Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Created November 22, 2024 08:19
Show Gist options
  • Save buttercutter/ee5db3654a5e5c3a7dbe7de9861402ab to your computer and use it in GitHub Desktop.
Save buttercutter/ee5db3654a5e5c3a7dbe7de9861402ab to your computer and use it in GitHub Desktop.
Simple masked language modeling code
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForMaskedLM
from datasets import load_dataset
import random
import numpy as np
from tqdm import tqdm
# Set random seed for reproducibility
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
set_seed(42)
# Custom dataset class for masked language modeling
class MaskedLanguageDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=128, mlm_probability=0.15):
self.tokenizer = tokenizer
self.texts = texts
self.max_length = max_length
self.mlm_probability = mlm_probability
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
# Tokenize text
encoding = self.tokenizer(
self.texts[idx],
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].squeeze()
attention_mask = encoding['attention_mask'].squeeze()
masked_input_ids = input_ids.clone().detach()
# Create masked input
labels = input_ids.clone()
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = self.tokenizer.get_special_tokens_mask(
labels, already_has_special_tokens=True)
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool),
value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100
# 80% of the time, replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
masked_input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
masked_input_ids[indices_random] = random_words[indices_random]
return {
'input_ids': input_ids,
'masked_input_ids': masked_input_ids,
'attention_mask': attention_mask,
'labels': labels
}
# Load AG News dataset
dataset = load_dataset("ag_news")
train_texts = dataset["train"]["text"]
# Initialize tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
# Create dataset and dataloader
train_dataset = MaskedLanguageDataset(train_texts, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# Training settings
device = torch.device('cuda' if torch.cuda.is_available() else 'mps')
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 3
# Training loop
model.train()
for epoch in range(num_epochs):
total_loss = 0
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
for batch in progress_bar:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
masked_input_ids = batch['masked_input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=masked_input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(outputs.logits.view(-1, outputs.logits.size(-1)), input_ids.view(-1))
#loss = outputs.loss
total_loss += loss.item()
loss.backward()
optimizer.step()
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
# Save the model
model.save_pretrained('mlm_ag_news_model')
tokenizer.save_pretrained('mlm_ag_news_model')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment