Created
May 24, 2023 23:29
-
-
Save NaxAlpha/3d69432aa81a9ab47dee70c7a16ad8a5 to your computer and use it in GitHub Desktop.
Fine-tune Pythia model on Multimodal C4 dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# WIP: Fine-tuned a Causal LM with images & text mixed on MMC4 Dataset | |
import os | |
import json | |
import random | |
from PIL import Image | |
from concurrent.futures import ThreadPoolExecutor | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.backends.cuda as cuda | |
from torch.utils.data import IterableDataset, DataLoader, get_worker_info | |
import timm | |
import timm.data | |
from transformers import AutoTokenizer, GPTNeoXForCausalLM | |
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention | |
def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None): | |
assert attention_mask is None and head_mask is None, "Not implemented" | |
with cuda.sdp_kernel(enable_math=False, enable_flash=False): | |
out = F.scaled_dot_product_attention( | |
query.half(), | |
key.half(), | |
value.half(), | |
is_causal=True, | |
).float() | |
return out, None | |
# patch attention to save a lot of memory | |
GPTNeoXAttention._attn = _attn_wrapper | |
class MultiModalC4(IterableDataset): | |
def __init__( | |
self, | |
dataset_path, | |
tokenizer_name, | |
image_model_name, | |
image_tokens=49, | |
max_seq_len=1024, | |
image_token_id=1, | |
cache_buffer_size=1000, | |
): | |
self.path = dataset_path | |
# jsonl files: | |
shards = [ | |
os.path.join(self.path, f) | |
for f in os.listdir(self.path) | |
if f.endswith(".jsonl") | |
] | |
self.shards = sorted(shards) | |
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
self.image_tokens = image_tokens | |
self.image_token_id = image_token_id | |
self.max_seq_len = max_seq_len | |
self.cache_buffer_size = cache_buffer_size | |
data_config = timm.data.resolve_model_data_config(image_model_name) | |
self.transforms = timm.data.create_transform(**data_config, is_training=True) | |
def _load_image(self, shard_id, image_info): | |
base_path = os.path.join(self.path, f"{shard_id}") | |
img_path = os.path.join(base_path, image_info["image_name"]) | |
idx = image_info["matched_text_index"] | |
try: | |
img = Image.open(img_path) | |
except: | |
img = None | |
return img, idx | |
def _merge_images_texts(self, texts, img_map): | |
txt_list = texts | |
output = [] | |
for i, txt in enumerate(txt_list): | |
if i in img_map: | |
# always put image first | |
output.append(img_map[i]) | |
output.append(txt) | |
else: | |
output.append(txt) | |
result = [output[0]] | |
for x in output[1:]: | |
if isinstance(x, str) and isinstance(result[-1], str): | |
result[-1] += "\n" + x | |
else: | |
result.append(x) | |
return result | |
def _flatten_merged(self, merged): | |
processed_images = [] | |
imgs_starts = [] | |
text_tokens = [] | |
for i, x in enumerate(merged): | |
if isinstance(x, str): | |
text_tokens += self.tokenizer.encode(x) | |
else: | |
img_token_placeholder = [self.image_token_id] * (self.image_tokens + 2) | |
img = self.transforms(x.convert("RGB")) | |
processed_images.append(img) | |
imgs_starts.append(len(text_tokens) + 1) | |
text_tokens += img_token_placeholder | |
if processed_images: | |
processed_images = torch.stack(processed_images) | |
imgs_starts = torch.tensor(imgs_starts) | |
else: | |
processed_images = torch.zeros(0, 3, 224, 224) | |
imgs_starts = torch.zeros(0) | |
return dict( | |
text_tokens=text_tokens, | |
processed_images=processed_images, | |
imgs_starts=imgs_starts, | |
) | |
def _stream_shard(self, shard_id): | |
with ThreadPoolExecutor() as executor: | |
with open(self.shards[shard_id], "r") as f: | |
for line in f: | |
obj = json.loads(line) | |
imgs = executor.map( | |
self._load_image, | |
[shard_id] * len(obj["image_info"]), | |
obj["image_info"], | |
) | |
imgs, idxs = zip(*imgs) | |
img_map = {idx: img for idx, img in zip(idxs, imgs) if img} | |
merged = self._merge_images_texts(obj["text_list"], img_map) | |
yield self._flatten_merged(merged) | |
def _stream_all(self, rnd): | |
shard_ids = list(range(len(self.shards))) | |
rnd.shuffle(shard_ids) | |
for shard_id in shard_ids: | |
yield from self._stream_shard(shard_id) | |
def _buffered_stream(self): | |
wi = get_worker_info() | |
if wi is None: | |
seed = None | |
else: | |
seed = wi.seed | |
rnd = random.Random(seed) | |
buffer = [] | |
for doc in self._stream_all(rnd): | |
buffer.append(doc) | |
if len(buffer) >= self.cache_buffer_size: | |
idx = rnd.randint(0, len(buffer) - 1) | |
yield buffer.pop(idx) | |
yield from buffer | |
def _find_trainable_seq_range(self, text_tokens): | |
idx = 0 | |
buffer_found = False | |
while idx + self.max_seq_len < len(text_tokens): | |
if ( | |
text_tokens[idx] != self.image_token_id | |
and text_tokens[idx + self.max_seq_len] != self.image_token_id | |
): | |
buffer_found = True | |
break | |
idx += 1 | |
return buffer_found, idx | |
def _find_images_on_crop(self, idx, imgs_starts): | |
in_range = (imgs_starts >= idx) * (imgs_starts <= idx + self.max_seq_len) | |
indieces = torch.nonzero(in_range).squeeze(1).tolist() | |
if not indieces: | |
return 0, 0 | |
sid, eid = indieces[0], indieces[-1] + 1 | |
return sid, eid | |
def _joined_docs(self): | |
text_tokens = [] | |
# image_counts = [] | |
imgs_starts = None | |
processed_images = None | |
for doc in self._buffered_stream(): | |
if processed_images is None: | |
processed_images = doc["processed_images"] | |
imgs_starts = doc["imgs_starts"] | |
else: | |
text_tokens += [self.tokenizer.eos_token_id] | |
processed_images = torch.cat( | |
( | |
processed_images, | |
doc["processed_images"], | |
) | |
) | |
new_starts = doc["imgs_starts"] + len(text_tokens) | |
imgs_starts = torch.cat((imgs_starts, new_starts)) | |
# image_counts.append(len(doc["imgs_starts"])) | |
text_tokens += doc["text_tokens"] | |
if len(text_tokens) < self.max_seq_len + 1: | |
continue | |
buffer_found, idx = self._find_trainable_seq_range(text_tokens) | |
im_idx1, im_idx2 = self._find_images_on_crop(idx, imgs_starts) | |
if buffer_found: | |
# send cropped buffer | |
yield dict( | |
text_tokens=text_tokens[idx : idx + self.max_seq_len + 1], | |
imgs_starts=imgs_starts[im_idx1:im_idx2] - idx, | |
processed_images=processed_images[im_idx1:im_idx2], | |
) | |
idx += self.max_seq_len + 1 | |
# destroy the buffer till idx | |
text_tokens = text_tokens[idx:] | |
# image_counts = image_counts[im_idx:] | |
imgs_starts = imgs_starts[im_idx2:] - idx | |
processed_images = processed_images[im_idx2:] | |
def __iter__(self): | |
for doc in self._joined_docs(): | |
yield doc | |
def mmc4_collate_fn(batch): | |
text_tokens = [] | |
processed_images = [] | |
imgs_starts = [] | |
imgs_counts = [] | |
for doc in batch: | |
text_tokens.append(doc["text_tokens"]) | |
processed_images.append(doc["processed_images"]) | |
imgs_starts.append(doc["imgs_starts"]) | |
imgs_counts.append(len(doc["imgs_starts"])) | |
text_tokens = torch.tensor(text_tokens) | |
processed_images = torch.cat(processed_images) | |
imgs_starts = torch.cat(imgs_starts) | |
imgs_counts = torch.tensor(imgs_counts) | |
return dict( | |
text_tokens=text_tokens, | |
processed_images=processed_images, | |
imgs_starts=imgs_starts, | |
imgs_counts=imgs_counts, | |
) | |
class MultiModalPythia(nn.Module): | |
def __init__(self, transformer_model, image_model, image_token_id=1): | |
super().__init__() | |
self.transformer = GPTNeoXForCausalLM.from_pretrained(transformer_model) | |
self.vision = timm.create_model( | |
image_model, | |
pretrained=True, | |
num_classes=0, | |
) | |
vis_emb = self.vision.embed_dim[-1] | |
lm_emb = self.transformer.config.hidden_size | |
self.proj = nn.Linear(vis_emb, lm_emb) | |
self.image_token_id = image_token_id | |
def forward(self, text_tokens, processed_images, imgs_starts, imgs_counts): | |
inp_txt = text_tokens[:, :-1] | |
out_txt = text_tokens[:, 1:].clone() | |
out_txt[out_txt == self.image_token_id] = -100 | |
txt_emb = self.transformer.gpt_neox.embed_in(inp_txt) | |
if processed_images.size(0) > 0: | |
img_emb = self.vision.forward_features(processed_images) | |
img_emb = img_emb.view(*img_emb.shape[:2], -1).permute(0, 2, 1) | |
img_emb = self.proj(img_emb) | |
N = img_emb.shape[1] | |
imgs_counts = [0] + imgs_counts.tolist() | |
for i, j in zip(imgs_counts[:-1], imgs_counts[1:]): | |
imgs = img_emb[i:j] | |
starts = imgs_starts[i:j] | |
for s, img in zip(starts, imgs): | |
txt_emb[s : s + N] = img | |
logits = self.transformer(inputs_embeds=txt_emb).logits | |
loss = F.cross_entropy( | |
logits.view(-1, logits.shape[-1]), | |
out_txt.reshape(-1), | |
ignore_index=-100, | |
) | |
return loss, logits | |
if __name__ == "__main__": | |
from tqdm import tqdm | |
from torch.optim import Adam | |
ds = MultiModalC4( | |
"../mmc4-ff", | |
"EleutherAI/pythia-1b-deduped", | |
"focalnet_large_fl4.ms_in22k", | |
image_tokens=49, | |
max_seq_len=1024, | |
image_token_id=1, | |
) | |
max_images = 0 | |
# prog = tqdm(ds) | |
# for i, x in enumerate(prog): | |
# t = x["text_tokens"] | |
# q = [-1] * len(t) | |
# for ims in x["imgs_starts"].long().tolist(): | |
# q[ims - 1 : ims + 50] = [1] * 51 | |
# assert torch.tensor(q == 1).sum() == torch.tensor(t == 1).sum() | |
# max_images = max(max_images, len(x["imgs_starts"])) | |
# prog.set_postfix(max_images=max_images) | |
loader = DataLoader( | |
dataset=ds, | |
batch_size=4, | |
num_workers=4, | |
collate_fn=mmc4_collate_fn, | |
) | |
dev = "cuda" | |
prog = tqdm(loader) | |
model = MultiModalPythia( | |
"EleutherAI/pythia-1b-deduped", | |
"focalnet_large_fl4.ms_in22k", | |
image_token_id=1, | |
).to(dev) | |
opt = Adam(model.parameters(), lr=1e-5) | |
for i, x in enumerate(prog): | |
x = {k: v.to(dev) for k, v in x.items()} | |
loss, logits = model(**x) | |
opt.zero_grad() | |
loss.backward() | |
opt.step() | |
prog.set_postfix(loss=loss.item()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment