Skip to content

Instantly share code, notes, and snippets.

@sir-wabbit
Last active January 19, 2022 01:01
Show Gist options
  • Save sir-wabbit/d31e904b06cf7af8d869ff5092764525 to your computer and use it in GitHub Desktop.
Save sir-wabbit/d31e904b06cf7af8d869ff5092764525 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
##########################################################
# Generating training and validation data
##########################################################
vocab_size = 47 # prime number p
ALL = []
for x in range(vocab_size):
for y in range(vocab_size):
ALL.append([x, y, (x - y + vocab_size) % vocab_size]) # x + y (mod p)
R = np.random.RandomState(32523414)
R.shuffle(ALL)
split_index = int(len(ALL) * 0.75)
TRAIN = ALL[:split_index]
TEST = ALL[split_index:]
TRAIN = torch.tensor(TRAIN, dtype=torch.long)
TEST = torch.tensor(TEST, dtype=torch.long)
##########################################################
# Simple model
##########################################################
hidden_size = 50
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.embeds = nn.Embedding(vocab_size, hidden_size)
self.linear1 = nn.Linear(2 * hidden_size, 2 * hidden_size)
self.relu = nn.LeakyReLU()
self.linear2 = nn.Linear(2 * hidden_size, 2 * hidden_size)
self.linear3 = nn.Linear(2 * hidden_size, vocab_size)
def forward(self, inputs):
a = self.embeds(inputs)
# print(a.shape)
a = a.reshape(-1, 2 * hidden_size)
b = self.relu(self.linear1(a))
c = self.relu(self.linear2(b))
return F.log_softmax(self.linear3(c), dim=1)
##########################################################
# A simple method of making small jumps in the tangent space.
##########################################################
def tangent_jump(v, alpha=5e-2, eps=1e-8):
if isinstance(v, nn.Module):
for p in v.parameters():
if p is None: continue
tangent_jump(p, alpha)
else:
g = v.grad.data
gl = torch.norm(g)
gn = g / (gl + eps)
while True:
# Random vector.
r = torch.randn_like(v.data)
# Ensure we are jumping in the direction orthogonal to the gradient.
r = r - gn * torch.sum(r * gn)
rl = torch.norm(r)
r = r / (rl + eps)
if rl >= eps:
break
# Just a heuristic
l = np.sqrt(len(v.data) - 1)
v.data += r * l * alpha
##########################################################
# Optimization
##########################################################
model = Model()
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_function = nn.NLLLoss()
total_loss = 0
acc_train_acc = 0
acc_val_acc = 0
DATA = []
N_EPOCHS = 10000
# Turn this off if you want to see what happens without
# stochastic jumps in the tangent space.
USE_TANGENT_JUMPS = True
for epoch in range(N_EPOCHS + 1):
# Note: If we were using small batches, there would be some
# gradient noise, that would result in small jumps
# in the tangent space. When not using batches, we
# can introduce those directly.
##########################################################
# Jumping in the tangent space.
##########################################################
if USE_TANGENT_JUMPS:
model.zero_grad()
log_probs = model(TRAIN[:,:2])
loss = loss_function(log_probs, TRAIN[:,2])
loss.backward()
alpha = 1e-1
tangent_jump(model, alpha)
##########################################################
# Regular (non-batched) optimization
##########################################################
model.zero_grad()
log_probs = model(TRAIN[:,:2])
loss = loss_function(log_probs, TRAIN[:,2])
loss.backward()
optimizer.step()
total_loss += loss.item()
##########################################################
# Reporting and stop condition.
##########################################################
train_acc = torch.sum(TRAIN[:,2] == torch.argmax(model(TRAIN[:,:2]), dim=1)) / TRAIN.shape[0]
val_acc = torch.sum(TEST[:,2] == torch.argmax(model(TEST[:,:2]), dim=1)) / TEST.shape[0]
DATA.append([epoch, train_acc, val_acc])
if train_acc == 1.0 and val_acc == 1.0:
break
acc_train_acc += train_acc.item()
acc_val_acc += val_acc.item()
if epoch % 100 == 0:
print(f'{epoch}: L={total_loss / 100}, Train Acc={acc_train_acc:.3f}%, Val Acc={acc_val_acc:.3f}%')
total_loss = 0
acc_train_acc = 0
acc_val_acc = 0
##########################################################
# Plot the accuracy.
##########################################################
DATA = np.asarray(DATA)
plt.figure(figsize=(16, 12), facecolor='white')
plt.plot(DATA[:,0], DATA[:,1])
plt.plot(DATA[:,0], DATA[:,2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment