Skip to content

Instantly share code, notes, and snippets.

@smorad
Created September 7, 2024 00:22
Show Gist options
  • Save smorad/a882289e18b119ec36ee40d90ee5e06b to your computer and use it in GitHub Desktop.
Save smorad/a882289e18b119ec36ee40d90ee5e06b to your computer and use it in GitHub Desktop.
import jax.numpy as jnp
import numpy as np
import jax
jax.config.update("jax_enable_x64", True)
import equinox as eqx
import equinox.nn as nn
import optax
import tqdm
from tensorflow_probability.substrates import jax as tfp
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
import torch
import jax_dataloader as jdl
import matplotlib.pyplot as plt
DIM = 3
def project(x, key=None):
return x / (1e-7 + jnp.linalg.norm(x, axis=-1, keepdims=True))
def leaky_relu(x, key=None):
return jax.nn.leaky_relu(x)
class AE(eqx.Module):
encoder: eqx.Module
decoder: eqx.Module
def __init__(self):
keys = jax.random.split(jax.random.PRNGKey(0), 8)
# Can represent 4D as a 3D sphere and the distance within the sphere
self.encoder = nn.Sequential([
lambda x, key=None: x.reshape(784),
nn.Linear(784, 256, key=keys[0]),
nn.LayerNorm(256, use_weight=False, use_bias=False),
leaky_relu,
nn.Linear(256, 128, key=keys[6]),
nn.LayerNorm(128, use_weight=False, use_bias=False),
leaky_relu,
nn.Linear(128, 64, key=keys[1]),
nn.LayerNorm(64, use_weight=False, use_bias=False),
leaky_relu,
nn.Linear(64, DIM + 1, key=keys[2]),
lambda x, key: (project(x[:-1]), 1 + jax.nn.softplus(x[-1]))
])
self.decoder = nn.Sequential([
nn.Linear(DIM, 64, key=keys[3]),
nn.LayerNorm(64, use_weight=False, use_bias=False),
leaky_relu,
nn.Linear(64, 128, key=keys[4]),
nn.LayerNorm(128, use_weight=False, use_bias=False),
leaky_relu,
nn.Linear(128, 256, key=keys[7]),
nn.LayerNorm(256, use_weight=False, use_bias=False),
leaky_relu,
nn.Linear(256, 784, key=keys[5]),
lambda x, key=None: x.reshape(28, 28),
])
def __call__(self, x, key=None):
mu, concentration = self.encoder(x)
dist = tfp.distributions.PowerSpherical(mean_direction=mu, concentration=concentration)
if key is None:
z_sample = mu
else:
z_sample = dist.sample(seed=key).reshape(-1)
return (mu, concentration), self.decoder(z_sample)
def update(model, x, opt, opt_state, key):
def loss_fn(model, x, key):
key = jax.random.split(key, x.shape[0])
(mu, concentration), x_hat = eqx.filter_vmap(model)(x, key)
uniform = tfp.distributions.SphericalUniform(mu.shape[-1])
spherical = tfp.distributions.PowerSpherical(mean_direction=mu, concentration=concentration)
kl_term = spherical.kl_divergence(uniform)
return jnp.mean((x - x_hat) ** 2) + kl_term.mean(), concentration.mean()
(loss, conc), grad = eqx.filter_value_and_grad(loss_fn, has_aux=True)(model, x, key)
updates, opt_state = opt.update(
grad, opt_state, params=eqx.filter(model, eqx.is_inexact_array)
)
model = eqx.apply_updates(model, updates)
return model, loss, grad, conc
key = jax.random.PRNGKey(0)
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float) / 255.0, train=True)
pt_ds_test = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float) / 255.0, train=False)
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=32, shuffle=True)
test_dataloader = jdl.DataLoader(pt_ds_test, 'pytorch', batch_size=16, shuffle=False)
test_x, test_y = next(test_dataloader.__iter__())
model = AE()
lr_schedule = optax.constant_schedule(0.001)
opt = optax.adamw(lr_schedule)
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array))
train_key = jax.random.PRNGKey(1)
for epoch in range(4):
losses = []
for x, y in tqdm.tqdm(dataloader):
train_key = jax.random.split(train_key)[0]
model, loss, grad, conc = eqx.filter_jit(update)(model, x, opt, opt_state, train_key)
losses.append(loss)
z, x_hat = eqx.filter_jit(eqx.filter_vmap(model))(test_x, jax.random.split(key, test_x.shape[0]))
print(f"Epoch {epoch} train: {np.mean(losses):.4f} test: {np.mean((test_x - x_hat) ** 2):.4f}, conc {conc:.3f}")
top = torch.tensor(np.asarray(x_hat)).unsqueeze(1)
bottom = torch.tensor(np.asarray(test_x)).unsqueeze(1)
grid = make_grid(torch.cat([top, bottom], dim=0), nrow=16)
plt.figure()
plt.imshow(grid.permute(1, 2, 0))
r = project(jax.random.normal(jax.random.PRNGKey(0), (32,DIM)))
slerp_images = eqx.filter_vmap(model.decoder)(r)
grid = make_grid(torch.from_numpy(np.array(slerp_images)).unsqueeze(1))
plt.figure()
plt.imshow(grid.permute(1, 2, 0))
plt.title("Spherical AE Random")
# Plot test set on sphere
ds = pt_ds_test.data.numpy()[:512] / 255.0
(mu, conc), _ = eqx.filter_vmap(model)(ds)
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
mu = np.array(mu)
ax.scatter(mu[:,0], mu[:,1], mu[:,2])
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment