Skip to content

Instantly share code, notes, and snippets.

@bashtage
Created August 18, 2025 12:22
Show Gist options
  • Save bashtage/39c50b223580b9bc3f763a3b1be3e478 to your computer and use it in GitHub Desktop.
Save bashtage/39c50b223580b9bc3f763a3b1be3e478 to your computer and use it in GitHub Desktop.
from numba import njit, prange
import numpy as np
N=1000
K=10
X = np.random.standard_normal((N,K))
w = np.arange(1, N+1).astype(float)
if N <= 1000:
ref = X.T @ np.diag(w) @ X
def base_fn(X,w):
return (X * w[:,None]).T @ X
base = base_fn(X,w)
def test(X, w):
k = X.shape[1]
n = X.shape[0]
out = np.zeros((k,k))
for i in range(n):
out += w[i] * np.dot(X[[i]].T, X[[i]])
return out
alt = test(X, w)
@njit(parallel=False)
def nb_test(X, w):
k = X.shape[1]
n = X.shape[0]
out = np.zeros((k,k))
for i in range(n):
out +=w[i] * X[i:i+1].T @ X[i:i+1]
return out
@njit(parallel=True)
def prange_test(X, w):
k = X.shape[1]
n = X.shape[0]
out = np.zeros((k,k))
for i in prange(n):
out +=w[i] * X[i:i+1].T @ X[i:i+1]
return out
no_parallel_alt = nb_test(X, w)
parallel_alt = prange_test(X, w)
np.testing.assert_allclose(base, alt)
np.testing.assert_allclose(base, no_parallel_alt)
np.testing.assert_allclose(base, parallel_alt)
parallel_alt = prange_test(X, w)
N=1000000
K=10
X = np.random.standard_normal((N,K))
w = np.arange(1, N+1).astype(float)
%timeit base_fn(X, w)
%timeit nb_test(X, w)
%timeit prange_test(X, w)
# Result
# 125 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 386 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 46.8 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment