Last active
January 18, 2020 19:52
-
-
Save cmcaine/426db5d7f9af402c7bd76eeb12bffd7f to your computer and use it in GitHub Desktop.
Simpler Flux demos
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Learn a linear model for a linear relationship of three variables | |
""" | |
module MultipleLinearRegressionDemo | |
using Flux: gradient, params | |
using Statistics: mean | |
function linreg(X, Y) | |
# Add the intercept column to X | |
X = hcat(ones(size(X)[1]), X) | |
# The weight for each independent variable and the intercept | |
W = randn(size(X)[2]) | |
# Prediction function (a linear model) | |
predict(X) = X * W | |
# Ordinary Least Squares loss function: mean squared error | |
# Use mean rather than sum so that learning_rate is independent of input size. | |
loss() = mean((Y - predict(X)).^2) | |
# Learn | |
# Compute the current loss, and differentiate with respect to W | |
# Update W in the opposite direction of the gradient (i.e. descend) | |
learning_rate = 0.01 | |
for i = 1:3000 | |
gs = gradient(loss, params(W)) | |
W .-= learning_rate .* gs[W] | |
end | |
return W | |
end | |
# Training data | |
X = randn(10, 2) | |
Y = X[:, 1] .* 3 .+ X[:, 2] .* .5 .+ 1 | |
# We're expecting: 1, 3, .5 | |
@show linreg(X, Y) | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Learn a linear model for data that really does have a linear relationship | |
""" | |
module SimpleLinearRegressionDemo | |
using Flux: gradient, params | |
using Statistics: mean | |
function linreg(xs, ys) | |
# Parameters have to wrapped so that they can be updated in-place. | |
c = [randn()] | |
m = [randn()] | |
# Prediction function (a linear model y = mx + c) | |
predict(x) = m[1] * x + c[1] | |
# Loss function: squared error | |
loss(x, y) = (predict(x) - y)^2 | |
# Learn | |
# Compute the current loss, and differentiate with respect to m and c. | |
# Update c and m in the opposite direction of the gradient (i.e. descend) | |
learning_rate = 0.01 | |
for i = 1:3000 | |
gs = gradient(() -> mean(loss.(xs, ys)), params(c, m)) | |
c .-= learning_rate .* gs[c] | |
m .-= learning_rate .* gs[m] | |
end | |
return vcat(c, m) | |
end | |
# Training data | |
xs = 1:10 | |
ys = xs .* 3 .+ 2; | |
# I expect c should approach 2 and m should approach 3. | |
@show linreg(xs, ys) | |
end |
Another oddity to examine before proposing in a PR is why the simple linear regression requires me to calculate the mean of the loss rather than the sum.
if I try the sum, the loss rapidly approaches infinity and I end up with [NaN, NaN]
.
Yup, it's just the learning rate is too high for sum
. If reduced to 0.002, it performs well.
Using a dynamic learning rate would probably fix it.
Well.. You do want the mean square error because then your learning rate depends on batch size :). It is also the BLUE for LS, so it has a definitional purpose.
That is a very good point ;)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I wonder if that's too magical and if I should do this instead?