Last active
August 4, 2023 04:47
-
-
Save NaxAlpha/0b63348cd19395779cd4b021888c2fb4 to your computer and use it in GitHub Desktop.
Train a semantic text compressor, potentially useful for very long context language modeling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from time import sleep | |
from functools import partial | |
from threading import Thread, Lock | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torch.nn.utils.rnn as rnn | |
from torch.utils.data import DataLoader, IterableDataset | |
import wandb | |
from tqdm import tqdm | |
from datasets import load_dataset | |
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast | |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention | |
MODEL_NAME = "EleutherAI/pythia-1.4b-deduped-v0" | |
WANDB_STYLE = """ | |
<style> | |
html, body { | |
padding: 0; | |
margin: 0; | |
width: 100%; | |
height: 100%; | |
} | |
p { | |
font-family: 'Verdana', sans-serif; | |
} | |
table { | |
border-collapse: collapse; | |
width: 100%; | |
} | |
td, th { | |
border: 1px solid #999999; | |
text-align: left; | |
padding: 8px; | |
} | |
tr:nth-child(even) { | |
background-color: #eeeeee; | |
} | |
pre { | |
white-space: pre-wrap; | |
} | |
</style> | |
""" | |
class LoRALinear(nn.Module): | |
def __init__(self, finp, fout, r=4): | |
super().__init__() | |
self.finp = finp | |
self.fout = fout | |
self.r = r | |
self.fc1 = nn.Linear(finp, r) | |
self.fc2 = nn.Linear(r, fout) | |
self.fc2.weight.data.zero_() | |
self.fc2.bias.data.zero_() | |
def forward(self, x): | |
return self.fc2(self.fc1(x)) | |
class LoRAWrapper(nn.Module): | |
def __init__(self, main, lora): | |
super().__init__() | |
self.main = main | |
self.lora = lora | |
def forward(self, x): | |
return self.main(x) + self.lora(x) | |
class Conceptor(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.base_model = GPTNeoXForCausalLM.from_pretrained( | |
MODEL_NAME, | |
) | |
self.hh = hh = self.base_model.config.hidden_size | |
self.embeddings = nn.Embedding(2, hh) | |
self.base_model.requires_grad_(False) | |
self.loras = self.make_lora() | |
def save(self, path): | |
sd = dict( | |
loras=self.loras.state_dict(), | |
embeddings=self.embeddings.state_dict(), | |
) | |
torch.save(sd, path) | |
def make_lora(self): | |
layers = [] | |
for module in self.base_model.modules(): | |
if isinstance(module, GPTNeoXAttention): | |
lora = LoRALinear(self.hh, 3 * self.hh) | |
layers.append(lora) | |
module.query_key_value = LoRAWrapper(module.query_key_value, lora) | |
lora = LoRALinear(self.hh, self.hh) | |
layers.append(lora) | |
module.dense = LoRAWrapper(module.dense, lora) | |
return nn.ModuleList(layers) | |
def _encode(self, tokens, sizes): | |
B, T = tokens.size() | |
indices = torch.arange(B, device=tokens.device) | |
embeddings = self.base_model.gpt_neox.embed_in(tokens) | |
# replace size'th token with self.embeddings(0) | |
embeddings[indices, sizes] = self.embeddings.weight[0] | |
output = self.base_model.gpt_neox(inputs_embeds=embeddings) | |
# take size'th token from each sequence | |
context = output.last_hidden_state[indices, sizes] | |
return context, embeddings | |
def forward(self, tokens, sizes): | |
B, T = tokens.size() | |
context, embeddings = self._encode(tokens, sizes) | |
# combine context and self.embeddings(1) and tokens | |
emb_token = self.embeddings.weight[1].unsqueeze(0).expand(B, -1) | |
context = torch.cat( | |
[context[:, None], emb_token[:, None], embeddings[:, :-2]], dim=1 | |
) | |
# compute logits | |
logits = self.base_model(inputs_embeds=context).logits | |
logits = logits[:, 1:].reshape(-1, logits.size(-1)) | |
# compute loss | |
targets = tokens[:, :-1].contiguous().view(-1) | |
loss = F.cross_entropy(logits, targets, reduction="none") | |
loss = loss.reshape(B, -1) | |
loss_mask = torch.arange(T - 1, device=tokens.device) < sizes[:, None] | |
loss = loss.masked_fill(~loss_mask, 0.0) | |
loss = loss.sum(dim=1) / sizes.float() | |
loss = loss.mean() | |
return loss | |
def encode(self, tokens, sizes): | |
return self._encode(tokens, sizes)[0] | |
@torch.no_grad() | |
def sample(self, context, max_tokens=128, temperature=1.0): | |
B = context.size(0) | |
emb_token = self.embeddings.weight[1].unsqueeze(0).expand(B, -1) | |
context = torch.cat([context[:, None], emb_token[:, None]], dim=1) | |
output_tokens = [] | |
for _ in range(max_tokens): | |
logits = self.base_model(inputs_embeds=context).logits | |
logits = logits[:, -1, :] / temperature | |
token = torch.multinomial(logits.softmax(dim=-1), num_samples=1) | |
output_tokens.append(token) | |
token_emb = self.base_model.gpt_neox.embed_in(token) | |
context = torch.cat([context, token_emb], dim=1) | |
return torch.cat(output_tokens, dim=-1) | |
class DatasetWrapper(IterableDataset): | |
def __init__(self, min_tokens=1, max_tokens=32): | |
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(MODEL_NAME) | |
self.min_tokens = min_tokens | |
self.max_tokens = max_tokens | |
self._buffer = [] | |
self._min_buffer = 10_000 | |
self._max_buffer = 20_000 | |
self._lock = None | |
self._thread = None | |
def _worker(self): | |
temp_buffer = [] | |
for sample in load_dataset( | |
"EleutherAI/the_pile_deduplicated", | |
split="train", | |
streaming=True, | |
).shuffle(buffer_size=1000): | |
text = sample["text"] + "<|endofdoc|>" | |
tokens = self.tokenizer.encode(text) | |
temp_buffer.extend(tokens) | |
# crop into chunks | |
while len(temp_buffer) >= self.max_tokens: | |
size = random.randrange(self.min_tokens, self.max_tokens - 1) | |
crop = temp_buffer[:size] + [self.tokenizer.eos_token_id] * 2 | |
with self._lock: | |
self._buffer.append(torch.tensor(crop)) | |
temp_buffer = temp_buffer[size:] | |
sleep(0.001) | |
# wait for buffer to drain | |
while len(self._buffer) >= self._max_buffer: | |
sleep(0.1) | |
def __iter__(self): | |
self._lock = Lock() | |
self._thread = Thread(target=self._worker, daemon=True) | |
self._thread.start() | |
while True: | |
while len(self._buffer) < self._min_buffer: | |
sleep(0.1) | |
with self._lock: | |
idx = random.randrange(len(self._buffer)) | |
sample = self._buffer.pop(idx) | |
yield sample | |
def dl_collate_fn(batch, pad_token_id): | |
lengths = [t.size(0) for t in batch] | |
tokens = rnn.pad_sequence( | |
batch, | |
batch_first=True, | |
padding_value=pad_token_id, | |
) | |
return tokens, torch.tensor(lengths) - 1 | |
class Trainer: | |
def __init__(self): | |
self.dataset = DatasetWrapper() | |
self.loader = DataLoader( | |
self.dataset, | |
batch_size=32, | |
num_workers=8, | |
collate_fn=partial( | |
dl_collate_fn, | |
pad_token_id=self.dataset.tokenizer.eos_token_id, | |
), | |
) | |
self.model = model = Conceptor().cuda() | |
print("Model parameters:", sum(p.numel() for p in model.parameters())) | |
print( | |
"Trainable parameters:", | |
sum(p.numel() for p in model.parameters() if p.requires_grad), | |
) | |
self.opt = optim.Adam( | |
params=model.parameters(), | |
lr=6e-5, | |
fused=True, | |
) | |
# self.model = torch.compile(model) | |
def train_step(self, tokens, lengths): | |
self.opt.zero_grad() | |
loss = self.model(tokens.cuda(), lengths.cuda()) | |
loss.backward() | |
self.opt.step() | |
return loss | |
def _detokenize(self, tokens): | |
eos = self.dataset.tokenizer.eos_token_id | |
output = [] | |
for tkn in tokens: | |
# stop at first EOS in token list | |
idx = tkn.index(eos) if eos in tkn else len(tkn) | |
text = self.dataset.tokenizer.decode(tkn[:idx]) | |
output.append(text) | |
return output | |
def generate(self, tokens, sizes): | |
self.model.eval() | |
mem = self.model.encode(tokens, sizes) | |
out = self.model.sample(mem, tokens.size(1), temperature=0.1).tolist() | |
self.model.train() | |
original = self._detokenize(tokens.tolist()) | |
generated = self._detokenize(out) | |
table = "<table><tr><th>Original</th><th>Generated</th></tr>" | |
for o, g in zip(original, generated): | |
table += f"<tr><td><pre>{o}</pre></td><td><pre>{g}</pre></td></tr>" | |
table += "</table>" | |
return table | |
def train(self): | |
wandb.init( | |
project="conceptor", | |
entity="_", | |
) | |
sd = torch.load("model-v3.pt") | |
self.model.load_state_dict(sd) | |
del sd | |
prog = tqdm(self.loader) | |
for i, (tokens, lengths) in enumerate(prog): | |
loss = self.train_step(tokens, lengths) | |
prog.set_description(f"loss: {loss.item():.3f}") | |
wandb.log( | |
{ | |
"loss": loss.item(), | |
"size": lengths.float().mean().item(), | |
"avgb": tokens.size(1), | |
}, | |
step=i, | |
) | |
if i % 200 == 0: | |
table = self.generate(tokens.cuda(), lengths.cuda()) | |
wandb.log(dict(diff=wandb.Html(WANDB_STYLE + table)), step=i) | |
self.model.save("model.pt") | |
if __name__ == "__main__": | |
trainer = Trainer() | |
trainer.train() |
View raw
(Sorry about that, but we can’t show files that are this big right now.)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment