Created
September 15, 2017 13:18
-
-
Save dhpollack/37077cc109fb1af15cb41dd617c47c30 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data as data | |
from torch.autograd import Variable | |
from tqdm import tnrange, tqdm_notebook, tqdm | |
"""My attempt at Karpathy's char-rnn from the unreasonableness of RNNs post | |
currently the loss goes down, but it spits out gibberish | |
""" | |
class SimpleGRU(nn.Module): | |
def __init__(self, vocab_size, emb_size, hid_size, batch_size, seq_len, n_layers=1): | |
super(SimpleGRU, self).__init__() | |
self.vocab_size = vocab_size | |
self.emb_size = emb_size | |
self.hid_size = hid_size | |
self.n_layers = n_layers | |
self.batch_size = batch_size | |
self.seq_len = seq_len | |
self.emb = nn.Embedding(vocab_size, emb_size) | |
self.gru = nn.GRU(emb_size, hid_size, batch_first=True) | |
self.fc1 = nn.Linear(seq_len * hid_size, vocab_size) | |
self.relu = nn.ReLU() | |
self.selu = nn.SELU() | |
self.logsoftmax = nn.LogSoftmax() | |
self.batchnorm = nn.BatchNorm1d(emb_size-1) | |
def forward(self, input, hidden): | |
self.sizes = [] | |
self.sizes.append((inputs.size(), hidden.size())) | |
x = self.emb(input) | |
x = self.batchnorm(x) | |
self.sizes.append(x.size()) | |
x, hidden = self.gru(x, hidden) | |
x = x.contiguous().view(self.batch_size, -1) | |
x = self.selu(self.fc1(x)) | |
self.sizes.append((x.size(), hidden.size())) | |
x = self.logsoftmax(x) | |
self.sizes.append(x.size()) | |
return x, hidden | |
class CharDataset(data.Dataset): | |
def __init__(self, data, seq_len): | |
self.data = data | |
self.seq_len = seq_len | |
def __getitem__(self, index): | |
inp_seq = self.data[index:(index+self.seq_len-1)] | |
tgt_seq = torch.Tensor([self.data[index+self.seq_len]]).type(self.data.type()) | |
return inp_seq, tgt_seq | |
def __len__(self): | |
return len(self.data) - self.seq_len | |
seq_length = 25 | |
batch_size = 250 | |
emb_size = 25 | |
hid_size = 100 | |
n_layers = 3 | |
with open("/home/david/Programming/data/project_gutenberg/tiny-shakespeare.txt", "r") as f: | |
text_raw = [c for l in f.readlines() for c in l] | |
charset = sorted(list(set(text_raw))) | |
c2i = {c: i for i, c in enumerate(charset)} | |
i2c = {i: c for c, i in c2i.items()} | |
text_idx = [c2i[c] for c in text_raw] | |
tqdm.write("{} {}".format(len(text_idx), len(text_raw))) | |
inputs = torch.Tensor(text_idx).long() | |
tqdm.write("{}".format(inputs.size())) | |
ds = CharDataset(inputs, seq_length) | |
dl = data.DataLoader(ds, batch_size=batch_size, drop_last=True) | |
tqdm.write("{}".format(len(dl))) | |
vocab_size = len(charset) | |
num_batches = len(dl) | |
epochs = 20 | |
lr = 0.003 | |
momo = 0.9 | |
model = SimpleGRU(vocab_size, emb_size, hid_size, batch_size, seq_length-1, n_layers) | |
model.load_state_dict(torch.load("checkpoints/char-rnn/model_19.pt")) | |
criterion = nn.NLLLoss() | |
#optimizer = optim.Adam(model.parameters(), lr=lr) | |
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momo) | |
tqdm.write("{}".format(model)) | |
#epoch_bar = tqdm(range(epochs), desc="epochs", disable=True) | |
train = False | |
if train: | |
for epoch in range(epochs): | |
running_loss = 0 | |
batch_bar = tqdm(enumerate(dl), desc="batches") | |
for i, (mb, tgts) in batch_bar: | |
h = Variable(torch.zeros(n_layers,batch_size, hid_size)) | |
tgts.squeeze_() | |
model.train() | |
model.zero_grad() | |
mb, tgts = Variable(mb), Variable(tgts) | |
out, h = model(mb, h) | |
loss = criterion(out, tgts) | |
loss.backward() | |
optimizer.step() | |
h.detach_() | |
running_loss += loss.data[0] | |
if i % 25 == 0 and i > 0 or i == num_batches - 1: | |
batch_bar.set_postfix(ave_loss=running_loss / (i+1), last_loss=loss.data[0]) | |
pass | |
#epoch_bar.set_postfix(prevloss=(running_loss / num_batches)) | |
torch.save(model.state_dict(), "model_{}.pt".format(epoch)) | |
tqdm.write("epoch {}".format(epoch+1)) | |
else: | |
batch_bar = tqdm(enumerate(dl), desc="batches") | |
model.eval() | |
pred = [] | |
for i, (mb, tgts) in batch_bar: | |
h = Variable(torch.zeros(n_layers,batch_size, hid_size)) | |
tgts.squeeze_() | |
mb, tgts = Variable(mb), Variable(tgts) | |
out, h = model(mb, h) | |
pred.append(out.data.max(1)[1]) | |
if i == 100: | |
break | |
pred = torch.cat(pred) | |
pred_c = [i2c[i] for i in pred] | |
print("".join(pred_c[2000:2350])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment