Skip to content

Instantly share code, notes, and snippets.

@eddiebergman
Last active October 11, 2024 10:09
Show Gist options
  • Save eddiebergman/6884542469fd7e158a8bebec51741ab9 to your computer and use it in GitHub Desktop.
Save eddiebergman/6884542469fd7e158a8bebec51741ab9 to your computer and use it in GitHub Desktop.
An optimized implementation of `PyHessian`'s density function for computing the eigen-spectra of a hessian in torch
# An optimized implementation of the Lanczos algorithm to compute an
# estimate of the eigen-spectra of the Hessian of a neural network.
#
# The original implementation by `pyhessian` is used as a benchmark.
# https://github.com/amirgholami/PyHessian/tree/master
#
# To see how to actually use and plot the eigen-spectra based on this function,
# please see the original repository.
#
# # Requires:
# - torch
# - torchvision
# - numpy
#
# # Results:
#
# ----- Running optimized ----------
# Worst time for optimized: 0.25586422 s
# Average time for optimized: 0.24603766893999995 s
# Best time for optimized: 0.243696563 s
# ----- Running pyhessian ----------
# Worst time for pyhessian: 0.796938346 s
# Average time for pyhessian: 0.7863519074800002 s
# Best time for pyhessian: 0.769821447 s
# ---------- Difference ----------
# >> Eigenvalues - sum(abs(difference))
# Total tensor(0.0169)
# Average tensor(0.0002)
# >> Weight - sum(abs(difference))
# Total tensor(0.0002)
# Average tensor(1.6736e-06)
#
# # Speedups:
#
# The primary speedups, at least for smaller networks, come from:
#
# * Reducing the computations done in the iteration to be on flat tensors
# instead of the list of parameter tensors.
# * Using a projection Q consisting of orthonormal basis vectors to derive
# the next orthonormal vector in the Lanczos iteration.
# * Better torch calls and usage of re-allocations
#
# These speedups are likely not as effective with larger networks
# where the bottleneck is likely the computaton of `torch.autograd.grad`
# for the Hessian-vector product.
#
# This code was optimized on a CPU and has not been tested on a GPU, although
# speed ups should likely be faster due to highly parallel nature of the
# GPU and the densor computations.
#
# # Profiling:
# Requires the use of `py-spy` (`pip install py-spy`) and speedscope
# which you can view the output of the file `eigen-unfolded.speedscope`
# by opening it in the speedscope viewer. Recommend to put it to left-heavy mode
# to see bottlenecks.
#
# > py-spy record -f speedscope -n -o eigen-unfolded.speedscope -- python eigen_spectra.py
#
# The biggest bottlenecks are:
# * `torch.autograd.grad` for the Hessian-vector product
# * The Q.T @ (Q @ w) computation, which is a massive matrix multiplication
#
# Potential improvements:
# * At each iteration, the `w` encodes information about previous iterations,
# i.e. the previous orthonormal vectors `v`, those from which the matrix `Q` is
# built. The is likely a way to reduce the full use of `Q` if there is a way
# to use the encoded information in `w` to orthogonalize it w.r.t. Q.
# * The wikipedia page on the Lanczos algorithm states that we do not need
# to re-orthogonalize `w` at each iteration unless it has collapsed to a `0`
# vector. This is a significant speedup, but results in a divergence of the method
# of what `PyHessian` does.
# * The number of iterations can likely be early-stopped based on some criterion.
# * There are parallel versions of this algorithm but on a single GPU it is likely
# not worth it.
# * Computing the gradients with `torch.autograd.grad` can likely be improved in
# some manner.
#
# # Caveat:
#
# * This algorithm is only w.r.t. a single batch, there may be other optimizations
# that could be done in the case of using many batches or an entire data-loader
#
from collections.abc import Iterable
import torch
import torch.nn as nn
import time
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import os
from functools import partial
import random
import warnings
# This actually takes a surprising amount of time to run as it imports a deep nest of
# things. Can remove for more _accurate_ profiling, but the import cost is significant
# to the users first use of the function.
from torch.fx.experimental.symbolic_shapes import expect_true, sym_eq # noqa: F401
# Warnings about torch.load for the dataset
warnings.filterwarnings("ignore")
# Initialize models
INPUT_SIZE = 28 * 28
HIDDEN_SIZE = 128
OUTPUT_SIZE = 10
# ALGORITHM PARAMETERS
N_ALGO_ITER = 100
# Number of iterations to run
WARMUP = 3
N_TIMING_SAMPLES = 50
# Whether to print the resulting tensors
COMPARE = True
# Seeding
SEED = 1
def symmetric_tridiag(diag: torch.Tensor, offdiag: torch.Tensor) -> torch.Tensor:
_arange = torch.arange(diag.shape[0]).unsqueeze(1)
M = torch.diag(offdiag, -1)
M.scatter_(1, _arange, diag.unsqueeze(1))
M.scatter_(1, _arange[1:], offdiag.unsqueeze(1))
return M
class BaseNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(BaseNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def set_seed(seed_value):
"""Set seed for reproducibility."""
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(seed_value)
def eigen_spectra_via_lanczos_flat(
model: nn.Module,
criterion: nn.Module,
inputs: torch.Tensor,
targets: torch.Tensor,
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
alpha_list = [] # NOTE: Of size N_ALGO_ITER
beta_list = [] # NOTE: Of size one less than alpha_list
model.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
with torch.no_grad():
# Use autograd.grad instead of backward to compute gradients
gradsH = torch.autograd.grad(loss, list(model.parameters()), create_graph=True)
params = [param for param in model.parameters() if param.requires_grad]
gradsH = [
g if g is not None else torch.zeros_like(p) for g, p in zip(gradsH, params)
]
# Information about the shapes of the parameters
_p_numels = [p.numel() for p in params]
_p_shapes = [p.shape for p in params]
_p_numels_cumsum = (
torch.tensor(_p_numels, device=inputs.device).cumsum(0).tolist()
)
_offsets = zip([0] + _p_numels_cumsum[:-1], _p_numels_cumsum)
N = sum(_p_numels)
# We minimize re-allocations/duplications by pre-allocating the tensors
# Basis space
_Q = torch.empty((N_ALGO_ITER, N), device=inputs.device, dtype=dtype)
# Upfront get the references to the basis vectors we'll need
# The `__getitem__` call seems to be much slower when done in each loop
QS = [_Q[:i] for i in range(1, N_ALGO_ITER)]
qs = [_Q[i] for i in range(N_ALGO_ITER)]
# Same underlying storage, multiple views
# Holds the current vector we're working on
w = torch.empty(N, device=inputs.device, dtype=dtype)
w_viewed_as_grad_outputs = [
w[start:end].view(shape) for (start, end), shape in zip(_offsets, _p_shapes)
]
def _insert_hv_into_w(ts: Iterable[torch.Tensor]):
for t, _view in zip(ts, w_viewed_as_grad_outputs):
_view.copy_(t)
_norm_buf = torch.empty(1, device=w.device, dtype=w.dtype)
norm_of = partial(torch.linalg.vector_norm, ord=2, dtype=w.dtype, out=_norm_buf)
# -------------- Algorithm --------------
# Make first orthonormal vector
w.copy_(torch.randint(0, 2, size=(N,), dtype=w.dtype, device=w.device))
w[w == 0] = -1
w.div_(norm_of(w))
# Copy it as the first basis vector
first_basis_vec = _Q[0].copy_(w)
Hv = torch.autograd.grad(
outputs=gradsH,
inputs=params,
grad_outputs=w_viewed_as_grad_outputs,
only_inputs=True,
retain_graph=True,
)
_insert_hv_into_w(Hv)
alpha = torch.dot(w, first_basis_vec).item()
alpha_list.append(alpha)
w.sub_(first_basis_vec, alpha=alpha)
beta = norm_of(w)
beta_list.append(beta.item())
for i, Q in enumerate(QS, start=1):
# Reset our current vector if beta (norm) is zero, i.e. w is zero vector
if beta == 0:
w.copy_(torch.randint(0, 2, size=(N,), dtype=w.dtype, device=w.device))
w[w == 0] = -1
w.div_(norm_of(w))
# Make w orthogonal to our current basis space Q, and then normalize it
# NOTE: We have row vectors, hence the reversal of where we place the T
dots = Q @ w
proj = Q.T @ dots
# Gram schimdt - Much slower if used directly
# proj = torch.zeros_like(w)
# for j in range(i):
# proj += torch.dot(Q_current[j], w) * Q_current[j]
w.sub_(proj).div_(norm_of(w))
# Add it to our basis set, copying from the w buffer
# w_last is a refernce to this newest basis vector
newest_basis_vector = qs[i].copy_(w)
# Calculate the next w
Hv = torch.autograd.grad(
gradsH,
params,
grad_outputs=w_viewed_as_grad_outputs,
only_inputs=True,
retain_graph=True,
)
_insert_hv_into_w(Hv)
alpha = torch.dot(w, newest_basis_vector, out=_norm_buf).item()
alpha_list.append(alpha)
w.sub_(newest_basis_vector, alpha=alpha)
w.sub_(qs[i - 1], alpha=beta)
beta = norm_of(w).item()
beta_list.append(beta)
beta_list.pop() # Remove the last beta, as we didn't need it
T = symmetric_tridiag(torch.tensor(alpha_list), torch.tensor(beta_list))
eigvals, eigvecs = torch.linalg.eigh(T)
eig_real = eigvals.real
weights = torch.pow(eigvecs[0, :], 2)
return eig_real, weights
def main2(model, criterion, X, y, which):
if which == "pyhessian":
from pyhessian import hessian
hessian_comp = hessian(model, criterion, data=(X, y), cuda=False)
density_eigen, density_weight = hessian_comp.density(iter=N_ALGO_ITER)
e = torch.tensor(density_eigen[0])
w = torch.tensor(density_weight[0])
return e, w
elif which == "optimized":
my_eigen, my_weight = eigen_spectra_via_lanczos_flat(model, criterion, X, y)
return my_eigen, my_weight
else:
raise ValueError(f"Unknown which: {which}")
# Prevent accidental namespace leakage
if __name__ == "__main__":
# Set seed for reproducibility
set_seed(SEED)
# Load MNIST dataset
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
train_dataset = datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform, # type: ignore
)
_which = ["optimized", "pyhessian"] if COMPARE else ["optimized"]
for which in _which:
print(f"----- Running {which} ----------")
# Use a fixed seed for the DataLoader
set_seed(SEED)
generator = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, generator=generator
)
Xs_ys = [next(iter(train_loader)) for _ in range(N_TIMING_SAMPLES + WARMUP)]
times = []
for X, y in Xs_ys:
model = BaseNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)
criterion = nn.CrossEntropyLoss()
t0 = time.monotonic_ns()
e1, w1 = main2(model, criterion, X, y, which)
times.append(time.monotonic_ns() - t0)
times_secnds = np.array(times)[WARMUP:] / 1e9
print(f"Worst time for {which}: {np.max(times_secnds)} s")
print(f"Average time for {which}: {np.mean(times_secnds)} s")
print(f"Best time for {which}: {np.min(times_secnds)} s")
if COMPARE:
# Orig
set_seed(SEED)
generator = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, generator=generator
)
X, y = next(iter(train_loader))
model = BaseNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)
criterion = nn.CrossEntropyLoss()
e1, w1 = main2(model, criterion, X, y, "pyhessian")
ix = torch.argsort(e1)
e1 = e1[ix]
w1 = w1[ix]
# Mine
set_seed(SEED)
generator = torch.Generator().manual_seed(SEED)
train_loader = DataLoader(
train_dataset, batch_size=8, shuffle=True, generator=generator
)
X, y = next(iter(train_loader))
model = BaseNN(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE)
criterion = nn.CrossEntropyLoss()
e2, w2 = main2(model, criterion, X, y, "optimized")
ix = torch.argsort(e2)
e2 = e2[ix]
w2 = w2[ix]
print("---------- Difference ----------")
print(">> Eigenvalues - sum(abs(difference))")
absdiff = (e1 - e2).abs()
print("Total", absdiff.sum())
print("Average", absdiff.mean())
print(">> Weight - sum(abs(difference))")
absdiff = (w1 - w2).abs()
print("Total", absdiff.sum())
print("Average", absdiff.mean())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment