Created
September 7, 2024 00:22
-
-
Save smorad/a882289e18b119ec36ee40d90ee5e06b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jax.numpy as jnp | |
import numpy as np | |
import jax | |
jax.config.update("jax_enable_x64", True) | |
import equinox as eqx | |
import equinox.nn as nn | |
import optax | |
import tqdm | |
from tensorflow_probability.substrates import jax as tfp | |
from torchvision.datasets import MNIST | |
from torchvision.utils import make_grid | |
import torch | |
import jax_dataloader as jdl | |
import matplotlib.pyplot as plt | |
DIM = 3 | |
def project(x, key=None): | |
return x / (1e-7 + jnp.linalg.norm(x, axis=-1, keepdims=True)) | |
def leaky_relu(x, key=None): | |
return jax.nn.leaky_relu(x) | |
class AE(eqx.Module): | |
encoder: eqx.Module | |
decoder: eqx.Module | |
def __init__(self): | |
keys = jax.random.split(jax.random.PRNGKey(0), 8) | |
# Can represent 4D as a 3D sphere and the distance within the sphere | |
self.encoder = nn.Sequential([ | |
lambda x, key=None: x.reshape(784), | |
nn.Linear(784, 256, key=keys[0]), | |
nn.LayerNorm(256, use_weight=False, use_bias=False), | |
leaky_relu, | |
nn.Linear(256, 128, key=keys[6]), | |
nn.LayerNorm(128, use_weight=False, use_bias=False), | |
leaky_relu, | |
nn.Linear(128, 64, key=keys[1]), | |
nn.LayerNorm(64, use_weight=False, use_bias=False), | |
leaky_relu, | |
nn.Linear(64, DIM + 1, key=keys[2]), | |
lambda x, key: (project(x[:-1]), 1 + jax.nn.softplus(x[-1])) | |
]) | |
self.decoder = nn.Sequential([ | |
nn.Linear(DIM, 64, key=keys[3]), | |
nn.LayerNorm(64, use_weight=False, use_bias=False), | |
leaky_relu, | |
nn.Linear(64, 128, key=keys[4]), | |
nn.LayerNorm(128, use_weight=False, use_bias=False), | |
leaky_relu, | |
nn.Linear(128, 256, key=keys[7]), | |
nn.LayerNorm(256, use_weight=False, use_bias=False), | |
leaky_relu, | |
nn.Linear(256, 784, key=keys[5]), | |
lambda x, key=None: x.reshape(28, 28), | |
]) | |
def __call__(self, x, key=None): | |
mu, concentration = self.encoder(x) | |
dist = tfp.distributions.PowerSpherical(mean_direction=mu, concentration=concentration) | |
if key is None: | |
z_sample = mu | |
else: | |
z_sample = dist.sample(seed=key).reshape(-1) | |
return (mu, concentration), self.decoder(z_sample) | |
def update(model, x, opt, opt_state, key): | |
def loss_fn(model, x, key): | |
key = jax.random.split(key, x.shape[0]) | |
(mu, concentration), x_hat = eqx.filter_vmap(model)(x, key) | |
uniform = tfp.distributions.SphericalUniform(mu.shape[-1]) | |
spherical = tfp.distributions.PowerSpherical(mean_direction=mu, concentration=concentration) | |
kl_term = spherical.kl_divergence(uniform) | |
return jnp.mean((x - x_hat) ** 2) + kl_term.mean(), concentration.mean() | |
(loss, conc), grad = eqx.filter_value_and_grad(loss_fn, has_aux=True)(model, x, key) | |
updates, opt_state = opt.update( | |
grad, opt_state, params=eqx.filter(model, eqx.is_inexact_array) | |
) | |
model = eqx.apply_updates(model, updates) | |
return model, loss, grad, conc | |
key = jax.random.PRNGKey(0) | |
pt_ds = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float) / 255.0, train=True) | |
pt_ds_test = MNIST('/tmp/mnist/', download=True, transform=lambda x: np.array(x, dtype=float) / 255.0, train=False) | |
dataloader = jdl.DataLoader(pt_ds, 'pytorch', batch_size=32, shuffle=True) | |
test_dataloader = jdl.DataLoader(pt_ds_test, 'pytorch', batch_size=16, shuffle=False) | |
test_x, test_y = next(test_dataloader.__iter__()) | |
model = AE() | |
lr_schedule = optax.constant_schedule(0.001) | |
opt = optax.adamw(lr_schedule) | |
opt_state = opt.init(eqx.filter(model, eqx.is_inexact_array)) | |
train_key = jax.random.PRNGKey(1) | |
for epoch in range(4): | |
losses = [] | |
for x, y in tqdm.tqdm(dataloader): | |
train_key = jax.random.split(train_key)[0] | |
model, loss, grad, conc = eqx.filter_jit(update)(model, x, opt, opt_state, train_key) | |
losses.append(loss) | |
z, x_hat = eqx.filter_jit(eqx.filter_vmap(model))(test_x, jax.random.split(key, test_x.shape[0])) | |
print(f"Epoch {epoch} train: {np.mean(losses):.4f} test: {np.mean((test_x - x_hat) ** 2):.4f}, conc {conc:.3f}") | |
top = torch.tensor(np.asarray(x_hat)).unsqueeze(1) | |
bottom = torch.tensor(np.asarray(test_x)).unsqueeze(1) | |
grid = make_grid(torch.cat([top, bottom], dim=0), nrow=16) | |
plt.figure() | |
plt.imshow(grid.permute(1, 2, 0)) | |
r = project(jax.random.normal(jax.random.PRNGKey(0), (32,DIM))) | |
slerp_images = eqx.filter_vmap(model.decoder)(r) | |
grid = make_grid(torch.from_numpy(np.array(slerp_images)).unsqueeze(1)) | |
plt.figure() | |
plt.imshow(grid.permute(1, 2, 0)) | |
plt.title("Spherical AE Random") | |
# Plot test set on sphere | |
ds = pt_ds_test.data.numpy()[:512] / 255.0 | |
(mu, conc), _ = eqx.filter_vmap(model)(ds) | |
fig = plt.figure() | |
ax = fig.add_subplot(projection='3d') | |
mu = np.array(mu) | |
ax.scatter(mu[:,0], mu[:,1], mu[:,2]) | |
ax.set_xlim(-1, 1) | |
ax.set_ylim(-1, 1) | |
ax.set_zlim(-1, 1) | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment