Created
November 22, 2024 08:19
-
-
Save buttercutter/ee5db3654a5e5c3a7dbe7de9861402ab to your computer and use it in GitHub Desktop.
Simple masked language modeling code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch.utils.data import Dataset, DataLoader | |
from transformers import BertTokenizer, BertForMaskedLM | |
from datasets import load_dataset | |
import random | |
import numpy as np | |
from tqdm import tqdm | |
# Set random seed for reproducibility | |
def set_seed(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
set_seed(42) | |
# Custom dataset class for masked language modeling | |
class MaskedLanguageDataset(Dataset): | |
def __init__(self, texts, tokenizer, max_length=128, mlm_probability=0.15): | |
self.tokenizer = tokenizer | |
self.texts = texts | |
self.max_length = max_length | |
self.mlm_probability = mlm_probability | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, idx): | |
# Tokenize text | |
encoding = self.tokenizer( | |
self.texts[idx], | |
truncation=True, | |
padding='max_length', | |
max_length=self.max_length, | |
return_tensors='pt' | |
) | |
input_ids = encoding['input_ids'].squeeze() | |
attention_mask = encoding['attention_mask'].squeeze() | |
masked_input_ids = input_ids.clone().detach() | |
# Create masked input | |
labels = input_ids.clone() | |
probability_matrix = torch.full(labels.shape, self.mlm_probability) | |
special_tokens_mask = self.tokenizer.get_special_tokens_mask( | |
labels, already_has_special_tokens=True) | |
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), | |
value=0.0) | |
masked_indices = torch.bernoulli(probability_matrix).bool() | |
labels[~masked_indices] = -100 | |
# 80% of the time, replace masked input tokens with tokenizer.mask_token ([MASK]) | |
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices | |
masked_input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) | |
# 10% of the time, replace masked input tokens with random word | |
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced | |
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) | |
masked_input_ids[indices_random] = random_words[indices_random] | |
return { | |
'input_ids': input_ids, | |
'masked_input_ids': masked_input_ids, | |
'attention_mask': attention_mask, | |
'labels': labels | |
} | |
# Load AG News dataset | |
dataset = load_dataset("ag_news") | |
train_texts = dataset["train"]["text"] | |
# Initialize tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertForMaskedLM.from_pretrained('bert-base-uncased') | |
# Create dataset and dataloader | |
train_dataset = MaskedLanguageDataset(train_texts, tokenizer) | |
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) | |
# Training settings | |
device = torch.device('cuda' if torch.cuda.is_available() else 'mps') | |
model.to(device) | |
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) | |
num_epochs = 3 | |
# Training loop | |
model.train() | |
for epoch in range(num_epochs): | |
total_loss = 0 | |
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}') | |
for batch in progress_bar: | |
optimizer.zero_grad() | |
input_ids = batch['input_ids'].to(device) | |
masked_input_ids = batch['masked_input_ids'].to(device) | |
attention_mask = batch['attention_mask'].to(device) | |
labels = batch['labels'].to(device) | |
outputs = model( | |
input_ids=masked_input_ids, | |
attention_mask=attention_mask, | |
labels=labels | |
) | |
loss = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(outputs.logits.view(-1, outputs.logits.size(-1)), input_ids.view(-1)) | |
#loss = outputs.loss | |
total_loss += loss.item() | |
loss.backward() | |
optimizer.step() | |
progress_bar.set_postfix({'loss': loss.item()}) | |
avg_loss = total_loss / len(train_loader) | |
print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}') | |
# Save the model | |
model.save_pretrained('mlm_ag_news_model') | |
tokenizer.save_pretrained('mlm_ag_news_model') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment