Created
April 23, 2021 07:33
-
-
Save jjerphan/026f5dde3b2314e0a44f0d2e89413c80 to your computer and use it in GitHub Desktop.
Explo scikit-learn#19952
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.linalg import cho_solve, cholesky | |
from sklearn.gaussian_process.kernels import RBF | |
from sklearn.datasets import make_regression | |
from sklearn.model_selection import train_test_split | |
kernel = RBF(length_scale=1.0) | |
X, y = make_regression() | |
X_train, X_test, y_train, y_test = train_test_split(X, y) | |
alpha = 5 | |
K = kernel(X_train) | |
K[np.diag_indices_from(K)] += alpha | |
L_lower = True | |
L = cholesky(K, lower=L_lower) | |
# Current | |
alpha_1 = cho_solve((L, L_lower), y_train) | |
# Reference | |
z = cho_solve((L, L_lower), y_train) | |
alpha_2 = cho_solve((L.T, not L_lower), z) | |
K_trans = kernel(X_test, X_train) | |
# Current | |
y_mean_1 = K_trans.dot(alpha_1) | |
# Reference | |
y_mean_2 = K_trans.dot(alpha_2) | |
v = cho_solve( | |
(L, L_lower), K_trans.T | |
) | |
# Current | |
y_cov1 = kernel(X_test) - K_trans.dot(v) | |
# Reference | |
y_cov2 = kernel(X_test) - v.T.dot(v) | |
# They are equivalent | |
np.testing.assert_allclose(y_cov1, y_cov2) | |
# But they are not here | |
np.testing.assert_allclose(alpha_1, alpha_2) | |
np.testing.assert_allclose(y_mean_1, y_mean_2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment