Last active
January 18, 2020 19:52
-
-
Save cmcaine/426db5d7f9af402c7bd76eeb12bffd7f to your computer and use it in GitHub Desktop.
Simpler Flux demos
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Learn a linear model for a linear relationship of three variables | |
""" | |
module MultipleLinearRegressionDemo | |
using Flux: gradient, params | |
using Statistics: mean | |
function linreg(X, Y) | |
# Add the intercept column to X | |
X = hcat(ones(size(X)[1]), X) | |
# The weight for each independent variable and the intercept | |
W = randn(size(X)[2]) | |
# Prediction function (a linear model) | |
predict(X) = X * W | |
# Ordinary Least Squares loss function: mean squared error | |
# Use mean rather than sum so that learning_rate is independent of input size. | |
loss() = mean((Y - predict(X)).^2) | |
# Learn | |
# Compute the current loss, and differentiate with respect to W | |
# Update W in the opposite direction of the gradient (i.e. descend) | |
learning_rate = 0.01 | |
for i = 1:3000 | |
gs = gradient(loss, params(W)) | |
W .-= learning_rate .* gs[W] | |
end | |
return W | |
end | |
# Training data | |
X = randn(10, 2) | |
Y = X[:, 1] .* 3 .+ X[:, 2] .* .5 .+ 1 | |
# We're expecting: 1, 3, .5 | |
@show linreg(X, Y) | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Learn a linear model for data that really does have a linear relationship | |
""" | |
module SimpleLinearRegressionDemo | |
using Flux: gradient, params | |
using Statistics: mean | |
function linreg(xs, ys) | |
# Parameters have to wrapped so that they can be updated in-place. | |
c = [randn()] | |
m = [randn()] | |
# Prediction function (a linear model y = mx + c) | |
predict(x) = m[1] * x + c[1] | |
# Loss function: squared error | |
loss(x, y) = (predict(x) - y)^2 | |
# Learn | |
# Compute the current loss, and differentiate with respect to m and c. | |
# Update c and m in the opposite direction of the gradient (i.e. descend) | |
learning_rate = 0.01 | |
for i = 1:3000 | |
gs = gradient(() -> mean(loss.(xs, ys)), params(c, m)) | |
c .-= learning_rate .* gs[c] | |
m .-= learning_rate .* gs[m] | |
end | |
return vcat(c, m) | |
end | |
# Training data | |
xs = 1:10 | |
ys = xs .* 3 .+ 2; | |
# I expect c should approach 2 and m should approach 3. | |
@show linreg(xs, ys) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That is a very good point ;)