Last active
October 11, 2024 10:09
-
-
Save eddiebergman/6884542469fd7e158a8bebec51741ab9 to your computer and use it in GitHub Desktop.
An optimized implementation of `PyHessian`'s density function for computing the eigen-spectra of a hessian in torch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An optimized implementation of the Lanczos algorithm to compute an | |
# estimate of the eigen-spectra of the Hessian of a neural network. | |
# | |
# The original implementation by `pyhessian` is used as a benchmark. | |
# https://github.com/amirgholami/PyHessian/tree/master | |
# | |
# To see how to actually use and plot the eigen-spectra based on this function, | |
# please see the original repository. | |
# | |
# # Requires: | |
# - torch | |
# - torchvision | |
# - numpy | |
# | |
# # Results: | |
# | |
# ----- Running optimized ---------- | |
# Worst time for optimized: 0.25586422 s | |
# Average time for optimized: 0.24603766893999995 s | |
# Best time for optimized: 0.243696563 s | |
# ----- Running pyhessian ---------- | |
# Worst time for pyhessian: 0.796938346 s | |
# Average time for pyhessian: 0.7863519074800002 s | |
# Best time for pyhessian: 0.769821447 s | |
# ---------- Difference ---------- | |
# >> Eigenvalues - sum(abs(difference)) | |
# Total tensor(0.0169) | |
# Average tensor(0.0002) | |
# >> Weight - sum(abs(difference)) | |
# Total tensor(0.0002) | |
# Average tensor(1.6736e-06) | |
# | |
# # Speedups: | |
# | |
# The primary speedups, at least for smaller networks, come from: | |
# | |
# * Reducing the computations done in the iteration to be on flat tensors | |
# instead of the list of parameter tensors. | |
# * Using a projection Q consisting of orthonormal basis vectors to derive | |
# the next orthonormal vector in the Lanczos iteration. | |
# * Better torch calls and usage of re-allocations | |
# | |
# These speedups are likely not as effective with larger networks | |
# where the bottleneck is likely the computaton of `torch.autograd.grad` | |
# for the Hessian-vector product. | |
# | |
# This code was optimized on a CPU and has not been tested on a GPU, although | |
# speed ups should likely be faster due to highly parallel nature of the | |
# GPU and the densor computations. | |
# | |
# # Profiling: | |
# Requires the use of `py-spy` (`pip install py-spy`) and speedscope | |
# which you can view the output of the file `eigen-unfolded.speedscope` | |
# by opening it in the speedscope viewer. Recommend to put it to left-heavy mode | |
# to see bottlenecks. | |
# | |
# > py-spy record -f speedscope -n -o eigen-unfolded.speedscope -- python eigen_spectra.py | |
# | |
# The biggest bottlenecks are: | |
# * `torch.autograd.grad` for the Hessian-vector product | |
# * The Q.T @ (Q @ w) computation, which is a massive matrix multiplication | |
# | |
# Potential improvements: | |
# * At each iteration, the `w` encodes information about previous iterations, | |
# i.e. the previous orthonormal vectors `v`, those from which the matrix `Q` is | |
# built. The is likely a way to reduce the full use of `Q` if there is a way | |
# to use the encoded information in `w` to orthogonalize it w.r.t. Q. | |
# * The wikipedia page on the Lanczos algorithm states that we do not need | |
# to re-orthogonalize `w` at each iteration unless it has collapsed to a `0` | |
# vector. This is a significant speedup, but results in a divergence of the method | |
# of what `PyHessian` does. | |
# * The number of iterations can likely be early-stopped based on some criterion. | |
# * There are parallel versions of this algorithm but on a single GPU it is likely | |
# not worth it. | |
# * Computing the gradients with `torch.autograd.grad` can likely be improved in | |
# some manner. | |
# | |
# # Caveat: | |
# | |
# * This algorithm is only w.r.t. a single batch, there may be other optimizations | |
# that could be done in the case of using many batches or an entire data-loader | |
# | |
from collections.abc import Iterable | |
import torch | |
import torch.nn as nn | |
import time | |
from torchvision import datasets, transforms | |
from torch.utils.data import DataLoader | |
import numpy as np | |
import os | |
from functools import partial | |
import random | |
import warnings | |
# This actually takes a surprising amount of time to run as it imports a deep nest of | |
# things. Can remove for more _accurate_ profiling, but the import cost is significant | |
# to the users first use of the function. | |
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq # noqa: F401 | |
# Warnings about torch.load for the dataset | |
warnings.filterwarnings("ignore") | |
# Initialize models | |
INPUT_SIZE = 28 * 28 | |
HIDDEN_SIZE = 128 | |
OUTPUT_SIZE = 10 | |
# ALGORITHM PARAMETERS | |
N_ALGO_ITER = 100 | |
# Number of iterations to run | |
WARMUP = 3 | |
N_TIMING_SAMPLES = 50 | |
# Whether to print the resulting tensors | |
COMPARE = True | |
# Seeding | |
SEED = 1 | |
def symmetric_tridiag(diag: torch.Tensor, offdiag: torch.Tensor) -> torch.Tensor: | |
_arange = torch.arange(diag.shape[0]).unsqueeze(1) | |
M = torch.diag(offdiag, -1) | |
M.scatter_(1, _arange, diag.unsqueeze(1)) | |
M.scatter_(1, _arange[1:], offdiag.unsqueeze(1)) | |
return M | |
class BaseNN(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size): | |
super(BaseNN, self).__init__() | |
self.fc1 = nn.Linear(input_size, hidden_size) | |
self.fc2 = nn.Linear(hidden_size, output_size) | |
def forward(self, x): | |
x = x.view(x.size(0), -1) | |
x = torch.relu(self.fc1(x)) | |
x = self.fc2(x) | |
return x | |
def set_seed(seed_value): | |
"""Set seed for reproducibility.""" | |
random.seed(seed_value) | |
np.random.seed(seed_value) | |
torch.manual_seed(seed_value) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed_value) | |
torch.cuda.manual_seed_all(seed_value) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
os.environ["PYTHONHASHSEED"] = str(seed_value) | |
def eigen_spectra_via_lanczos_flat( | |
model: nn.Module, | |
criterion: nn.Module, | |
inputs: torch.Tensor, | |
targets: torch.Tensor, | |
dtype: torch.dtype = torch.float32, | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
alpha_list = [] # NOTE: Of size N_ALGO_ITER | |
beta_list = [] # NOTE: Of size one less than alpha_list | |
model.zero_grad() | |
outputs = model(inputs) | |
loss = criterion(outputs, targets) | |
with torch.no_grad(): | |
# Use autograd.grad instead of backward to compute gradients | |
gradsH = torch.autograd.grad(loss, list(model.parameters()), create_graph=True) | |
params = [param for param in model.parameters() if param.requires_grad] | |
gradsH = [ | |
g if g is not None else torch.zeros_like(p) for g, p in zip(gradsH, params) | |
] | |
# Information about the shapes of the parameters | |
_p_numels = [p.numel() for p in params] | |
_p_shapes = [p.shape for p in params] | |
_p_numels_cumsum = ( | |
torch.tensor(_p_numels, device=inputs.device).cumsum(0).tolist() | |
) | |
_offsets = zip([0] + _p_numels_cumsum[:-1], _p_numels_cumsum) | |
N = sum(_p_numels) | |
# We minimize re-allocations/duplications by pre-allocating the tensors | |
# Basis space | |
_Q = torch.empty((N_ALGO_ITER, N), device=inputs.device, dtype=dtype) | |
# Upfront get the references to the basis vectors we'll need | |
# The `__getitem__` call seems to be much slower when done in each loop | |
QS = [_Q[:i] for i in range(1, N_ALGO_ITER)] | |
qs = [_Q[i] for i in range(N_ALGO_ITER)] | |
# Same underlying storage, multiple views | |
# Holds the current vector we're working on | |
w = torch.empty(N, device=inputs.device, dtype=dtype) | |
w_viewed_as_grad_outputs = [ | |
w[start:end].view(shape) for (start, end), shape in zip(_offsets, _p_shapes) | |
] | |
def _insert_hv_into_w(ts: Iterable[torch.Tensor]): | |
for t, _view in zip(ts, w_viewed_as_grad_outputs): | |
_view.copy_(t) | |
_norm_buf = torch.empty(1, device=w.device, dtype=w.dtype) | |
norm_of = partial(torch.linalg.vector_norm, ord=2, dtype=w.dtype, out=_norm_buf) | |
# -------------- Algorithm -------------- | |
# Make first orthonormal vector | |
w.copy_(torch.randint(0, 2, size=(N,), dtype=w.dtype, device=w.device)) | |
w[w == 0] = -1 | |
w.div_(norm_of(w)) | |
# Copy it as the first basis vector | |
first_basis_vec = _Q[0].copy_(w) | |
Hv = torch.autograd.grad( | |
outputs=gradsH, | |
inputs=params, | |
grad_outputs=w_viewed_as_grad_outputs, | |
only_inputs=True, | |
retain_graph=True, | |
) | |
_insert_hv_into_w(Hv) | |
alpha = torch.dot(w, first_basis_vec).item() | |
alpha_list.append(alpha) | |
w.sub_(first_basis_vec, alpha=alpha) | |
beta = norm_of(w) | |
beta_list.append(beta.item()) | |
for i, Q in enumerate(QS, start=1): | |
# Reset our current vector if beta (norm) is zero, i.e. w is zero vector | |
if beta == 0: | |
w.copy_(torch.randint(0, 2, size=(N,), dtype=w.dtype, device=w.device)) | |
w[w == 0] = -1 | |
w.div_(norm_of(w)) | |
# Make w orthogonal to our current basis space Q, and then normalize it | |
# NOTE: We have row vectors, hence the reversal of where we place the T | |
dots = Q @ w | |
proj = Q.T @ dots | |
# Gram schimdt - Much slower if used directly | |
# proj = torch.zeros_like(w) | |
# for j in range(i): | |
# proj += torch.dot(Q_current[j], w) * Q_current[j] | |
w.sub_(proj).div_(norm_of(w)) | |
# Add it to our basis set, copying from the w buffer | |
# w_last is a refernce to this newest basis vector | |
newest_basis_vector = qs[i].copy_(w) | |
# Calculate the next w | |
Hv = torch.autograd.grad( | |
gradsH, | |
params, | |
grad_outputs=w_viewed_as_grad_outputs, | |
only_inputs=True, | |
retain_graph=True, | |
) | |
_insert_hv_into_w(Hv) | |
alpha = torch.dot(w, newest_basis_vector, out=_norm_buf).item() | |
alpha_list.append(alpha) | |
w.sub_(newest_basis_vector, alpha=alpha) | |
w.sub_(qs[i - 1], alpha=beta) | |
beta = norm_of(w).item() | |
beta_list.append(beta) | |
beta_list.pop() # Remove the last beta, as we didn't need it | |
T = symmetric_tridiag(torch.tensor(alpha_list), torch.tensor(beta_list)) | |
eigvals, eigvecs = torch.linalg.eigh(T) | |
eig_real = eigvals.real | |
weights = torch.pow(eigvecs[0, :], 2) | |
return eig_real, weights | |
def main2(model, criterion, X, y, which): | |
if which == "pyhessian": | |
from pyhessian import hessian | |
hessian_comp = hessian(model, criterion, data=(X, y), cuda=False) | |
density_eigen, density_weight = hessian_comp.density(iter=N_ALGO_ITER) | |
e = torch.tensor(density_eigen[0]) | |
w = torch.tensor(density_weight[0]) | |
return e, w | |
elif which == "optimized": | |
my_eigen, my_weight = eigen_spectra_via_lanczos_flat(model, criterion, X, y) | |
return my_eigen, my_weight | |
else: | |
raise ValueError(f"Unknown which: {which}") | |
# Prevent accidental namespace leakage | |
if __name__ == "__main__": | |
# Set seed for reproducibility | |
set_seed(SEED) | |
# Load MNIST dataset | |
transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | |
) | |
train_dataset = datasets.MNIST( | |
root="./data", | |
train=True, | |
download=True, | |
transform=transform, # type: ignore | |
) | |
_which = ["optimized", "pyhessian"] if COMPARE else ["optimized"] | |
for which in _which: | |
print(f"----- Running {which} ----------") | |
# Use a fixed seed for the DataLoader | |
set_seed(SEED) | |
generator = torch.Generator().manual_seed(SEED) | |
train_loader = DataLoader( | |
train_dataset, batch_size=8, shuffle=True, generator=generator | |
) | |
Xs_ys = [next(iter(train_loader)) for _ in range(N_TIMING_SAMPLES + WARMUP)] | |
times = [] | |
for X, y in Xs_ys: | |
model = BaseNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE) | |
criterion = nn.CrossEntropyLoss() | |
t0 = time.monotonic_ns() | |
e1, w1 = main2(model, criterion, X, y, which) | |
times.append(time.monotonic_ns() - t0) | |
times_secnds = np.array(times)[WARMUP:] / 1e9 | |
print(f"Worst time for {which}: {np.max(times_secnds)} s") | |
print(f"Average time for {which}: {np.mean(times_secnds)} s") | |
print(f"Best time for {which}: {np.min(times_secnds)} s") | |
if COMPARE: | |
# Orig | |
set_seed(SEED) | |
generator = torch.Generator().manual_seed(SEED) | |
train_loader = DataLoader( | |
train_dataset, batch_size=8, shuffle=True, generator=generator | |
) | |
X, y = next(iter(train_loader)) | |
model = BaseNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE) | |
criterion = nn.CrossEntropyLoss() | |
e1, w1 = main2(model, criterion, X, y, "pyhessian") | |
ix = torch.argsort(e1) | |
e1 = e1[ix] | |
w1 = w1[ix] | |
# Mine | |
set_seed(SEED) | |
generator = torch.Generator().manual_seed(SEED) | |
train_loader = DataLoader( | |
train_dataset, batch_size=8, shuffle=True, generator=generator | |
) | |
X, y = next(iter(train_loader)) | |
model = BaseNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE) | |
criterion = nn.CrossEntropyLoss() | |
e2, w2 = main2(model, criterion, X, y, "optimized") | |
ix = torch.argsort(e2) | |
e2 = e2[ix] | |
w2 = w2[ix] | |
print("---------- Difference ----------") | |
print(">> Eigenvalues - sum(abs(difference))") | |
absdiff = (e1 - e2).abs() | |
print("Total", absdiff.sum()) | |
print("Average", absdiff.mean()) | |
print(">> Weight - sum(abs(difference))") | |
absdiff = (w1 - w2).abs() | |
print("Total", absdiff.sum()) | |
print("Average", absdiff.mean()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment