Last active
January 19, 2022 01:01
-
-
Save sir-wabbit/d31e904b06cf7af8d869ff5092764525 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import numpy as np | |
import matplotlib.pyplot as plt | |
########################################################## | |
# Generating training and validation data | |
########################################################## | |
vocab_size = 47 # prime number p | |
ALL = [] | |
for x in range(vocab_size): | |
for y in range(vocab_size): | |
ALL.append([x, y, (x - y + vocab_size) % vocab_size]) # x + y (mod p) | |
R = np.random.RandomState(32523414) | |
R.shuffle(ALL) | |
split_index = int(len(ALL) * 0.75) | |
TRAIN = ALL[:split_index] | |
TEST = ALL[split_index:] | |
TRAIN = torch.tensor(TRAIN, dtype=torch.long) | |
TEST = torch.tensor(TEST, dtype=torch.long) | |
########################################################## | |
# Simple model | |
########################################################## | |
hidden_size = 50 | |
class Model(nn.Module): | |
def __init__(self): | |
super(Model, self).__init__() | |
self.embeds = nn.Embedding(vocab_size, hidden_size) | |
self.linear1 = nn.Linear(2 * hidden_size, 2 * hidden_size) | |
self.relu = nn.LeakyReLU() | |
self.linear2 = nn.Linear(2 * hidden_size, 2 * hidden_size) | |
self.linear3 = nn.Linear(2 * hidden_size, vocab_size) | |
def forward(self, inputs): | |
a = self.embeds(inputs) | |
# print(a.shape) | |
a = a.reshape(-1, 2 * hidden_size) | |
b = self.relu(self.linear1(a)) | |
c = self.relu(self.linear2(b)) | |
return F.log_softmax(self.linear3(c), dim=1) | |
########################################################## | |
# A simple method of making small jumps in the tangent space. | |
########################################################## | |
def tangent_jump(v, alpha=5e-2, eps=1e-8): | |
if isinstance(v, nn.Module): | |
for p in v.parameters(): | |
if p is None: continue | |
tangent_jump(p, alpha) | |
else: | |
g = v.grad.data | |
gl = torch.norm(g) | |
gn = g / (gl + eps) | |
while True: | |
# Random vector. | |
r = torch.randn_like(v.data) | |
# Ensure we are jumping in the direction orthogonal to the gradient. | |
r = r - gn * torch.sum(r * gn) | |
rl = torch.norm(r) | |
r = r / (rl + eps) | |
if rl >= eps: | |
break | |
# Just a heuristic | |
l = np.sqrt(len(v.data) - 1) | |
v.data += r * l * alpha | |
########################################################## | |
# Optimization | |
########################################################## | |
model = Model() | |
optimizer = optim.Adam(model.parameters(), lr=0.01) | |
loss_function = nn.NLLLoss() | |
total_loss = 0 | |
acc_train_acc = 0 | |
acc_val_acc = 0 | |
DATA = [] | |
N_EPOCHS = 10000 | |
# Turn this off if you want to see what happens without | |
# stochastic jumps in the tangent space. | |
USE_TANGENT_JUMPS = True | |
for epoch in range(N_EPOCHS + 1): | |
# Note: If we were using small batches, there would be some | |
# gradient noise, that would result in small jumps | |
# in the tangent space. When not using batches, we | |
# can introduce those directly. | |
########################################################## | |
# Jumping in the tangent space. | |
########################################################## | |
if USE_TANGENT_JUMPS: | |
model.zero_grad() | |
log_probs = model(TRAIN[:,:2]) | |
loss = loss_function(log_probs, TRAIN[:,2]) | |
loss.backward() | |
alpha = 1e-1 | |
tangent_jump(model, alpha) | |
########################################################## | |
# Regular (non-batched) optimization | |
########################################################## | |
model.zero_grad() | |
log_probs = model(TRAIN[:,:2]) | |
loss = loss_function(log_probs, TRAIN[:,2]) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
########################################################## | |
# Reporting and stop condition. | |
########################################################## | |
train_acc = torch.sum(TRAIN[:,2] == torch.argmax(model(TRAIN[:,:2]), dim=1)) / TRAIN.shape[0] | |
val_acc = torch.sum(TEST[:,2] == torch.argmax(model(TEST[:,:2]), dim=1)) / TEST.shape[0] | |
DATA.append([epoch, train_acc, val_acc]) | |
if train_acc == 1.0 and val_acc == 1.0: | |
break | |
acc_train_acc += train_acc.item() | |
acc_val_acc += val_acc.item() | |
if epoch % 100 == 0: | |
print(f'{epoch}: L={total_loss / 100}, Train Acc={acc_train_acc:.3f}%, Val Acc={acc_val_acc:.3f}%') | |
total_loss = 0 | |
acc_train_acc = 0 | |
acc_val_acc = 0 | |
########################################################## | |
# Plot the accuracy. | |
########################################################## | |
DATA = np.asarray(DATA) | |
plt.figure(figsize=(16, 12), facecolor='white') | |
plt.plot(DATA[:,0], DATA[:,1]) | |
plt.plot(DATA[:,0], DATA[:,2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment