Last active
July 12, 2023 18:29
-
-
Save jiahao/837462f0d5f5a79152d60c58fd7ed268 to your computer and use it in GitHub Desktop.
A small example of double descent in Julia
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Plots | |
using StatsPlots | |
using LinearAlgebra | |
using ClassicalOrthogonalPolynomials | |
using ProgressMeter | |
using Statistics | |
k = 15 # Size of training data | |
l = 15 # Size of test data | |
# Model for generating data | |
# f(x) = 2*x+cos(25*x) # This was in the MIT tutorial paper | |
# Some arbitrary polynomial | |
f(x) = 2*x+x^2 - 0.1*x^3+10*x^14 | |
#1D Ackley function | |
# f(x) = -20*exp(-0.2*sqrt(0.5*x^2)) - exp(0.5*cos(2*pi*(x))) + exp(1) + 20 | |
xlims = (-1, 1) | |
# Doing polynomial regression by hand | |
function regress(y, x, | |
p, # Order of polynomial | |
λ=√eps(one(eltype(x))) # Ridge regularization parameter | |
) | |
n = length(x) | |
V = zeros(n, p+1) | |
for i ∈ 1:n | |
for j∈1:p | |
V[i,j+1] = legendrep(j, x[i]) # Good old Legendre polynomials | |
end | |
end | |
return pinv(V, atol=λ)*y # Least squares solution with explicit regularization threshold | |
end | |
# Evaluate the function with coefficients c | |
function y(x, c) | |
p = length(c) | |
v = c[1] | |
for i in 2:p | |
v += c[i]*legendrep(i-1, x) | |
end | |
v | |
end | |
# Run the simulation | |
N = 100 # Samples | |
maxp = 100 # Largest order to sample | |
models = [] | |
err_train = zeros(N, maxp+1) | |
err_test = zeros(N, maxp+1) | |
@showprogress for i in 1:N | |
# Sample train data | |
xk = (xlims[2]-xlims[1])*rand(k).+xlims[1] | |
yk = f.(xk) | |
# Sample test data | |
xl = (xlims[2]-xlims[1])*rand(l).+xlims[1] | |
yl = f.(xl) | |
for p in 0:maxp | |
c = regress(yk, xk, p, 0) | |
push!(models, c) # Save coefficients | |
err_train[i, p+1] = norm(yk.-(x->y(x,c)).(xk))^2/k | |
err_test[i, p+1] = norm(yl.-(x->y(x,c)).(xl))^2/l | |
end | |
end | |
# Generate plot | |
errorline(log10.(1:(maxp+1)), log10.(err_train), errorstyle=:ribbon, label="train", xlabel="log10(p)", ylabel="Log10 loss", ylim=(-7, 4)) | |
errorline!(log10.(1:(maxp+1)), log10.(err_test), errorstyle=:ribbon, label="test") | |
vline!([log10(15)], label = "n=k") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment