Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Last active February 21, 2025 14:20
Show Gist options
  • Save buttercutter/7b99cfff0a2d48f9c8befeadee6dd54e to your computer and use it in GitHub Desktop.
Save buttercutter/7b99cfff0a2d48f9c8befeadee6dd54e to your computer and use it in GitHub Desktop.
A simple code for [Protein Discovery with Discrete Walk-Jump Sampling](http://arxiv.org/abs/2306.12360)
# Credit : gpt-4o , Claude-3.5-Sonnet-200k , Gemini-Pro-1.5
# Reference :
# [Protein Discovery with Discrete Walk-Jump Sampling](http://arxiv.org/abs/2306.12360)
# [Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion](http://arxiv.org/abs/2407.01392)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy
import os
import random
import string
from collections import Counter
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForSeq2SeqLM, T5Tokenizer, T5ForConditionalGeneration
from transformers.optimization import Adafactor, AdafactorSchedule
from torch.nn.utils.rnn import pad_sequence
from torch.nn.utils.parametrizations import weight_norm
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from adam_mini import Adam_mini
# the models have been trained / finetuned, run inference code only
INFERENCE_ONLY = 0
# Just for code development / debugging purpose
TEST_OVERFIT = 0
# for the denoiser module, choose only ONE of the following options :
USE_PRETRAINED_BERT = 0
USE_PRETRAINED_BERT_MLM = 0
USE_PRETRAINED_T5 = 0
USE_CUSTOM_TRANSFORMER_ENCODER = 0 # the most RAM memory efficient option
USE_CUSTOM_TRANSFORMER_ENCODER_DECODER = 1
# Early-stopping for the models training
USE_EARLY_STOP = 0
EARLY_STOP_THRESHOLD = 2.175 #1.91
# for sentence completion downstream task
ENABLE_MASK_LEARNING = 1
# google colab T4 GPU does not have a lot of RAM for computation
# custom transformer module can now handle multiple masked tokens
if torch.cuda.is_available(): #or USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER:
MASK_RATIO = 0.15 # use 0.15 for 15% masking probability, use the value of -1 to indicate only a single masked token
else:
MASK_RATIO = 0.15 # use 0.15 for 15% masking probability, use the value of -1 to indicate only a single masked token
# allows the denoiser model to train on [batch_size, sequence_length, vocab_size]
USE_LOGITS_FOR_THE_ENTIRE_SENTENCE = 1
USE_LOGITS_FOR_THE_ENTIRE_SENTENCE = USE_LOGITS_FOR_THE_ENTIRE_SENTENCE or (MASK_RATIO != -1) # if masking more than 1 token, then it makes sense to train on [batch_size, sequence_length, vocab_size]
# custom transformer module can now handle multiple masked tokens
#USE_LOGITS_FOR_THE_ENTIRE_SENTENCE = USE_LOGITS_FOR_THE_ENTIRE_SENTENCE and not (USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER)
# Analyze walk-jump's output samples for debugging purpose
ENABLE_SAMPLE_ANALYSIS = 0 # turns off for reducing memory consumption
if torch.cuda.is_available():
device_str = "cuda"
else:
device_str = "mps"
device = torch.device(device_str)
# Automatic Mixed Precision for training
from torch import autocast
if torch.cuda.is_available():
from torch.amp import GradScaler
USE_MIXED_PRECISION_TRAINING = 0 # optional, turns off for this code since it hurts model performance
else:
USE_MIXED_PRECISION_TRAINING = 0 # not implemented
# for saving RAM memory during training : https://github.com/zyushun/Adam-mini
USE_ADAM_MINI = 0
# 0: Sinusoidal Positional Embedding , 1: Rotary Positional Embedding
USE_ROPE = 0
# Just for code development / debugging purpose
USE_DUMMY_TRAINING_DATA = 0
# for adjusting the generation process due to fixed output length
GENERATES_OUTPUT_OF_VARYING_LENGTH = 0
# for more difficult denoising task
ADD_EXTRA_GAUSSIAN_NOISE = 0 # turns off for now
# Select between diffusion forcing and walk-jump
# if the following two variables are turned off, it would be walk-jump (single constant noise level)
USE_DIFFUSION_FORCING = 1 & ADD_EXTRA_GAUSSIAN_NOISE
USE_PRECOMPUTE_NOISE_SCHEDULE = 0 # testing only, do not recommend to use due to expensive storage
# Regarding two different approaches for Langevin MCMC sampling
USE_MCMC = 1
USE_ALGORITHM_1_OR_4 = 0 # value of 1 means Algorithm 1, value of 0 means Algorithm 4, see walk-jump paper
USE_OBABO = 0 # Using KIPLMC2 is slow because of the need to compute gradients of U with respect to both theta and X
# sequential monte-carlo (SMC)
USE_SMC = 0 # if use SMC, then ignore USE_ALGORITHM_1_OR_4 which is related to Langevin MCMC
# Markov-approximate fractional Brownian motion (MA-fBM)
USE_MAFBM = 0 # if use MAFBM, then ignore USE_ALGORITHM_1_OR_4 which is related to Langevin MCMC
# Once turned on, it will be different from the walk-jump denoise update equation
USE_LOGITS_FOR_DENOISING = 0 # consumes much more RAM memory
USE_LOGITS_FOR_DENOISING = USE_LOGITS_FOR_DENOISING and (USE_SMC or USE_MAFBM or USE_MCMC)
# kl_div method (requires extra run of denoiser model) to improve sampling based on prior distribution
# Only turn on USE_GRAD_KL if USE_PRETRAINED_T5 is disabled, because USE_GRAD_KL uses
# "tokenizer.vocab_size"-rounds of denoiser module execution, hence extremely long execution time.
# Using large pretrained T5 model as denoiser module will only worsen the runtime issue.
USE_GRAD_KL = 0
# Choose only one of the following training receipes for walk-jump sampling
USE_dWJS_ENERGY = 1
USE_dWJS_SCORE = ~USE_dWJS_ENERGY
# Define parameters
input_dim = 128
model_dim = input_dim
model_dim_ebm = model_dim >> 2 # specific only to EBM model
hidden_dim = 256
num_layers = 4
num_layers_ebm = num_layers >> 1 # specific only to EBM model
num_heads = 8
num_heads_ebm = num_heads >> 2 # specific only to EBM model
num_smc_steps = 5 # sequential monte-carlo (SMC)
N_particles = 10 # sequential monte-carlo (SMC)
hurst = 0.7 # Markov-approximate fractional Brownian motion (MA-fBM)
T_fbm = 1.0 # Markov-approximate fractional Brownian motion (MA-fBM)
n_steps = 1000 # Markov-approximate fractional Brownian motion (MA-fBM)
K_fbm = 3 # Markov-approximate fractional Brownian motion (MA-fBM)
num_walk_steps = 5 # for langevin dynamics MCMC sampling process
num_jump_steps = 20 #num_walk_steps
walk_step_size = 0.6 # for langevin dynamics MCMC sampling process
sigma_max = 1.1
sigma_min = 0.1
num_epochs = 500
batch_size = 512
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
# BERT model is larger than TransformerDenoiser() module
batch_size = batch_size >> 6
elif USE_PRETRAINED_T5:
# T5 models are way larger than both BERT model and TransformerDenoiser() module
batch_size = 1
elif USE_CUSTOM_TRANSFORMER_ENCODER_DECODER:
# we have extra decoder layers inside the TransformerDenoiser() module
batch_size = batch_size >> 4
else: # USE_CUSTOM_TRANSFORMER_ENCODER
# we do not have extra decoder layers inside the TransformerDenoiser() module
batch_size = batch_size >> 3
#if torch.cuda.is_available(): # so far colab run session has some extra unused GPU RAM on T4 GPU
# batch_size = batch_size << 2 # increasing batch_size worsens the validation loss convergence rate
# Monitors the quality of the generated samples throughout the training and validation
# processes to assess the model's performance and identify potential issues
def analyze_samples(generated_samples, tokenizer, skip_special_tokens=False, num_samples=1):
decoded_samples = []
if num_samples != 1:
num_samples = generated_samples.size(0)
for i in range(num_samples):
sample = generated_samples[i]
sample = sample.long() # Convert the sample to integer tensor
decoded_sample = tokenizer.decode(sample, skip_special_tokens=skip_special_tokens)
print(f"Sample {i+1}: {decoded_sample}")
decoded_samples.append(decoded_sample)
return decoded_samples
def assert_sample_range_compliance(sample, tokenizer):
# Assert that all token IDs are within the valid range
assert sample.min() >= 0, f"Token ID is less than 0! sample = {sample}, sample.min() = {sample.min()}"
assert sample.max() < tokenizer.vocab_size, f"Token ID exceeds valid range! Max ID: {sample.max()}, Vocab Size: {tokenizer.vocab_size}"
# Assert that the tokens input to the model are not all zeros
assert not torch.all(sample == 0), "Error: sample contains all zeros!"
return True
def check_for_vanishing_gradients(model):
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.data.norm(2)
if grad_norm < 1e-5: # Threshold for detecting vanishing gradients
print(f"Warning: Vanishing gradient detected in {name} with norm {grad_norm.item():.6f}")
if USE_PRETRAINED_T5: #or USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER:
tokenizer = AutoTokenizer.from_pretrained("pnawrot/nanoT5-base")
#tokenizer = T5Tokenizer.from_pretrained('google/t5-efficient-tiny')
else:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
def tokenizer_function(raw_sequence_input, tokenizer, max_length=input_dim):
tokenized_sequence = tokenizer(
raw_sequence_input,
padding='max_length',
truncation=True,
max_length=max_length,
return_tensors="pt"
)
return tokenized_sequence.to(device)
#print(f"tokenizer.pad_token_id = {tokenizer.pad_token_id}")
# for initializing target_label for denoiser module
CONSTANTS_VALUE_IGNORE = tokenizer.pad_token_id # -100
# for creating data loader for span-masking task
class DataCollatorForSpanCorruption:
def __init__(self, tokenizer, mlm_probability=0.15, mean_noise_span_length=3, input_length=input_dim):
self.tokenizer = tokenizer
self.mlm_probability = mlm_probability
self.mean_noise_span_length = mean_noise_span_length
self.input_length = input_length
def __call__(self, examples):
# If examples are tensors, convert them to lists
if isinstance(examples[0], torch.Tensor):
input_ids = [example.tolist() for example in examples]
attention_mask = None # No attention mask for tensor inputs
else:
# Assuming examples are dicts with 'input_ids' keys
input_ids = [example['input_ids'] for example in examples]
attention_mask = [example['attention_mask'] for example in examples] if 'attention_mask' in examples[0] else None
batch = self._collate_batch(input_ids)
# Add attention mask if it exists
if attention_mask is not None:
batch['attention_mask'] = pad_sequence(
[mask.clone().detach() for mask in attention_mask],
batch_first=True,
padding_value=0
)
return batch
def _collate_batch(self, input_ids_list):
# Pad input_ids to the same length
batch_input_ids = pad_sequence(
[ids.clone().detach() for ids in input_ids_list],
batch_first=True,
padding_value=self.tokenizer.pad_token_id
)
# Create masked inputs and labels
if USE_PRETRAINED_T5:
masked_input_ids, labels, mlm_mask = self._mask_tokens_span(batch_input_ids)
return {'input_ids': masked_input_ids, 'labels': labels, 'mask_indices': mlm_mask}
elif USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
#labels, mlm_mask = self._mask_tokens_span(batch_input_ids)
labels, mlm_mask = self._mask_tokens_standard(batch_input_ids)
return {'input_ids': batch_input_ids, 'labels': labels, 'mask_indices': mlm_mask}
else: # USE_CUSTOM_TRANSFORMER_ENCODER or USE_CUSTOM_TRANSFORMER_ENCODER_DECODER
#labels, mlm_mask = self._mask_tokens_span(batch_input_ids)
labels, mlm_mask = self._mask_tokens_standard(batch_input_ids)
return {'input_ids': batch_input_ids, 'labels': labels, 'mask_indices': mlm_mask}
# span-masking strategy
def _mask_tokens_span(self, inputs):
"""
Prepare masked tokens inputs/labels for masked span language modeling according to T5's objective.
"""
inputs = inputs.clone()
labels = torch.full(inputs.shape, self.tokenizer.pad_token_id)
special_tokens = {self.tokenizer.pad_token_id}
batch_size, seq_len = inputs.shape
mask_indices = []
# Track masking locations
mask_indices_tensor = torch.zeros_like(inputs, dtype=torch.bool)
for i in range(batch_size):
input_ids = inputs[i].tolist()
num_to_mask = max(1, int(round(seq_len * self.mlm_probability)))
# Get candidate indices to mask
candidate_indices = [
idx for idx in range(len(input_ids)) if input_ids[idx] not in special_tokens
]
# Shuffle candidate indices
random.shuffle(candidate_indices)
masked_indices = set()
current_idx = 0
spans = []
while len(masked_indices) < num_to_mask and current_idx < len(candidate_indices):
span_length = max(1, int(numpy.random.poisson(lam=self.mean_noise_span_length)))
start = candidate_indices[current_idx]
end = min(start + span_length, seq_len)
span_indices = list(range(start, end))
# Avoid overlapping spans
if any(idx in masked_indices for idx in span_indices):
current_idx += 1
continue
masked_indices.update(span_indices)
spans.append((start, end))
current_idx += 1
# Sort spans in reverse order to avoid index shifting issues
spans = sorted(spans, key=lambda x: x[0], reverse=True)
target_tokens = []
prev_end = seq_len
for idx, (start, end) in enumerate(spans):
# Replace span with sentinel token in inputs
if USE_PRETRAINED_T5:
sentinel_token_id = self.tokenizer.convert_tokens_to_ids(f'<extra_id_{idx}>')
else:
sentinel_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
inputs[i, start:end] = sentinel_token_id
# Build labels
target_tokens = [sentinel_token_id] + input_ids[start:end] + target_tokens
# Record the masked positions
for start, end in spans:
mask_indices_tensor[i, start:end] = True
# Handle unmasked positions in labels
#labels[~mask_indices_tensor] = CONSTANTS_VALUE_IGNORE
#labels[i, :len(target_tokens)] = torch.tensor(target_tokens, dtype=torch.long)
# debug prints
if len(spans) > 0:
total_masked = sum(end - start for start, end in spans)
#print(f"Sequence {i}: Created {len(spans)} spans, masking {total_masked} tokens")
#print(f"Spans: {spans}")
if USE_PRETRAINED_T5:
return inputs, labels, mask_indices_tensor # T5 masking tokens are not unique, so need to return masked "inputs"
else:
return labels, mask_indices_tensor # Return the mask information
# standard BERT masking strategy without any span-masking
def _mask_tokens_standard(self, inputs):
"""
Prepare masked tokens inputs/labels for standard masked language modeling (e.g., BERT).
"""
labels = inputs.clone()
# Create a mask for tokens to mask
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
# Set labels for masked tokens, set CONSTANTS_VALUE_IGNORE for others
#labels[~masked_indices] = CONSTANTS_VALUE_IGNORE # We only compute loss on masked tokens
# Replace masked input tokens according to BERT's strategy
# 80% of the time, replace with [MASK]
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, replace with random token
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), inputs.shape, device=device, dtype=torch.long)
indices_random = indices_random.to(device)
inputs[indices_random] = random_words[indices_random]
# The rest 10% of the time, keep the original token (do nothing)
return labels, masked_indices
sigma = 0.5 # single noise level
mask_token_penalty_weight = 1.0 # Increase this value to penalize more heavily
sep_token_penalty_weight = 1.0 # Increase this value to penalize more heavily
unused_token_penalty_weight = 0.005 # Increase this value to penalize more heavily
ebm_energy_regularization_scale = 16 # for L2 regularization on EBM loss
if USE_CUSTOM_TRANSFORMER_ENCODER_DECODER:
ebm_energy_regularization_scale = ebm_energy_regularization_scale << 1 # for L2 regularization on EBM loss
'''
log(Σ exp(x_i)) = log(Σ exp(x_i - C + C))
= log(Σ exp(x_i - C) * exp(C))
= log(exp(C) * Σ exp(x_i - C))
= log(exp(C)) + log(Σ exp(x_i - C))
= C + log(Σ exp(x_i - C))
where C is any constant.
The log_sum_exp() implementation chooses C to be max_val (the maximum value among the x_i values). Here's why this is brilliant:
1. Shifting by max_val: By subtracting max_val from each x_i before exponentiating, we ensure that:
- The largest value among x_i - max_val will be 0 (because max_val - max_val = 0).
- All other values of x_i - max_val will be negative or 0.
2. Avoiding Overflow: Since exp(0) = 1, and exp(x) for negative x is always between 0 and 1, we avoid computing exp() of large positive numbers, thus preventing overflow.
3. Reducing Underflow: While underflow might still occur for extremely small values of exp(x_i - max_val), it's less severe because we are summing these values. The sum is less likely to underflow to zero compared to individual terms.
4. Adding Back max_val: Finally, we add max_val back to the result to compensate for the subtraction we did earlier. This ensures that we get the correct value of log_sum_exp(x_i).
'''
def log_sum_exp(x):
max_val = x.max()
return max_val + torch.log(torch.sum(torch.exp(x - max_val)))
# USE_ROPE = 0
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
#print(f"pe.shape = {self.pe.shape}")
return x + self.pe[:x.size(0), :]
# USE_ROPE = 1
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.max_position_embeddings = max_position_embeddings
self.dim = dim
def forward(self, seq_len):
positions = torch.arange(seq_len, device=self.inv_freq.device)
sinusoid = torch.einsum('i,j->ij', positions, self.inv_freq)
sin = sinusoid.sin()
cos = sinusoid.cos()
return cos, sin
class RoPEMultiheadAttention(nn.Module):
def __init__(self, d_model, nhead, dropout=0.1, is_causal=False, batch_first=False):
super().__init__()
assert d_model % nhead == 0
self.head_dim = d_model // nhead
self.nhead = nhead
self.d_model = d_model
self.is_causal = is_causal
self.batch_first = batch_first
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.rope = RotaryEmbedding(self.head_dim)
def apply_rotary_emb(self, x, cos, sin):
"""
Apply rotary embeddings to the input tensor using the provided cosine and sine values.
Args:
x (torch.Tensor): Input tensor.
cos (torch.Tensor): Precomputed cosine values.
sin (torch.Tensor): Precomputed sine values.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
assert x.ndim == 4 # Ensure input is for multi-head attention
#print(f"x.ndim = {x.ndim}")
d = x.shape[3] // 2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos - x2 * sin
y2 = x1 * sin + x2 * cos
return torch.cat([y1, y2], 3).type_as(x)
def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
if self.batch_first:
batch_size, tgt_len, embed_dim = query.shape
else:
tgt_len, batch_size, embed_dim = query.shape
src_len = key.shape[1]
scaling = float(self.head_dim) ** -0.5
q = self.q_proj(query).view(batch_size, tgt_len, self.nhead, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, src_len, self.nhead, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, src_len, self.nhead, self.head_dim).transpose(1, 2)
# Apply RoPE to Q and K
cos, sin = self.rope(max(src_len, tgt_len))
q = self.apply_rotary_emb(q, cos, sin)
k = self.apply_rotary_emb(k, cos, sin)
# Attention weights
attn = torch.matmul(q, k.transpose(-2, -1)) * scaling
# Apply causal mask for decoder self-attention
if self.is_causal:
causal_mask = torch.triu(torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=q.device), diagonal=1)
attn = attn.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
if attn_mask is not None:
attn += attn_mask
if key_padding_mask is not None:
attn = attn.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
# Attention output
output = torch.matmul(attn, v)
#print(f"After attention, output.shape = {output.shape}, tgt_len = {tgt_len}")
# This is for USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_DECODER
# Before the final reshape, handle the case where tgt_len < embed_dim
if tgt_len < embed_dim:
# Method 1: Pad with zeros to reach embed_dim
padding = torch.zeros(batch_size, self.nhead, embed_dim - tgt_len, self.head_dim, device=output.device)
#print(f"padding.shape = {padding.shape}")
output = torch.cat([output, padding], dim=2)
tgt_len = embed_dim
# OR Method 2: Repeat the output to reach embed_dim
# output = output.repeat_interleave(math.ceil(embed_dim / tgt_len), dim=1)[:, :embed_dim, :]
output = output.transpose(1, 2).contiguous().view(batch_size, tgt_len, embed_dim)
#print(f"After reshape view, output.shape = {output.shape}")
output = self.out_proj(output)
return output
# Encoder Layer: Uses single self-attention (bidirectional)
class RoPETransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, batch_first=False):
super().__init__()
# Single self-attention layer (non-causal/bidirectional)
self.self_attn = RoPEMultiheadAttention(d_model, nhead, dropout=dropout, is_causal=False, batch_first=batch_first)
# One set of normalization and feedforward
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
# Two layer norms (pre-attention and pre-FFN)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Two dropouts
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=False):
x = src
# Single self-attention block
attn_output = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x),
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)
x = x + self.dropout1(attn_output)
# Single feedforward block
ff_output = self.linear2(self.dropout(F.relu(self.linear1(self.norm2(x)))))
x = x + self.dropout2(ff_output)
return x
# Decoder Layer: Uses both self-attention (causal) and cross-attention
class RoPETransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, batch_first=False):
super().__init__()
# Causal self-attention for decoder
self.self_attn = RoPEMultiheadAttention(d_model, nhead, dropout=dropout, is_causal=True, batch_first=batch_first)
# Cross-attention to connect with encoder outputs
self.multihead_attn = RoPEMultiheadAttention(d_model, nhead, dropout=dropout, is_causal=False, batch_first=batch_first)
# Same feedforward as encoder
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
# Three layer norms (pre-self-attn, pre-cross-attn, pre-FFN)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
# Three dropouts
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
tgt_key_padding_mask=None, memory_key_padding_mask=None,
memory_is_causal=True, tgt_is_causal=True):
x = tgt
# Self-attention block (causal)
attn_output = self.self_attn(
self.norm1(x), self.norm1(x), self.norm1(x),
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)
#print(f"In RoPETransformerDecoderLayer(), x.shape = {x.shape}, attn_output = {attn_output.shape}")
if USE_CUSTOM_TRANSFORMER_ENCODER_DECODER:
x = x.mean(dim=1).unsqueeze(1)
x = x + self.dropout1(attn_output)
# Cross-attention block
cross_attn_output = self.multihead_attn(
self.norm2(x), self.norm2(memory), self.norm2(memory),
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)
x = x + self.dropout2(cross_attn_output)
# Feedforward block
ff_output = self.linear2(self.dropout(F.relu(self.linear1(self.norm3(x)))))
x = x + self.dropout3(ff_output)
return x
# USE_PRETRAINED_BERT = 1
class BertDenoiser(nn.Module):
def __init__(self, model_dim, use_bert_mlm=USE_PRETRAINED_BERT_MLM): # model_dim == sequence_length
super(BertDenoiser, self).__init__()
self.use_bert_mlm = use_bert_mlm
self.final_layer = nn.Linear(tokenizer.vocab_size, 1)
# SiLU layer
self.SiLU = nn.SiLU()
# ReLU layer
self.ReLU = nn.ReLU()
if self.use_bert_mlm:
self.model = AutoModelForMaskedLM.from_pretrained("prajjwal1/bert-tiny").to(device)
#self.model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased").to(device)
self.config = AutoConfig.from_pretrained("prajjwal1/bert-tiny")
#self.config = AutoConfig.from_pretrained("google-bert/bert-base-cased")
else:
self.model = AutoModel.from_pretrained("prajjwal1/bert-tiny").to(device)
#self.model = AutoModel.from_pretrained("bert-base-uncased").to(device)
self.middle_layer = nn.Linear(self.model.config.hidden_size, tokenizer.vocab_size)
self.dropout = nn.Dropout(0.2) # Add dropout
# Apply Xavier/Glorot or He initialization
#self._initialize_weights()
# Initialize the final layer
#nn.init.xavier_uniform_(self.final_layer_A.weight)
#nn.init.xavier_uniform_(self.final_layer_B.weight)
#nn.init.zeros_(self.final_layer_A.bias)
#nn.init.zeros_(self.final_layer_B.bias)
def initialize_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.01)
elif isinstance(m, nn.Conv1d):
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
m.bias.data.fill_(0.01)
def _initialize_weights(self):
for name, param in self.named_parameters():
if 'weight' in name:
if isinstance(param, torch.nn.Parameter):
if param.dim() > 1: # Only apply to matrices, not biases
if 'self_attn' in name or 'multihead_attn' in name:
torch.nn.init.xavier_uniform_(param) # Xavier for attention layers
else:
torch.nn.init.kaiming_uniform_(param, nonlinearity='relu') # He for ReLU-based layers
elif 'bias' in name:
torch.nn.init.zeros_(param) # Biases are usually initialized to zero
def forward(self, inputs, mlm_mask=None):
if isinstance(inputs, dict):
# Convert input_ids to long tensor
input_ids = inputs['input_ids'].long()
labels = inputs['labels'].to(device)
attention_mask = inputs['attention_mask']
else:
input_ids = inputs.long()
labels = input_ids.clone().detach()
attention_mask = (input_ids != tokenizer.pad_token_id).long()
#print(f"input_ids.shape = {input_ids.shape}")
if self.use_bert_mlm:
# Process text
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
).logits
else:
# Process text
outputs = self.model(
input_ids,
attention_mask=attention_mask,
labels=labels
).last_hidden_state # [batch_size, sequence_length, hidden_size=128]
outputs = self.middle_layer(outputs)
# Shape: [batch_size, seq_len, vocab_size]
#print(f"outputs.shape = {outputs.shape}")
outputs = self.dropout(outputs) # Apply dropout
denoised_sentence = self.final_layer(outputs).squeeze(-1) # shape : [batch_size, sequence_length]
# Apply activation function
denoised_sentence = self.SiLU(denoised_sentence)
if mlm_mask is not None:
masked_positions = mlm_mask.bool()
denoised_masked_token_logits = outputs[masked_positions] # shape : [batch_size, vocab_size] if MASK_RATIO = -1
denoised_token_logits = outputs
#print(f"denoised_sentence.shape = {denoised_sentence.shape}, denoised_masked_token_logits.shape = {denoised_masked_token_logits.shape}, denoised_token_logits.shape = {denoised_token_logits.shape}")
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
return denoised_sentence, denoised_masked_token_logits, denoised_token_logits
else:
return denoised_sentence, denoised_masked_token_logits
else:
return denoised_sentence
'''
Original Sentence: "The quick brown fox jumps over the lazy dog."
Masked Encoder Input (input_ids): ['The', 'quick', '‹extra_id_0>', 'jumps', 'over', 'the', '<extra_id_1>', 'dog', '.']
Decoder's Target Output (labels): ['<extra_id_0>', 'brown', 'fox', '<extra_id_1>', 'lazy']
Decoder's Generation Process:
-----------------------------------------------------------------------------------
Time Step | Decoder Input Token | Target Label Token | Prediction Objective
-----------------------------------------------------------------------------------
t=0 | ‹pad> | ‹extra_id_0> | Predict < extra_id_0>
t=1 | ‹extra_id_0> | brown | Predict brown
t=2 | brown | fox | Predict fox
t=3 | fox | ‹extra_id_ 1> | Predict < extra_id_1›
t=4 | ‹extra_id_ 1> | lazy | Predict lazy
-----------------------------------------------------------------------------------
Note: There are no timesteps corresponding to 'jumps', 'over', 'the' in the decoder's output because these tokens are unmasked and present in the encoder input.
'''
# USE_PRETRAINED_T5 = 1
class T5Denoiser(nn.Module):
def __init__(self, model_dim):
super(T5Denoiser, self).__init__()
#self.model = T5ForConditionalGeneration.from_pretrained('google/t5-efficient-tiny')
self.model = AutoModelForSeq2SeqLM.from_pretrained("pnawrot/nanoT5-base")
# Projection layer to map logits space back to sequence_length (which is same as model_sim)
self.projection_A = nn.Sequential(
nn.Linear(self.model.config.vocab_size, model_dim),
#nn.ReLU() # no need of activation function before being fed into cross-entropy loss function
)
# Projection layer to map logits space back to a single token embedding
self.projection_B = nn.Sequential(
nn.Linear(self.model.config.vocab_size, 1),
#nn.ReLU() # no need of activation function before being fed into cross-entropy loss function
)
def forward(self, input_ids, target_label=None, decoder_input_ids=None, mlm_mask=None):
if decoder_input_ids is not None:
# Shift tgt to the right to create decoder input ids
decoder_input_ids = self.model._shift_right(decoder_input_ids)
else:
batch_size = input_ids.size(0)
# Use the decoder start token and expand it to match the batch size
# If tgt is not provided, use the BOS token as the initial input for decoder
decoder_start_token = torch.tensor([[self.model.config.decoder_start_token_id]], device=device)
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=device)
decoder_input_ids = torch.cat((decoder_input_ids, decoder_start_token.expand(batch_size, -1)), dim=1)
#print(f"decoder_input_ids.shape = {decoder_input_ids.shape}")
# Generate output logits
# We do not need to manually feed in decoder_input_ids, we let the model handles them internally during training
output = self.model(input_ids=input_ids.long(), labels=target_label)
#output = self.model(input_ids=input_ids.long(), decoder_input_ids=decoder_input_ids.long())
output = output.logits # shape : [batch_size, tgt_sequence_length, vocab_size]
#print(f"output.shape = {output.shape}")
if ENABLE_MASK_LEARNING: # there is a new token concatenated to tgt tensor
# Get the most recent timestep prediction
# We want to update denoised_sentence based on the prediction for the last token in the sequence
denoised_sentence = output[:, -1, :] # Select the last timestep
else:
# Remove unnecessary dimension
denoised_sentence = output.squeeze(1)
#print(f"denoised_sentence.shape = {denoised_sentence.shape}")
# denoised_sentence has a shape of [batch_size, vocab_size]
# projection_A layer uses almost same amount of RAM as projection_B layer (which relies on broadcast operation)
# We should not use torch.max() because introduces non-differentiable points, hindering gradient-based optimization.
# Besides, only the maximum value receives a gradient; all other inputs get zero gradients, which is inefficient for learning.
if mlm_mask is not None:
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
denoised_token_logits = output
denoised_masked_token_logits = denoised_sentence
denoised_sentence = self.projection_A(denoised_sentence) # shape : [batch_size, sequence_length]
#denoised_sentence = self.projection_B(denoised_sentence) # shape : [batch_size, 1]
#denoised_sentence, _ = torch.max(denoised_sentence, dim=-1, keepdim=True) # shape : [batch_size, 1]
return denoised_sentence, denoised_masked_token_logits, denoised_token_logits
else:
denoised_masked_token_logits = denoised_sentence
denoised_sentence = self.projection_A(denoised_sentence) # shape : [batch_size, sequence_length]
#denoised_sentence = self.projection_B(denoised_sentence) # shape : [batch_size, 1]
#denoised_sentence, _ = torch.max(denoised_sentence, dim=-1, keepdim=True) # shape : [batch_size, 1]
return denoised_sentence, denoised_masked_token_logits
else:
denoised_sentence = self.projection_A(denoised_sentence) # shape : [batch_size, sequence_length]
#denoised_sentence = self.projection_B(denoised_sentence) # shape : [batch_size, 1]
#denoised_sentence, _ = torch.max(denoised_sentence, dim=-1, keepdim=True) # shape : [batch_size, 1]
return denoised_sentence
# USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER
class TransformerDenoiser(nn.Module):
def __init__(self, input_dim, model_dim, num_layers, num_heads, max_noise_level):
super(TransformerDenoiser, self).__init__()
self.embedding = nn.Embedding(tokenizer.vocab_size, model_dim)
#self.noise_level_embeddings = nn.Embedding(max_noise_level, model_dim)
if not USE_ROPE:
self.pos_encoder = PositionalEncoding(model_dim)
self.pos_decoder = PositionalEncoding(model_dim)
if USE_ROPE:
# Use RoPE Transformer Encoder layers
encoder_layers = RoPETransformerEncoderLayer(
model_dim,
num_heads,
model_dim,
batch_first=True
)
# Use RoPE Transformer Decoder layers
decoder_layers = RoPETransformerDecoderLayer(
model_dim,
num_heads,
model_dim,
batch_first=True
)
else:
# Use Transformer Encoder Layers from Pytorch library
encoder_layers = nn.TransformerEncoderLayer(model_dim, num_heads, model_dim, batch_first=True)
# Use Transformer Decoder Layers from Pytorch library
decoder_layers = nn.TransformerDecoderLayer(model_dim, num_heads, model_dim, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers)
# Layer Normalization to prevent vanishing gradients
self.norm = nn.LayerNorm(model_dim)
# SiLU layer
self.SiLU = nn.SiLU()
#Sigmoid layer
#self.Sigmoid = nn.Sigmoid()
"""
# Projection layer to map single token embedding back to logits space
self.projection = nn.Sequential(
nn.Linear(1, tokenizer.vocab_size),
#nn.ReLU() # no need of activation function before being fed into cross-entropy loss function
)
"""
# Projection layer (tie weights with embedding)
self.projection = nn.Linear(model_dim, tokenizer.vocab_size)
self.projection.weight = self.embedding.weight # Weight tying
"""
self.denoise_head = nn.Sequential(
nn.Linear(model_dim, model_dim),
nn.ReLU()
)
"""
# Convolutional denoise head does not depend on input_dim or
# input sequence length. This is helpful in NLP domain, because the
# NLP model will see varying input sequence length
# Weight normalization is one technique to address vanishing gradients
self.denoise_head = nn.Sequential(
weight_norm(nn.Conv1d(in_channels=model_dim, out_channels=model_dim, kernel_size=3, padding=1)),
#nn.SiLU(),
#weight_norm(nn.Conv1d(in_channels=model_dim, out_channels=model_dim, kernel_size=3, padding=1)),
#nn.SiLU(),
#weight_norm(nn.Conv1d(in_channels=model_dim, out_channels=model_dim, kernel_size=3, padding=1)),
#nn.ReLU()
)
# Apply Xavier/Glorot or He initialization
self._initialize_weights()
def initialize_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight)
if m.bias is not None:
m.bias.data.fill_(0.01)
elif isinstance(m, nn.Conv1d):
torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
m.bias.data.fill_(0.01)
def _initialize_weights(self):
for name, param in self.named_parameters():
if 'weight' in name:
if isinstance(param, torch.nn.Parameter):
if param.dim() > 1: # Only apply to matrices, not biases
if 'self_attn' in name or 'multihead_attn' in name:
torch.nn.init.xavier_uniform_(param) # Xavier for attention layers
else:
torch.nn.init.kaiming_uniform_(param, nonlinearity='relu') # He for ReLU-based layers
elif 'bias' in name:
torch.nn.init.zeros_(param) # Biases are usually initialized to zero
#nn.init.xavier_uniform_(self.projection[0].weight)
#nn.init.zeros_(self.projection[0].bias)
# for decoder only
def _shift_right(self, input_ids, start_token_id):
"""
Shift input_ids to the right by one position and prepend the start_token_id.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.size())
shifted_input_ids[:, 0] = start_token_id
shifted_input_ids[:, 1:] = input_ids[:, :-1]
return shifted_input_ids
def forward(self, inputs, tgt=None, input_pad_mask=None, mlm_mask=None):
if isinstance(inputs, dict):
src = inputs['input_ids']
else:
src = inputs
# src: [batch_size, sequence_length]
# Embed input tokens
src = self.embedding(src.long()) # [batch_size, sequence_length, model_dim]
#print(f"After nn.embedding(), src.shape = {src.shape}")
# Saves memory
del inputs
# Add sequence length dimension
#src = src.unsqueeze(1)
if tgt is not None:
#tgt = tgt.unsqueeze(2)
# Determine the start token ID based on tokenizer and model
if USE_PRETRAINED_T5:
start_token_id = tokenizer.pad_token_id # T5 uses pad_token_id as start token
else:
start_token_id = tokenizer.cls_token_id # BERT uses cls_token_id as start token
# Shift tgt to the right
tgt = self._shift_right(tgt, start_token_id)
tgt = self.embedding(tgt.long()) # [batch_size, sequence_length, model_dim]
#print(f"After nn.embedding(), src.shape = {src.shape}")
# Apply input masking for padding if provided
if input_pad_mask is not None:
src = src.masked_fill(input_pad_mask.unsqueeze(1), 0.0)
#print(f"After masked_fill, src = {src}")
if not USE_ROPE:
# Add positional encodings
src = self.pos_encoder(src)
#print(f"src.shape = {src.shape}")
if tgt is not None:
#print(f"Before pos_decoder(), tgt.shape = {tgt.shape}")
tgt = self.pos_decoder(tgt)
# Pre-Norm: Apply LayerNorm before the encoder layer
src = self.norm(src)
# Pass through the transformer encoder
memory = self.transformer_encoder(src)#, src_key_padding_mask=input_pad_mask)
#print(f"memory.shape = {memory.shape}")
# Add residual connection
memory = memory + src
# Pre-Norm: Apply LayerNorm before the decoder layer
memory = self.norm(memory)
if tgt is not None:
tgt = self.norm(tgt)
#print(f"tgt.shape = {tgt.shape}")
# Decoder
if tgt is not None:
#print(f"Before transformer_decoder(), tgt.shape = {tgt.shape} , memory.shape = {memory.shape}")
output = self.transformer_decoder(tgt, memory)#, tgt_key_padding_mask=input_pad_mask, memory_key_padding_mask=input_pad_mask)
#print(f"After transformer_decoder(), output.shape = {output.shape}")
else:
output = memory # bypass the decoder for EBM module under dWJS_SCORE mode
if tgt is not None:
#print(f"In TransformerDenoiser(), output.shape = {output.shape}, tgt = {tgt.shape}")
# Add residual connection
if USE_ROPE:
output = output + tgt.mean(dim=1).unsqueeze(1) # Residual connection
else:
output = output + tgt # Residual connection
output = output.mean(dim=1).unsqueeze(1)
#print(f"In TransformerDenoiser(), output.shape = {output.shape}, src = {src.shape}")
# Add residual connection
output = output + src # Residual connection
# Apply normalization
output = self.norm(output)
# Transpose for Conv1d: (batch_size, model_dim, 1)
output = output.transpose(1, 2)
# Pass through the denoising head
#print(f"Before denoise_head(), output.shape = {output.shape}, output = {output}")
output = self.denoise_head(output)
#print(f"After denoise_head(), output.shape = {output.shape}, output = {output}")
# Transpose back to original shape: (batch_size, 1, model_dim)
output = output.transpose(1, 2)
# Add residual connection
output = output + src
# Add residual connection
#output = output + memory
# Add residual connection
if tgt is not None:
output = output + tgt.mean(dim=1).unsqueeze(1) # Residual connection
# Apply normalization
output = self.norm(output)
# Apply activation function
output = self.SiLU(output)
#output = self.Sigmoid(output)
#print(f"output.shape = {output.shape}") # shape: [batch_size, tgt_sequence_length, src_sequence_length]
if ENABLE_MASK_LEARNING: # there is a new token concatenated to tgt tensor
# Get the most recent timestep prediction
# We want to update denoised_sentence based on the prediction for the last token in the sequence
#denoised_sentence = output[:, -1, :] # Select the last timestep
denoised_sentence = output.mean(dim=1) # shape: [batch_size, model_dim]
else:
# Remove unnecessary dimension
denoised_sentence = output.squeeze(1) # shape: [batch_size, model_dim]
#print(f"denoised_sentence.shape = {denoised_sentence.shape}")
# denoised_sentence has a shape of [batch_size, sequence_length]
# Projects a single token back to logits space, so this is the opposite of softmax operation
if mlm_mask is not None:
#print(f"mlm_mask.shape = {mlm_mask.shape}")
denoised_masked_token_logits = self.projection(denoised_sentence) # shape: [batch_size, vocab_size]
#denoised_masked_token_logits = self.projection(denoised_sentence[mlm_mask.bool()].unsqueeze(-1))
#print(f"denoised_masked_token_logits.shape = {denoised_masked_token_logits.shape}")
#print(f"denoised_sentence.shape = {denoised_sentence.shape}, denoised_masked_token_logits.shape = {denoised_masked_token_logits.shape}, denoised_token_logits.shape = {denoised_token_logits.shape}")
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
denoised_token_logits = self.projection(output) # shape: [batch_size, sequence_length, vocab_size]
#print(f"denoised_token_logits.shape = {denoised_token_logits.shape}")
return denoised_sentence, denoised_masked_token_logits, denoised_token_logits
else:
return denoised_sentence, denoised_masked_token_logits
else:
return denoised_sentence
# will switch to transformer model due to varying input sequence length as well as
# the higher-order gradient issue as described in http://arxiv.org/abs/1907.05600
class EnergyBasedModel(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(EnergyBasedModel, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.net(x)
def denoiser_model(noisy_y, mlm_mask=None, target_label=None, tgt=None, input_pad_mask=None):
denoised_sentence = None
denoised_masked_token_logits = None
denoised_token_logits = None
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
if mlm_mask is not None:
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
denoised_sentence, denoised_masked_token_logits, denoised_token_logits = denoiser(inputs=noisy_y, mlm_mask=mlm_mask)
else:
denoised_sentence, denoised_masked_token_logits = denoiser(inputs=noisy_y, mlm_mask=mlm_mask)
else:
denoised_sentence = denoiser(inputs=noisy_y, mlm_mask=mlm_mask)
elif USE_PRETRAINED_T5:
if mlm_mask is not None:
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
denoised_sentence, denoised_masked_token_logits, denoised_token_logits = denoiser(input_ids=noisy_y, target_label=target_label, mlm_mask=mlm_mask)
else:
denoised_sentence, denoised_masked_token_logits = denoiser(input_ids=noisy_y, target_label=target_label, mlm_mask=mlm_mask)
else:
denoised_sentence = denoiser(input_ids=noisy_y, target_label=target_label, mlm_mask=mlm_mask)
else: # USE_CUSTOM_TRANSFORMER
# Use denoiser with the current noisy sequence (src) and current target sequence (tgt)
if mlm_mask is not None:
if USE_CUSTOM_TRANSFORMER_ENCODER_DECODER:
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
denoised_sentence, denoised_masked_token_logits, denoised_token_logits = denoiser(noisy_y, tgt, input_pad_mask, mlm_mask)
else:
denoised_sentence, denoised_masked_token_logits = denoiser(noisy_y, tgt, input_pad_mask, mlm_mask)
else: # USE_CUSTOM_TRANSFORMER_ENCODER
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
denoised_sentence, denoised_masked_token_logits, denoised_token_logits = denoiser(noisy_y, None, None, mlm_mask) # for isolating decoder related code
else:
denoised_sentence, denoised_masked_token_logits = denoiser(noisy_y, None, None, mlm_mask) # for isolating decoder related code
else:
if USE_CUSTOM_TRANSFORMER_ENCODER_DECODER:
denoised_sentence = denoiser(noisy_y, tgt, input_pad_mask, mlm_mask)
else: # USE_CUSTOM_TRANSFORMER_ENCODER
denoised_sentence = denoiser(noisy_y, None, None, mlm_mask) # for isolating decoder related code
return denoised_sentence, denoised_masked_token_logits, denoised_token_logits
# sequential monte-carlo (SMC)
def proposal(particle):
# Clone the particle to avoid in-place modifications
new_particle = particle.clone()
batch_size, seq_length = new_particle.size()
# Number of modifications per sample
num_modifications = max(1, int(0.05 * seq_length)) # Modify 5% of tokens
# Generate random indices and tokens for all samples
indices = torch.randint(0, seq_length, (batch_size, num_modifications), device=particle.device)
random_tokens = torch.randint(0, tokenizer.vocab_size, (batch_size, num_modifications), device=particle.device)
# Create batch indices for advanced indexing
batch_indices = torch.arange(batch_size, device=particle.device).unsqueeze(1).expand(-1, num_modifications)
# Modify the new_particle tensor
new_particle[batch_indices, indices] = random_tokens.float().to(device)
return new_particle
# sequential monte-carlo (SMC)
def compute_weights(particles, ebm):
# particles: list of length N_particles, each tensor of shape [batch_size, seq_length]
#print(f"particles[0].shape = {particles[0].shape}")
batch_size = particles[0].size(0)
N_particles = len(particles)
with torch.no_grad():
# Compute scalar energies for all particles
energies = torch.stack([ebm(particle).sum(dim=1) for particle in particles], dim=1) # Shape: [batch_size, N_particles]
#print(f"energies.shape = {energies.shape}")
# Convert energies to weights
weights = torch.exp(-energies) # Lower energy = higher probability
weights = weights / weights.sum(dim=1, keepdim=True) # Normalize over particles
# Verify the shape of weights
#print(f"weights shape after computation: {weights.shape}") # Should be [batch_size, N_particles]
return weights # Shape: [batch_size, N_particles]
# sequential monte-carlo (SMC)
def resample(particles, weights):
# particles: list of length N_particles, each tensor of shape [batch_size, seq_length]
# weights: tensor of shape [batch_size, N_particles]
batch_size = particles[0].size(0)
N_particles = len(particles)
seq_length = particles[0].size(1)
# Stack particles to create a tensor of shape [N_particles, batch_size, seq_length]
particles_tensor = torch.stack(particles, dim=0) # Shape: [N_particles, batch_size, seq_length]
# Transpose to shape [batch_size, N_particles, seq_length]
particles_tensor = particles_tensor.permute(1, 0, 2) # Shape: [batch_size, N_particles, seq_length]
# Perform batch-wise multinomial sampling
# particle_indices: tensor of shape [batch_size, N_particles]
particle_indices = torch.multinomial(weights, N_particles, replacement=True)
# Create batch indices
batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, N_particles) # Shape: [batch_size, N_particles]
# Gather the resampled particles
resampled_particles_tensor = particles_tensor[batch_indices, particle_indices] # Shape: [batch_size, N_particles, seq_length]
# Transpose back to [N_particles, batch_size, seq_length]
resampled_particles_tensor = resampled_particles_tensor.permute(1, 0, 2) # Shape: [N_particles, batch_size, seq_length]
# Split into a list of particles
resampled_particles = [resampled_particles_tensor[i] for i in range(N_particles)]
return resampled_particles
# USE_MAFBM
class MA_fBM:
def __init__(self, hurst, T, n_steps, K):
"""
Initialize MA-fBM with optimal coefficients from the GFDM paper.
Args:
hurst (float): Hurst parameter H in (0,1)
T (float): Terminal time
n_steps (int): Number of time steps
K (int): Number of Ornstein-Uhlenbeck processes
device (str): 'cuda' for GPU, 'cpu' for CPU
"""
self.H = hurst
self.T = T
self.n_steps = n_steps
self.K = K
self.dt = T / n_steps
# Calculate gamma grid according to paper
n = (K + 1) / 2
r = 1.5 # r > 1 , Geometric spacing parameter from paper
self.gammas = torch.tensor([r**(k-n) for k in range(1, K+1)],
device=device, dtype=torch.float32)
# Compute optimal coefficients
self.weights = self._compute_optimal_coefficients()
def _torch_gamma_inc(self, a, x):
"""
PyTorch implementation of regularized lower incomplete gamma function P(a,x)
P(a,x) = 1/Γ(a) ∫₀ˣ t^(a-1) e^(-t) dt
Args:
a: Shape parameter (Hurst index + 0.5)
x: Upper limit of integration (gamma_k * T)
Uses series expansion approximation:
e^(-t) = ∑_{n=0}^∞ (-t)^n/n!
γ(a,x) = ∫₀ˣ t^(a-1) e^(-t) dt
= ∫₀ˣ t^(a-1) ∑_{n=0}^∞ (-t)^n/n! dt
= ∑_{n=0}^∞ (-1)^n/n! ∫₀ˣ t^(a-1+n) dt
= ∑_{n=0}^∞ (-1)^n/n! [t^(a+n)/(a+n)]₀ˣ
= ∑_{n=0}^∞ (-1)^n x^(a+n)/(n!(a+n))
P(a,x) = γ(a,x)/Γ(a)
= (1/Γ(a)) ∑_{n=0}^∞ (-1)^n x^(a+n)/(n!(a+n))
"""
eps = 1e-8 # Convergence threshold
iterations = 100 # Maximum iterations
# Convert inputs to tensors
a = torch.tensor(a, device=device, dtype=torch.float32)
x = torch.as_tensor(x, device=device, dtype=torch.float32)
# Initialize sum
result = torch.zeros_like(x, device=device)
temp = torch.ones_like(x, device=device)
temp = x**a # Initialize and start with x^a
factorial = 1
# Compute series expansion
for n in range(iterations):
# Update term: term[n] = term[n-1] * (-1)^n * x / (n! * (a+n))
# Use -x instead of x to get (-1)^n term
factorial = factorial * (n + 1) if n > 0 else 1
temp = temp * (-x) / (factorial * (a + n)) # self-multiply gives the x^n term
result += temp
# Check convergence when terms become very small (< eps)
# Adding more terms won't significantly change the result
if torch.all(torch.abs(temp) < eps):
break
# P(a,x) = result / Γ(a)
return result / torch.exp(torch.lgamma(a))
def _compute_optimal_coefficients(self):
"""
Compute optimal approximation coefficients following the GFDM paper's equation (9) : Aω = b
"""
# Create matrix A and vector b
A = torch.zeros((self.K, self.K), device=device)
b = torch.zeros(self.K, device=device)
# Compute matrix A
for i in range(self.K):
for j in range(self.K):
gamma_i = self.gammas[i]
gamma_j = self.gammas[j]
A[i,j] = (2*self.T + (torch.exp(-(gamma_i + gamma_j)*self.T) - 1) /
(gamma_i + gamma_j)) / (gamma_i + gamma_j)
# Compute vector b
z = self.H + 0.5
for k in range(self.K):
gamma_k = self.gammas[k]
x = gamma_k * self.T
# Compute regularized incomplete gamma functions
P_1 = self._torch_gamma_inc(z, x) # P(H+1/2, γₖT)
P_2 = self._torch_gamma_inc(z + 1, x) # P(H+3/2, γₖT)
# Implementation of b formula,
# see Appendix D.2 (Type II) of [Variational Inference for SDEs Driven by Fractional Noise](http://arxiv.org/abs/2310.12975)
b[k] = (self.T * P_1 / (gamma_k**z) -
(self.H + 0.5) * P_2 / (gamma_k**(z+1)))
# Solve linear system Aω = b for optimal weights
if device_str == 'mps':
weights = self.conjugate_gradient_solver(A, b)
else:
weights = torch.linalg.solve(A, b) # linalg.solve has no MPS backend support yet
return weights
def conjugate_gradient_solver(self, A, b, tol=1e-6, max_iter=1000):
"""
Solve Ax = b using the Conjugate Gradient method.
Args:
A (torch.Tensor): Symmetric positive definite matrix of shape [N, N].
b (torch.Tensor): Right-hand side vector of shape [N].
tol (float): Tolerance for convergence.
max_iter (int): Maximum number of iterations.
Returns:
x (torch.Tensor): Solution vector of shape [N].
"""
x = torch.zeros_like(b) # Initial guess
r = b - torch.matmul(A, x)
p = r.clone()
rs_old = torch.dot(r, r)
for i in range(max_iter):
Ap = torch.matmul(A, p)
alpha = rs_old / torch.dot(p, Ap)
x = x + alpha * p
r = r - alpha * Ap
rs_new = torch.dot(r, r)
if torch.sqrt(rs_new) < tol:
break
p = r + (rs_new / rs_old) * p
rs_old = rs_new
return x
@torch.no_grad()
def simulate(self, return_processes=False):
"""
Generate a sample path of MA-fBM.
Args:
return_processes (bool): If True, also return individual OU processes
Returns:
tuple: Time points, MA-fBM path, and optionally OU processes
"""
times = torch.linspace(0, self.T, self.n_steps, device=device)
ou_processes = torch.zeros((self.K, self.n_steps),
device=device)
# Generate Brownian increments
dW = torch.randn(self.K, self.n_steps-1,
device=device) * torch.sqrt(torch.tensor(self.dt, device=device))
# Simulate OU processes
for i in range(1, self.n_steps):
ou_processes[:, i] = (ou_processes[:, i-1] * torch.exp(-self.gammas * self.dt) +
dW[:, i-1])
# Combine OU processes to get MA-fBM
mafbm_path = torch.sum(self.weights.reshape(-1,1) * ou_processes, dim=0)
if return_processes:
return times, mafbm_path, ou_processes
return times, mafbm_path
# Define the Langevin MCMC step function based on Algorithm 4 of the walk-jump paper
def langevin_mcmc_step_advanced(y, v, mlm_mask, input_pad_mask, ebm, denoiser, step_size, u=1.0, gamma=0.1, K=1):
# See Algorithm 1 of [Langevin Dynamics with Variable Coefficients and Nonconservative Forces: From Stationary States to Numerical Methods](https://www.mdpi.com/1099-4300/19/12/647)
step_size = torch.tensor(step_size, device=device) # Convert step_size to tensor
# Gather the correct tokens from the train data based on the mask positions
#train_data_correct = y[mask.bool()].long()
def energy_func(input_tensor):
return ebm(input_tensor).sum() # This is where the EBM is used
"""
Formula Breakdown:
- (w_t - w_t_minus_1)**2 measures the squared distance between the two states
- exp(-((w_t - w_t_minus_1)**2) / (2 * self.sigma**2)) gives the unnormalized probability
- The denominator (self.sigma * math.sqrt(2 * math.pi)) normalizes the probability
Walk-Jump Perspective:
- In the "walk" phase, we're exploring the smoothed data space using this forward process
- This Gaussian model allows for continuous transitions between states, which is key to the walk-jump approach for discrete data
- The "jump" phase then projects these smoothed states back to the discrete space
"""
def forward_process(w_t, w_t_minus_1):
# Normalize inputs such that `diff` does not result in very large values
w_t_norm = (w_t - w_t.mean()) / w_t.std()
w_t_minus_1_norm = (w_t_minus_1 - w_t_minus_1.mean()) / w_t_minus_1.std()
# Apply forward process on normalized inputs
diff = w_t_norm - w_t_minus_1_norm
return torch.exp(-(diff**2) / (2 * sigma**2)) / (sigma * math.sqrt(2 * math.pi))
# p(ŵ|w_t, w_{t-1})
def auxiliary_prob(w_hat, w_t, w_t_minus_1, denoiser):
# Construct the noisy input
noisy_input = w_t # Shape: [batch_size, seq_len]
# Provide w_t_minus_1 as context
context = w_t_minus_1 # Shape: [batch_size, seq_len]
#print(f"noisy_input.shape = {noisy_input.shape}, context.shape = {context.shape}")
# Apply the denoiser to predict possible clean versions of w_t, given w_t and w_{t-1}
denoised_sentence, denoised_masked_token_logits = denoiser(src=noisy_input, tgt=context, input_pad_mask=input_pad_mask, mlm_mask=mlm_mask) # Shape: [batch_size, seq_len] , [batch_size, vocab_size]
# Get probabilities
probs = F.softmax(denoised_masked_token_logits.squeeze(1), dim=-1) # Shape: [batch_size, vocab_size]
# The probability the denoiser assigns to ŵ can be interpreted as how likely ŵ is to be the "clean" version of w_t, given w_t and w_{t-1}
# Return the probability of w_hat
return probs[0, w_hat]
def prior():
# Could be uniform or based on token frequencies in the dataset
return 1.0 / tokenizer.vocab_size
# to compute p(w_{t-1}|w_t, ŵ) which is the probability of the previous state given the current state and the auxiliary variable
def compute_transition_probability(w_t, w_hat, w_t_minus_1, denoiser, position=None, prev_word=None):
print(f"w_t.shape = {w_t.shape}, w_hat = {w_hat}, w_t_minus_1.shape = {w_t_minus_1.shape}")
# p(w_t|w_{t-1})
p_w_t_given_w_t_minus_1 = forward_process(w_t, w_t_minus_1)
print(f"forward_process() gives {p_w_t_given_w_t_minus_1} which has a shape of {p_w_t_given_w_t_minus_1.shape}")
# p(ŵ|w_t, w_{t-1})
p_w_hat_given_w_t_w_t_minus_1 = auxiliary_prob(w_hat, w_t, w_t_minus_1, denoiser=denoiser)
print(f"auxiliary_prob() gives {p_w_hat_given_w_t_w_t_minus_1} which has a shape of {p_w_hat_given_w_t_w_t_minus_1.shape}")
# p(w_{t-1})
p_w_t_minus_1 = prior() # prior(w_t_minus_1, position, prev_word)
# Compute p(w_{t-1}, w_t, ŵ)
p_w_t_minus_1_w_t_w_hat = p_w_t_given_w_t_minus_1 * p_w_hat_given_w_t_w_t_minus_1 * p_w_t_minus_1
print(f"p_w_t_minus_1_w_t_w_hat = {p_w_t_minus_1_w_t_w_hat}")
# Approximate p(w_t, ŵ) using the current w_{t-1}
# Avoids explicit marginalization over all possible w_{t-1} using `vocab_size` loop iterations
p_w_t_w_hat = p_w_t_given_w_t_minus_1 * p_w_hat_given_w_t_w_t_minus_1
print(f"p_w_t_w_hat = {p_w_t_w_hat}")
print(f"p_w_t_minus_1_w_t_w_hat.shape = {p_w_t_minus_1_w_t_w_hat.shape}, p_w_t_w_hat.shape = {p_w_t_w_hat.shape}")
# Compute p(w_{t-1}|w_t, ŵ) using Bayes' rule
p_w_t_minus_1_given_w_t_w_hat = p_w_t_minus_1_w_t_w_hat / p_w_t_w_hat
# transition probability
return p_w_t_minus_1_given_w_t_w_hat
# KL divergence function using one-hot encoded target
# See Algorithm 2 of [Protein Design with Guided Discrete Diffusion](http://arxiv.org/abs/2305.20009)
def kl_div_func(model_output):
#target_distribution = F.one_hot(train_data_correct, num_classes=tokenizer.vocab_size).float()
#target_distribution = torch.rand_like(model_output) / model_output.size(-1)
p_h = torch.zeros_like(y)
with torch.no_grad(): # for saving RAM memory consumption
for w in range(tokenizer.vocab_size):
# here, model is the denoiser
p_transition = compute_transition_probability(y, w, model_output, denoiser)
print(f"p_transition = {p_transition}")
# Calculate p_h
p_h += p_transition * model_output
# Add a small epsilon to avoid log(0) in KL divergence calculation
epsilon = 1e-8
p_h = p_h + epsilon
# Normalize to ensure sum equals 1 such that p_h is a valid probability distribution
p_h = p_h / p_h.sum(dim=-1, keepdim=True)
# Ensure model_output is also a valid probability distribution
model_output = F.softmax(model_output, dim=-1)
#print(f"model_output.shape = {model_output.shape}, p_h.shape = {p_h.shape}")
# Re-enable gradients for the KL divergence computation
with torch.enable_grad():
# the following is semantically similar to kl_div(current_prob.log(), previous_prob)
return F.kl_div(model_output.log(), p_h, reduction='batchmean')
# gradients for a Gaussian likelihood with prior N(0, I)
def grad_U_theta(ebm, theta, X, mlm_mask=None):
"""
Gradient of U with respect to theta.
Args:
ebm: BERT-based EBM (computes -log p(y|x)).
theta: Tokenized sequences (Tensor, shape [batch_size, seq_length]).
X: Latent particles (Tensor, shape [N_particles, batch_size, seq_length]).
mlm_mask: Optional mask for MLM tasks (Tensor, shape [batch_size, seq_length]).
Returns:
Gradient w.r.t theta (Tensor, shape [batch_size, seq_length]).
"""
# Ensure theta requires gradients
theta = theta.requires_grad_(True)
# Initialize gradient accumulator
grad_sum = 0
# Iterate over particles
for x in X:
# Compute energy for this particle
#energy = ebm(theta, mlm_mask=mlm_mask).sum()
# Compute gradient w.r.t theta for this particle
grad = torch.autograd.functional.jacobian(energy_func, theta)
'''
grad = torch.autograd.grad(
outputs=energy,
inputs=theta,
create_graph=True
)[0]
'''
# Accumulate gradient
grad_sum += grad
# Average the gradients over all particles
grad_theta = grad_sum / X.size(0)
return grad_theta
def prior_grad(X):
"""
Gradient of the log prior w.r.t X.
Args:
X: Latent particles (Tensor, shape [N_particles, batch_size, seq_length]).
Returns:
Gradient w.r.t X (Tensor, same shape as X).
"""
return -X
def grad_U_X(ebm, X, mlm_mask=None, prior_grad=None):
"""
Gradient of U with respect to X.
Args:
ebm: BERT-based EBM (computes -log p(y|x)).
X: Latent particles (Tensor, shape [N_particles, batch_size, seq_length]).
mlm_mask: Optional mask for MLM tasks (Tensor, shape [batch_size, seq_length]).
prior_grad: Optional function to compute gradient of log prior w.r.t X.
Returns:
Gradient w.r.t X (Tensor, same shape as X).
"""
# Ensure X requires gradients
X = X.requires_grad_(True)
# Initialize gradient accumulator
grad_sum = 0
# Iterate over particles
for x in X:
# Compute energy for this particle
#energy = ebm(x, mlm_mask=mlm_mask).sum()
#print(f"in grad_U_X, x has a shape of {x.shape}")
# Compute gradient w.r.t X for this particle
grad = torch.autograd.functional.jacobian(energy_func, x)
'''
grad = torch.autograd.grad(
outputs=energy,
inputs=x,
create_graph=True
)[0]
'''
# Accumulate gradient
grad_sum += grad
# Average the gradients over all particles
grad_X = grad_sum / X.size(0)
# Add prior gradient if specified,
# In the case of NLP domain, it is hard to define prior grad when it is not gaussian-distributed.
if prior_grad is not None:
grad_X += prior_grad(X)
return grad_X
y.requires_grad_(True)
#print(f"y shape: {y.shape}, y grad: {y.requires_grad}")
batch_size, particle_dim = y.shape # reusing the shape dimension of tensor y for particles X
# Initialize variables for OBABO
#X = torch.stack([y.clone() for _ in range(N_particles)], dim=0) # this is variable X in KIPLMC2
theta = torch.rand_like(y, device=device, requires_grad=True) * (tokenizer.vocab_size-1) # θ (parameter)
v_theta = torch.randn_like(theta, device=device) # V^θ_0 (velocity for θ)
X = torch.rand(N_particles, batch_size, particle_dim, device=device, requires_grad=True) * (tokenizer.vocab_size-1) # X^i (particles)
v_X = torch.randn_like(X, device=device) # V^{X,i}_0 (velocities for X)
for k in range(K):
# Compute gradients
if USE_OBABO:
grad_theta = grad_U_theta(ebm=ebm, theta=y, X=X) # Gradient w.r.t theta
grad_X = grad_U_X(ebm=ebm, X=X) # Gradient w.r.t X
else:
if USE_dWJS_ENERGY:
# Compute gradients
#print(f"torch.autograd.grad(energy, y, allow_unused=True) = {torch.autograd.grad(energy, y, allow_unused=True)}")
#grad_energy = torch.autograd.grad(energy, y)[0]
#grad_energy = y.grad
grad_energy = torch.autograd.functional.jacobian(energy_func, y)
total_grad = grad_energy
else: # USE_dWJS_SCORE
with torch.no_grad(): # for saving RAM memory consumption
# energy and score are related by a derivative
# score = (denoised - inputs) / (sigma ** 2) = ∇log p(y) = -∇f(y) = -∇energy = -∇ebm
x_hat, _, _ = denoiser_model(y)
score = (x_hat - y) / (sigma**2) # This is ∇log p(y)
total_grad = -1 * score
if USE_GRAD_KL:
with torch.no_grad(): # for saving RAM memory consumption
model_output, _, _ = denoiser_model(y)
grad_kl = torch.autograd.functional.jacobian(kl_div_func, model_output)
# Combine gradients, we subtract grad_kl because we want to always minimize KL divergence
#print(f"grad_energy.shape = {grad_energy.shape}, grad_kl.shape = {grad_kl.shape}")
total_grad = total_grad - grad_kl
grad_theta = total_grad
grad_X = total_grad
# See equation 6 of [Rational Construction of Stochastic Numerical Methods for Molecular Sampling](https://arxiv.org/abs/1203.5428)
# "To slightly simplify the presentation that follows, we make the change of variables q -> M^(−1/2) * q, p -> M^(+1/2) * p,
# with a corresponding adjustment of the potential; this is equivalent to assuming M = I"
# q is equivalent to y, p is equivalent to v, M is equivalent to u
# Besides, the following equation v is similar in form to langevin dynamic recursive equation (in discrete domain) : x[t+1] = x[t] + τsθ(x[t]) + sqrt(2τ)z ,
# where τ is the step_size, and z is the eps (gaussian white-noise)
# there is exp() inside the following equation v is because it is the result of solving the recursive equation using https://en.wikipedia.org/wiki/Magnus_expansion
# after it is being rearranged as differential equation (in continous domain)
eps = torch.randn_like(y)
# Sample alpha from [0, 1]
alpha = torch.rand(1).to(device) # This corresponds to the random midpoint step
# KIPLMC2 OBABO step (O+B): First velocity update with noise and friction
v_theta = torch.exp(-gamma * step_size) * v_theta - u * step_size * torch.exp(-2 * gamma * (step_size - alpha * step_size)) * grad_theta + torch.sqrt(u * (1 - torch.exp(-2 * gamma * step_size))) * eps
v_X = torch.exp(-gamma * step_size) * v_X - u * step_size * torch.exp(-2 * gamma * (step_size - alpha * step_size)) * grad_X + torch.sqrt(u * (1 - torch.exp(-2 * gamma * step_size))) * eps
# equation (4) of [The Randomized Midpoint Method for Log-Concave Sampling](https://arxiv.org/abs/1909.05503) uses x
# while KIPLMC2 from [Kinetic Interacting Particle Langevin Monte Carlo](http://arxiv.org/abs/2407.05790) uses θ, but here we use symbol y instead
# we are also not using the midpoint n+1/2 method due to extra compute logic for gradients
# KIPLMC2 OBABO step (A): Update positions θ and X^i
y = y + (step_size / 2) * v_theta
X = X + (step_size / 2) * v_X
#print(f"y.max() = {y.max()}, y.min() = {y.min()}")
#print(f"u = {u} , step_size = {step_size}")
#v = v + u * (step_size / 2) * g # only needed in walk-jump paper Algorithm 4, but not in KIPLMC2 OBABO
# We are using underdamped langevin dynamics, see equations (4) and (5) of [The Randomized Midpoint Method for Log-Concave Sampling](https://arxiv.org/abs/1909.05503)
# If we remove the first term below, then it will become overdamped langevin dynamics.
# This is also similar to the KIPLMC1 (which uses Exponential Integrators) approach in [Kinetic Interacting Particle Langevin Monte Carlo](http://arxiv.org/abs/2407.05790)
# KIPLMC2 OBABO step (B+O): Second velocity update with noise and gradients
v_theta = torch.exp(-gamma * step_size) * v_theta - u * step_size * torch.exp(-2 * gamma * (step_size - alpha * step_size)) * grad_theta + torch.sqrt(u * (1 - torch.exp(-2 * gamma * step_size))) * eps
v_X = torch.exp(-gamma * step_size) * v_X - u * step_size * torch.exp(-2 * gamma * (step_size - alpha * step_size)) * grad_X + torch.sqrt(u * (1 - torch.exp(-2 * gamma * step_size))) * eps
# Detaching during the walk phase ensures that the NN model is used as a fixed, pretrained guide
# for sample generation, preventing any unintended parameter updates.
return y.detach()
def langevin_mcmc_step(y, model, step_size):
y.requires_grad_(True)
energy = model(y).sum()
energy.backward(retain_graph=True)
#print(f"torch.autograd.grad(energy, y, allow_unused=True) = {torch.autograd.grad(energy, y, allow_unused=True)}")
grad = torch.autograd.grad(energy, y)[0]
#grad = y.grad
y_next = y - step_size * grad + torch.sqrt(torch.tensor(2 * step_size, device=y.device)) * torch.randn_like(y)
return y_next.detach()
"""
This implementation below follows the stabilization mechanism described in diffusion forcing paper:
1. The input "sequence" are tokens fully diffused to the maximum noise level (sigma_max).
2. It then denoises tokens one by one, starting from the first token.
3. For each token, it gradually reduces the noise level from sigma_max to sigma_min.
4. When moving to the next token, it treats the previously denoised tokens as slightly noisy ground truth by adding a small amount of noise (sigma_max / M).
5. For subsequent tokens, it ensures that the noise level is at least as high as the noise level of the previously denoised tokens.
This approach should help prevent the accumulation of single-step errors in autoregressive sampling by treating predicted tokens as noisy ground truth, rather than perfect observations.
"""
def apply_noise(sequence, M, sigma_min, sigma_max):
# M is the number of denoising (or jump) steps
seq_len = len(sequence)
noisy_seq = sequence.clone()
# Gradually denoise tokens while implementing the stabilization mechanism
for t in range(seq_len):
#print(f"t = {t}")
for m in range(M-1, -1, -1): # Start from highest noise and decrease
#print(f"m = {m}")
# Use a slightly higher noise level for previously denoised tokens
# Decreases noise level gradually as denoising takes place
noise_level = sigma_max - (sigma_max - sigma_min) * (m / (M-1)) * ((seq_len - t) / seq_len)
#print(f"noise_level = {noise_level}")
# Apply denoising step
noisy_seq[t] = sequence[t] + torch.randn_like(sequence[t]) * noise_level
# After fully denoising a token, slightly increase its noise level for stability
if t < seq_len - 1 and m == 0:
noisy_seq[t] += torch.randn_like(sequence[t]) * (sigma_min + (sigma_max - sigma_min) / M)
return noisy_seq
def walk_jump_sampling(init_y, mlm_mask, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask, target_label=None):
if isinstance(init_y, dict):
y = init_y['input_ids']
else:
y = init_y
batch_size, seq_length = y.size()
"""
1. Initialization of tgt:
tgt starts as a sequence of [MASK] tokens, which will be updated as the model generates new tokens.
2. Using the Transformer Decoder:
In each step of walk_jump_sampling, the denoiser now takes both noisy_y (as src) and tgt. The decoder will use the context from noisy_y to generate predictions for tgt.
3. Updating tgt:
After each step, a new token is generated (using denoised_y), and this token is added to tgt. This way, tgt grows in length during each iteration, effectively generating the sequence step by step.
4. Sampling Process:
The loop continues, gradually building the output sequence one token at a time. The use of tgt allows the model to generate sequences dynamically, making decisions based on the tokens generated so far.
"""
if ENABLE_MASK_LEARNING:
# Initialize tgt with just the CLS (to indicate beginning of sequence) token for custom transformer using BERT tokenizer
# CLS [sequence 1] SEP [sequence 2] SEP
#tgt = torch.full((batch_size, 1), tokenizer.cls_token_id, dtype=torch.long, device=device) # for inference, but not for training
# Determine the start token ID based on tokenizer and model
if USE_PRETRAINED_T5:
start_token_id = tokenizer.pad_token_id # T5 uses pad_token_id as start token
else:
start_token_id = tokenizer.cls_token_id # BERT uses cls_token_id as start token
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM: # BERT models do not have decoders, so no need 'tgt'
tgt = None
else:
# Shift tgt to the right
tgt = denoiser._shift_right(y, start_token_id)
else:
tgt = None
# Define the step size schedule function
def get_step_size(t, initial_step_size=1e-3, gamma=0.55):
return initial_step_size / (t + 1) ** gamma
# Preparation steps for walk stage:
# NLP token sampling search space is huge with tokenizer.vocab_size = 30522
# Protein token sampling search space is only 20
# So, we might have to temporarily disable the following rand_like() operation for now
#y = torch.rand_like(y) * tokenizer.vocab_size + torch.randn_like(y) * sigma
if mlm_mask is not None and (USE_SMC or USE_MAFBM or USE_MCMC): # non-masked language modeling task
# Add more noise to masked positions
#y[mlm_mask] = torch.rand_like(y[mlm_mask]) * tokenizer.vocab_size + torch.randn_like(y[mlm_mask]) * sigma
# Add less noise to unmasked positions
#y[~mlm_mask] = y[~mlm_mask] + torch.randn_like(y[~mlm_mask]) * (sigma * 0.1)
#y = y + torch.randn_like(y) * sigma
pass
v = torch.randn_like(y) # Initialize velocity from a standard normal distribution
#v = torch.zeros_like(y) # Initialize velocity
# walk then jump, so num_walk_steps == num_jump_steps
for t in range(num_walk_steps):
# walk stage (sampling process is guided by using EBM)
if USE_SMC: # see [Probabilistic Inference in Language Models via Twisted Sequential Monte Carlo](https://arxiv.org/abs/2404.17546)
# Initialize particles
particles = [y for _ in range(N_particles)]
for t in range(num_smc_steps):
# Prediction Step: Propose new particles
particles = [proposal(particle) for particle in particles]
# Weighting Step: Compute weights based on EBM energy
weights = compute_weights(particles, ebm)
#print(f"weights shape: {weights.shape}") # Shape: [batch_size, N_particles]
# Resampling Step: Resample particles based on weights
particles = resample(particles, weights)
# After SMC steps, select the particle with the highest weight for each batch item
# Vectorized implementation
batch_size = particles[0].size(0)
# Stack particles into a tensor: [N_particles, batch_size, seq_length]
particles_tensor = torch.stack(particles, dim=0) # [N_particles, batch_size, seq_length]
# Permute to [batch_size, N_particles, seq_length]
particles_tensor = particles_tensor.permute(1, 0, 2) # [batch_size, N_particles, seq_length]
# Compute the indices of the best particles for each batch item
best_particle_indices = torch.argmax(weights, dim=1) # [batch_size]
# Prepare batch indices for advanced indexing
batch_indices = torch.arange(batch_size, device=weights.device) # [batch_size]
# Select the final particles using advanced indexing
final_particles = particles_tensor[batch_indices, best_particle_indices, :] # [batch_size, seq_length]
# Update y with the selected final particles
y = final_particles # [batch_size, seq_length]
elif USE_MAFBM: # see [Generative Fractional Diffusion Models](http://arxiv.org/abs/2310.17638)
# Create and simulate MA-fBM
if t == 0:
ma_fbm = MA_fBM(hurst, T=T_fbm, n_steps=input_dim, K=K_fbm)
times, path, ou_processes = ma_fbm.simulate(return_processes=True)
# Print shapes to understand the dimensions
#print(f"Initial path shape: {path.shape}") # Should be [n_steps]
#print(f"Initial ou_processes shape: {ou_processes.shape}") # Should be [K, n_steps]
# Reshape path: [n_steps] -> [batch_size, n_steps]
path = path.unsqueeze(0).expand(batch_size, -1)
# Add MA-fBM contribution to y: [batch_size, seq_length]
y = y + walk_step_size * sigma * path
# saves memory
del times
del path
del ou_processes
elif USE_MCMC: # see [Provable Benefit of Annealed Langevin Monte Carlo for Non-log-concave Sampling](http://arxiv.org/abs/2407.16936)
# Annealing process such that step size decays across time, ensuring convergence
current_walk_step_size = get_step_size(t, initial_step_size=walk_step_size)
if USE_ALGORITHM_1_OR_4:
y = langevin_mcmc_step(y, ebm, current_walk_step_size) # Langevin MCMC sampling
else:
y = langevin_mcmc_step_advanced(y, v, mlm_mask, input_pad_mask, ebm, denoiser, current_walk_step_size) # Update using advanced Langevin dynamics
#print(f"USE_MCMC, y = {y}")
assert not y.isnan().any(), "mcmc is giving NaN output !!!"
else:
# Walk sampling stage (which serves the purpose of forward noising) is only needed in
# Image Denoising: Removing noise from images corrupted by Gaussian noise.
# Text Denoising: Correcting sentences with randomly inserted, deleted, or swapped words.
pass # no need of walk sampling stage for masked language/image model downstream task
# jump stage
if USE_PRECOMPUTE_NOISE_SCHEDULE:
# ONLY works for static pre-compute noise schedule with fixed input sequence length
sigma_t = noise_schedule[t]
else:
sigma_t = sigma # there is only a single denoising level for walk-jump equation
if ADD_EXTRA_GAUSSIAN_NOISE:
# Add noise for reverse denoising process (this step might be optional since we have a denoiser() NN module further down)
if USE_DIFFUSION_FORCING:
# Add noise level according to diffusion forcing scheme
noisy_y = apply_noise(y, num_jump_steps, sigma_min, sigma_max)
else:
# Add noise level according to full sequence diffusion scheme
noise = torch.randn_like(y) * sigma_t
noisy_y = y + noise
else:
noisy_y = y # for isolating the extra optional step just above
assert not noisy_y.isnan().any(), "noisy_y is giving NaN output !!!"
if (USE_SMC or USE_MAFBM or USE_MCMC): # masked language modeling task
# Scales to range of [0, tokenizer.vocab_size-1]
noisy_y = noisy_y - noisy_y.min()
noisy_y = noisy_y / noisy_y.max()
noisy_y = noisy_y * (tokenizer.vocab_size - 1)
assert not noisy_y.isnan().any(), "noisy_y is giving NaN output !!!"
# checks for potential issues of NLP model's input range
#print(f"noisy_y.max() = {noisy_y.max()}, noisy_y.min() = {noisy_y.min()}")
assert_sample_range_compliance(noisy_y, tokenizer)
if isinstance(init_y, dict):
# put the noised 'y' back into the dictionary
init_y['input_ids'] = noisy_y
noisy_y = init_y
# for the purpose of learning to denoised masked token, this is often used in masked language model (MLM)
denoised_sentence, denoised_masked_token_logits, denoised_token_logits = denoiser_model(noisy_y, mlm_mask, target_label, tgt, input_pad_mask)
# denoised_sentence has a shape of [batch_size, src_sequence_length]
#print(f"denoised_sentence.shape = {denoised_sentence.shape}")
#print(f"denoised_masked_token_logits.shape = {denoised_masked_token_logits.shape}")
if USE_LOGITS_FOR_DENOISING:
denoised_y = denoised_token_logits
else:
denoised_y = denoised_sentence
#print(f"y.shape = {y.shape} , noisy_y.shape = {noisy_y.shape} , denoised_y.shape = {denoised_y.shape}")
if not(USE_SMC or USE_MAFBM or USE_MCMC): # masked language modeling task
y = denoised_y # DIRECT denoiser denoising, no need any specific denoising equation
else:
# See section 2.4 or Algorithm 1 of [Discrete Flow Matching](http://arxiv.org/abs/2407.15595)
if USE_LOGITS_FOR_DENOISING:
if isinstance(noisy_y, dict):
noisy_y = noisy_y['input_ids']
else:
noisy_y = noisy_y
# Get the shapes
batch_size, sequence_length, vocab_size = denoised_y.shape
# Create embedding projection that matches vocabulary size
embedding_proj = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=vocab_size,
padding_idx=tokenizer.pad_token_id,
device=device
)
# Both delta_X_t and u_t are of the shape of [batch_size, sequence_length, vocab_size]
delta_X_t = embedding_proj(noisy_y.long())
u_t = denoised_y
h = walk_step_size
# Update rule for equation (12): delta + h * velocity
prob_distribution = delta_X_t + h * u_t
prob_distribution = F.softmax(prob_distribution, dim=-1) # Normalize to a valid probability distribution
# Sampling the next state
X_t_h = torch.multinomial(prob_distribution.view(-1, vocab_size), 1).view(batch_size, sequence_length)
y = X_t_h
#print(y)
# Output: Tensor of shape [batch_size, sequence_length] representing sampled token indices
else:
y = y + sigma_t ** 2 * denoised_y # Update based on denoising equation in walk-jump
# Scales to range of [0, tokenizer.vocab_size-1]
y = y - y.min()
y = y / y.max()
y = y * (tokenizer.vocab_size - 1)
# checks for potential issues
#print(f"denoised_y.max() = {denoised_y.max()}, denoised_y.min() = {denoised_y.min()}")
assert_sample_range_compliance(y, tokenizer)
# we only runs this during inference, we will instead run "_shift_right()" during training
if mlm_mask is None and tgt is not None:
# Get the last token from the denoised sequence
new_token = denoised_y[:, -1].unsqueeze(-1) # shape [batch_size, 1]
#print(f"tgt.shape = {tgt.shape} , tgt = {tgt} , new_token.shape = {new_token.shape} , new_token = {new_token}")
# Step 1: Normalize the model output between (0, 1)
new_token_min = new_token.min() # Get the minimum value
new_token_max = new_token.max() # Get the maximum value
normalized_output = (new_token - new_token_min) / (new_token_max - new_token_min + 1e-8)
# Step 2: Rescale to the desired range (0, vocab_size)
rescaled_output = normalized_output * tokenizer.vocab_size
# Ensure values stay within bounds of (0, vocab_size-1)
new_token = rescaled_output.clamp(0, tokenizer.vocab_size-1)
# For each subsequent iteration, concatenate the new token to tgt
tgt = torch.cat((tgt, new_token), dim=1)
# Check for end token
if USE_PRETRAINED_T5: #or USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER:
# T5 uses a different end-of-sequence token and does not use a separate SEP token
if (new_token == tokenizer.eos_token_id).all():
break
else:
# For other models, check for both EOS and SEP tokens
if hasattr(tokenizer, 'sep_token_id'):
if (new_token == tokenizer.sep_token_id).all():
break
elif hasattr(tokenizer, 'eos_token_id'):
if (new_token == tokenizer.eos_token_id).all():
break
else:
pass # nothing happens
denoised_sentence = y
if isinstance(noisy_y, dict):
noisy_ids = noisy_y['input_ids']
else:
noisy_ids = noisy_y
if mlm_mask is not None:
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
return noisy_ids, denoised_sentence, denoised_masked_token_logits, denoised_token_logits
else:
return noisy_ids, denoised_sentence, denoised_masked_token_logits
else:
return noisy_ids, denoised_sentence
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, smoothing=0.1):
super().__init__()
self.smoothing = smoothing
def forward(self, x, target):
confidence = 1.0 - self.smoothing
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
# Use this loss function instead of nn.CrossEntropyLoss
smooth_CE_loss_fn = LabelSmoothingCrossEntropy(smoothing=0.1)
def train_walk_jump(ebm, denoiser, train_loader, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, optimizer_ebm, optimizer_denoiser, scheduler_ebm, scheduler_denoiser, scaler):
ebm.train()
denoiser.train()
pad_token_id = tokenizer.pad_token_id
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
for train_data in train_loader:
# Clear memory before processing each batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch.mps, 'empty_cache'): # Check if MPS backend exists
torch.mps.empty_cache()
#if isinstance(train_data, dict):
if MASK_RATIO != -1: # will be using data_collator which returns dict format
input_ids = train_data['input_ids'].to(device)
target_label = train_data['labels'].to(device) # this is the target for the masked tokens in which the model should predict to unmask
else:
input_ids = train_data['input_ids'].to(device)
target_label = input_ids.clone()
if ENABLE_MASK_LEARNING:
if MASK_RATIO != -1:
# We are now using data collator class to deal with masking strategy
mask = train_data['mask_indices'].to(device) # Get mask from data collator
# Randomly mask some tokens in the clean train_data
#mask = torch.rand(train_data.shape).to(device) < MASK_RATIO # 15% masking probability
mask = mask * 1 # converts True/False into 1/0
else:
# Randomly mask only 1 single token in the clean train_data
batch_size, seq_len = train_data['attention_mask'].shape
mask = torch.zeros_like(train_data['attention_mask'], dtype=torch.bool).to(device)
# Create a boolean mask for non-pad and non-special tokens
# (True where tokens are real, and False for [PAD], [CLS], [SEP])
non_special_tokens = (train_data['input_ids'] != pad_token_id) & (train_data['input_ids'] != cls_token_id) & (train_data['input_ids'] != sep_token_id)
#print(f"non_special_tokens has a shape of {non_special_tokens.shape}")
# Sum along the sequence dimension to get the actual length of each sequence (excluding special tokens)
actual_seq_len = non_special_tokens.sum(dim=1)
#print(f"actual_seq_len has a shape of {actual_seq_len.shape}")
# For each sequence in the batch, randomly mask one token
# random indexing starts from 1 since we do not want to index the first [CLS] token in each training sequence
# Random index per sequence in the batch
random_indices = torch.stack([torch.randint(1, length.item(), (1,)) for length in actual_seq_len]).squeeze()
#print(f"random_indices = {random_indices}, random_indices.shape = {random_indices.shape}")
# Mask the selected tokens at the random indices
mask[torch.arange(batch_size), random_indices] = 1
mask = mask * 1 # converts True/False into 1/0
assert(mask.sum() == batch_size) # shape : [batch_size, seq_len] , so only 1 masked token for each sequence
#print(f"mask = {mask}")
#if not(MASK_RATIO != -1): # will not be using data_collator which returns dict format
# Set non-masked positions in target_label to -100 or tokenizer.pad_token_id
# See the use of ignore_index in https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
#target_label[~mask.bool()] = CONSTANTS_VALUE_IGNORE
input_ids = input_ids.float().to(device)
if ADD_EXTRA_GAUSSIAN_NOISE:
noisy_train_data = add_token_noise(input_ids, tokenizer, noise_level=sigma_min).to(device) # Or your chosen noising function
else:
noisy_train_data = input_ids # for isolating the extra noise added
if MASK_RATIO != -1:
if USE_PRETRAINED_T5:
# We should not put '<extra_id_0>' for all masked tokens, they should be numbered accordingly as in '<extra_id_*>' to indicate ordering
# We had used the data_collator to prepare train_loader, so no need to manually modify
masked_train_ids = noisy_train_data.clone() # we want to denoise noisy data to its clean version
else:
# we are now using data_collator for masking purpose, see _mask_tokens_standard()
# We had used the data_collator to prepare train_loader, so no need to manually modify
masked_train_ids = noisy_train_data.clone() # we want to denoise noisy data to its clean version
masked_train_ids[mask.bool()] = tokenizer.mask_token_id
else:
if USE_PRETRAINED_T5:
masked_train_ids = noisy_train_data.clone() # we want to denoise noisy data to its clean version
masked_train_ids[mask.bool()] = tokenizer.convert_tokens_to_ids('<extra_id_0>')
else:
masked_train_ids = noisy_train_data.clone() # we want to denoise noisy data to its clean version
masked_train_ids[mask.bool()] = tokenizer.mask_token_id
#print(f"Masked train ids = {masked_train_ids}")
#if MASK_RATIO == 0.00:
# for testing purpose only
#assert(torch.equal(train_data, masked_train_ids))
# ebm model and denoiser model are trained independently, no gradient connections between them
masked_train_data = {
'input_ids': masked_train_ids,
'labels': target_label,
'attention_mask': train_data['attention_mask'].clone().detach() if 'attention_mask' in train_data else None
}
# Train EBM
optimizer_ebm.zero_grad()
# Train denoiser
optimizer_denoiser.zero_grad()
# Add noise to input data
noisy_train_ids = train_data['input_ids'].float() + torch.randn_like(train_data['input_ids'].float()) * sigma
# Scales to range of [0, tokenizer.vocab_size-1]
noisy_train_ids = noisy_train_ids - noisy_train_ids.min()
noisy_train_ids = noisy_train_ids / noisy_train_ids.max()
noisy_train_ids = noisy_train_ids * (tokenizer.vocab_size - 1)
# checks for potential issues
assert_sample_range_compliance(noisy_train_ids, tokenizer)
if USE_MIXED_PRECISION_TRAINING:
with autocast(device_type=device_str, dtype=torch.float16):
# Get energy of noisy input
energy_real = ebm({
'input_ids': noisy_train_ids,
'labels': target_label,
'attention_mask': train_data['attention_mask']
})
else:
# Get energy of noisy input
energy_real = ebm({
'input_ids': noisy_train_ids,
'labels': target_label,
'attention_mask': train_data['attention_mask']
})
#print(f"energy_real has a shape of {energy_real.shape}")
assert not torch.all(energy_real == 0), "Error: energy_real contains all zeros!"
# for the purpose of more efficient run for the denoiser model
input_pad_mask = (input_ids == tokenizer.pad_token_id).to(device)
if ENABLE_MASK_LEARNING:
# Generate samples using walk-jump
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
noisy_ids, generated_samples, denoised_masked_token_logits, denoised_token_logits = walk_jump_sampling(masked_train_data, mask, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask, target_label=target_label)
else:
noisy_ids, generated_samples, denoised_masked_token_logits = walk_jump_sampling(masked_train_data, mask, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask, target_label=target_label)
else:
# Generate samples using walk-jump
noisy_ids, generated_samples = walk_jump_sampling(train_data, None, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask, target_label=target_label)
# Analyzing the generated samples during training
if ENABLE_SAMPLE_ANALYSIS:
analyze_samples(generated_samples, tokenizer)
# checks for potential issues
assert_sample_range_compliance(noisy_ids, tokenizer)
# EBM module favours lower energy for the real data from the distribution, and higher energy for fake data from the sampling process
if USE_MIXED_PRECISION_TRAINING:
with autocast(device_type=device_str, dtype=torch.float16):
#energy_fake = ebm(generated_samples)
energy_fake = ebm({
'input_ids': noisy_ids, #train_data['input_ids'],
'labels': target_label,
'attention_mask': train_data['attention_mask']
})
else:
#energy_fake = ebm(generated_samples)
energy_fake = ebm({
'input_ids': noisy_ids, #train_data['input_ids'],
'labels': target_label,
'attention_mask': train_data['attention_mask']
})
#print(f"energy_fake has a shape of {energy_fake.shape}")
assert not torch.all(energy_fake == 0), "Error: energy_fake contains all zeros!"
'''
log(q(x_real) / q(x_fake)) = log( [exp(-E(x_real)) / Z] / [exp(-E(x_fake)) / Z] )
= log( exp(-E(x_real)) / exp(-E(x_fake)) ) (Z cancels out)
= -E(x_real) + E(x_fake)
Z cancels out is a simplification based on the assumption that Z is approximately the same when calculated for both real (ground truth) and fake (generated) data.
To maximize the log-likelihood ratio, we minimize its negative: -log(q(x_real) / q(x_fake)) = E(x_real) - E(x_fake) = energy_real.mean() - energy_fake.mean()
However, if we do not assume that Z cancels out due to imperfect MCMC sampling,
-log(q(x)) = -log(exp(-E(x)) / Z)
= -log(exp(-E(x))) + log(Z)
= E(x) + log(Z)
We've established that -log(q(x)) ≈ E(x) + log_sum_exp(-energy).
For real data, we can write: -log(q(x_real)) ≈ E(x_real) + log_sum_exp(-energy_real).
For fake data, we can write: -log(q(x_fake)) ≈ E(x_fake) + log_sum_exp(-energy_fake).
Constructing the Loss Function:
Goal: We want to minimize -log(q(x_real)) (make real data probable) and maximize -log(q(x_fake)) (make fake data improbable).
Using the Approximation: We can approximate this by minimizing E(x_real) + log_sum_exp(-energy_real) and maximizing E(x_fake) + log_sum_exp(-energy_fake).
Combining Terms: To achieve this with a single loss function, we can take the negative of the term we want to minimize and add it to the term we want to maximize:
We had done this on -log(q(x)), so directly we have the following contrastive loss:
loss = [E(x_real) + log_sum_exp(-energy_real)] - [E(x_fake) + log_sum_exp(-energy_fake)]
Simplifying: We can rearrange this as:
loss = E(x_real) + log_sum_exp(-energy_real) - E(x_fake) - log_sum_exp(-energy_fake)
Averaging: In practice, we work with batches of data, so we take the mean over the data points in the batch:
loss_ebm = E(x_real).mean() + log_sum_exp(-energy_real).mean() - E(x_fake).mean() - log_sum_exp(-energy_fake).mean()
Theoretical E(x): In the theoretical derivations of EBMs and contrastive divergence, E(x) represents the energy function. Lower energy corresponds to higher probability.
Code Implementation - ebm() function: In the code, the ebm() function (or EnergyBasedModel) is implemented to compute the negative log-probability of a given input, up to a constant.
This is because we want to use gradient-based optimization to minimize this value for real data.
So, what the ebm() model outputs is not directly E(x), but rather something proportional to -log(q(x)), which in turn is approximately E(x) + log(Z) which we have derived in the text above.
Given that:
energy_real = ebm(noisy_train_ids)
energy_fake = ebm(noisy_ids)
so we have the following final loss_ebm expression:
loss_ebm = energy_real.mean() + log_sum_exp(-energy_real).mean() - energy_fake.mean() - log_sum_exp(-energy_fake).mean()
'''
# Compute EBM loss with contrastive divergence
#loss_ebm = (energy_real.mean() - energy_fake.mean()) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
# Compute EBM loss with contrastive divergence, log-sum-exp trick and offset
energies = torch.cat([energy_real, energy_fake]) # Concatenate energy_real and energy_fake
mean_energy = torch.mean(energies)
#loss_ebm = log_sum_exp(-(energy_real-mean_energy)) - log_sum_exp(-(energy_fake-mean_energy)) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
loss_ebm = energy_real.mean() + log_sum_exp(-energy_real).mean() - energy_fake.mean() - log_sum_exp(-energy_fake).mean() + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
if USE_MIXED_PRECISION_TRAINING and torch.cuda.is_available():
scaler.scale(loss_ebm).backward(retain_graph=True)
#nn.utils.clip_grad_norm_(ebm.parameters(), max_norm=10.0) # Gradient clipping
#check_for_vanishing_gradients(ebm) # Check for vanishing gradients
scaler.step(optimizer_ebm)
scaler.update()
else:
loss_ebm.backward(retain_graph=True)
#nn.utils.clip_grad_norm_(ebm.parameters(), max_norm=10.0) # Gradient clipping
#check_for_vanishing_gradients(ebm) # Check for vanishing gradients
optimizer_ebm.step()
if ENABLE_MASK_LEARNING:
#print(f"mask = {mask}")
unmask = 1 - mask # Inverse of mask
#print(f"unmask = {unmask}")
"""
Cross-Entropy Loss:
1. Why it’s preferred for masked language modeling:
Cross-entropy loss is widely used for classification tasks, including masked language modeling. MLM is essentially a classification problem where the model predicts a token from a discrete set of possibilities (the vocabulary) for each masked position. Cross-entropy loss measures how well the predicted probability distribution over the vocabulary matches the true distribution (typically a one-hot vector where the correct token has a probability of 1).
2. How it works:
Cross-entropy loss penalizes the model when the predicted probability of the correct token is low. It compares the predicted probability distribution with the true distribution and calculates the logarithmic loss, which is then averaged over all predictions. This loss function is effective for tasks where the outputs are discrete and categorical, such as predicting words or tokens in natural language processing tasks.
MSE Loss:
1. Why it’s less suitable:
MSE loss is more appropriate for regression tasks, where the goal is to predict continuous values. In the context of language modeling, using MSE would treat the token IDs as continuous values and penalize the squared differences between predicted and true token IDs. This approach doesn’t align well with the nature of language, where the relationship between token IDs is not linear or continuous.
2. Why it’s inappropriate for MLM:
Token IDs in a vocabulary do not have a meaningful numerical relationship to each other (e.g., the token ID for “apple” being 103 and “banana” being 104 does not mean they are numerically close in meaning). MSE loss would incorrectly interpret these IDs as continuous variables and could lead to suboptimal training because it doesn’t capture the categorical nature of the task.
"""
# Compute the loss for unmasked positions
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
# Compute CrossEntropy loss directly between denoised_token_logits and the correct tokens
#print(f"denoised_token_logits.shape = {denoised_token_logits.shape}, train_data.shape = {train_data.shape}")
target = train_data['input_ids'].long()
#print(f"target = {analyze_samples(target, tokenizer, skip_special_tokens=True)}")
loss_denoiser = nn.CrossEntropyLoss(ignore_index=-100)(denoised_token_logits.view(-1, denoised_token_logits.size(-1)), target.view(-1))
else:
# Compute the loss for masked positions
loss_masked = 0.00
if MASK_RATIO == -1: # loss computation in the case of a single masked token
# Assuming that 'mask' is a tensor of shape [batch_size, seq_len] where True indicates a masked position
train_data_correct = train_data[mask.bool()].long()
print(f"generated_samples[mask.bool()] = {analyze_samples(generated_samples[mask.bool()], tokenizer, skip_special_tokens=False)}, shape: {generated_samples[mask.bool()].shape}, dtype: {generated_samples[mask.bool()].dtype}")
#print(f"before tokenized, train_data_correct = {train_data_correct}, shape: {train_data_correct.shape}, dtype: {train_data_correct.dtype}")
#print(f"after tokenized, train_data_correct = {analyze_samples(train_data_correct, tokenizer, skip_special_tokens=True)}, shape: {train_data_correct.shape}, dtype: {train_data_correct.dtype}")
#print(f"for checking, train_data[:, 1] = {train_data[:, 1]}, shape: {train_data[:, 1].shape}, dtype: {train_data[:, 1].dtype}") # Ensure this is consistent
# Compute CrossEntropy loss directly between denoised_masked_token_logits and the correct tokens
loss_masked = nn.CrossEntropyLoss(ignore_index=-100)(denoised_masked_token_logits, train_data_correct)
#loss_masked = smooth_CE_loss_fn(denoised_masked_token_logits, train_data_correct)
# Both generated_samples and train_data are of tokenized embedding nature, hence use MSELoss() here for now
#loss_masked = nn.MSELoss()(generated_samples[mask.bool()], train_data[mask.bool()])
#loss_masked = nn.CrossEntropyLoss()(generated_samples[mask.bool()], train_data[mask.bool()])
else:
if USE_PRETRAINED_T5:
train_data_correct = train_data.long()
# we do not run CE loss for computing loss_masked because DataCollatorForSpanCorruption does not yet provide masked positions directly
else:
train_data_correct = train_data[mask.bool()].long()
# BERT_MLM model outputs a shape of [batch_size, sequence_length, vocab_size] which is feasible for computing CE loss
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
# Compute CrossEntropy loss directly between denoised_token_logits and the correct tokens
loss_masked = nn.CrossEntropyLoss(ignore_index=-100)(denoised_token_logits.view(-1, denoised_token_logits.size(-1)), train_data_correct)
if not USE_PRETRAINED_T5:
#print(f"generated_samples[mask.bool()] = {generated_samples[mask.bool()]}")
mask_token_penalty = (generated_samples[mask.bool()].int() == tokenizer.mask_token_id).sum().item()
sep_token_penalty = (generated_samples[mask.bool()].int() == tokenizer.sep_token_id).sum().item()
#print(f"mask_token_penalty = {mask_token_penalty}, sep_token_penalty = {sep_token_penalty}")
#loss_masked = loss_masked + mask_token_penalty * mask_token_penalty_weight + sep_token_penalty * sep_token_penalty_weight # Adjust the weight as needed
loss_unmasked = nn.MSELoss()(generated_samples[unmask.bool()], train_data[unmask.bool()])
if USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER:
alpha = 1.0 # Focus on optimizing loss_masked
else:
alpha = 0.5 # Adjust alpha as needed
print(f"loss_masked = {loss_masked}")
print(f"loss_unmasked = {loss_unmasked}")
loss_denoiser = alpha * loss_masked + (1 - alpha) * loss_unmasked
#print(f"loss_denoiser = {loss_denoiser}")
# Identify the range of `[unused]` tokens
unused_token_min_id = 1 # [unused0]
unused_token_max_id = 107 # [unused102]
# Create a mask to check for `[unused]` tokens in the generated samples
unused_token_mask = (generated_samples >= unused_token_min_id) & (generated_samples <= unused_token_max_id)
unused_token_penalty = unused_token_mask.sum().item()
#print(f"unused_token_penalty = {unused_token_penalty}")
# Penalty for predicting [unused] tokens
loss_denoiser = loss_denoiser + unused_token_penalty * unused_token_penalty_weight
else:
loss_denoiser = nn.MSELoss()(generated_samples, train_data)
loss_denoiser.backward(retain_graph=True)
#nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=10.0) # Gradient clipping
#check_for_vanishing_gradients(denoiser) # Check for vanishing gradients
optimizer_denoiser.step()
# Update learning rate
scheduler_ebm.step()
scheduler_denoiser.step()
# Explicitly delete tensors to save memory across training epochs
del train_data
del noisy_train_data
del masked_train_data
del mask
del unmask
del target
del input_ids
del target_label
#del non_special_tokens
#del random_indices
del input_pad_mask
del energy_real
del energy_fake
del energies
del mean_energy
del unused_token_mask
if ENABLE_MASK_LEARNING:
# Generate samples using walk-jump
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
del generated_samples
del denoised_masked_token_logits
del denoised_token_logits
else:
del generated_samples
del denoised_masked_token_logits
else:
# Generate samples using walk-jump
del generated_samples
return loss_ebm.item(), loss_denoiser.item()
# Initialize models
#ebm = EnergyBasedModel(input_dim, hidden_dim).to(device)
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
ebm = BertDenoiser(model_dim).to(device)
elif USE_PRETRAINED_T5:
#ebm = T5Denoiser().to(device) # we do not use T5 due to RAM memory restriction
ebm = TransformerDenoiser(input_dim, model_dim_ebm, num_layers_ebm, num_heads_ebm, sigma_max).to(device)
else:
ebm = TransformerDenoiser(input_dim, model_dim_ebm, num_layers_ebm, num_heads_ebm, sigma_max).to(device)
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
denoiser = BertDenoiser(model_dim).to(device)
elif USE_PRETRAINED_T5:
denoiser = T5Denoiser(model_dim).to(device)
else:
denoiser = TransformerDenoiser(input_dim, model_dim, num_layers, num_heads, sigma_max).to(device)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv1d):
nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
# the following init_weights are not used because it hurts `loss_unmasked` quite a lot
#ebm.apply(init_weights)
#denoiser.apply(init_weights)
# Load the AG News dataset
dataset = load_dataset('ag_news')
# Load the IMDb dataset
#dataset = load_dataset('imdb')
# Select a small subset of the training set for testing
NUM_OF_SMALL_SUBSET_OF_TRAIN_SET = 6000
train_dataset = dataset['train'].select(range(NUM_OF_SMALL_SUBSET_OF_TRAIN_SET))
if TEST_OVERFIT and not INFERENCE_ONLY:
# In the following 'for' loop, train_dataset variable seems to be of immutable, hence no data overwriting actually happenedd
#for i in range(NUM_OF_SMALL_SUBSET_OF_TRAIN_SET):
# train_dataset[i]['text'] = train_dataset[0]['text']
# Set the text of all the entries in the training set to be the same as train_dataset[0]['text']
new_text = train_dataset[0]['text']
train_dataset = train_dataset.map(lambda x: {'text': new_text})
dataset['train'] = train_dataset
print(f"an example training dataset at index 0: {train_dataset[0]['text']}")
print(f"an example training dataset at index 300: {train_dataset[300]['text']}")
def add_token_noise(input_ids, tokenizer, noise_level=0.05, noise_fraction=0.1):
"""
Add noise to tokenized input_ids by randomly replacing tokens based on a noise level,
but without modifying special tokens and filtering them out during noise application.
Args:
input_ids (List[int]): The input token ids.
tokenizer: Tokenizer with methods to convert ids to tokens and tokens to ids.
noise_level (float): The noise level that controls the magnitude of the noise.
noise_fraction (float): The fraction of tokens to apply noise to.
Returns:
torch.Tensor: The noisy token ids as a tensor with special tokens re-inserted.
"""
special_token_ids = set(tokenizer.all_special_ids) # Get the set of special token IDs
#print(f"All special token IDs: {tokenizer.all_special_ids}")
#print(f"Pad token ID: {tokenizer.pad_token_id}")
#print(f"special_token_ids: {special_token_ids}")
noisy_input_ids = []
for ids in input_ids:
# Ensure the output length matches the input length
original_length = len(ids)
ids = ids.tolist() # Convert tensor to a list of integers
#print(f"ids: {ids}")
#print(f"original_length: {original_length}")
# Store original positions of special tokens and filter them out
special_tokens_positions = {i: token_id for i, token_id in enumerate(ids) if token_id in special_token_ids}
filtered_ids = [token_id for token_id in ids if token_id not in special_token_ids]
#print(f"Input IDs before filtering: {ids}")
#print(f"Special tokens: {special_tokens_positions}")
#print(f"Filtered IDs: {filtered_ids}")
# Generate noise for the filtered non-special tokens
noise = torch.randn(len(filtered_ids)) # Generate noise for each non-special token
probs = torch.sigmoid(noise * noise_level)
# Create a counter of tokens in the tokenizer's vocabulary
vocab_size = len(tokenizer)
token_counter = Counter(range(vocab_size))
# Create a cumulative distribution for token sampling
cum_dist = []
total = 0
for token_id, count in token_counter.items():
total += count
cum_dist.append(total)
cum_dist = [x / total for x in cum_dist]
# Function to sample a token based on the cumulative distribution and a random value
def sample_token(rand_val):
for i, val in enumerate(cum_dist):
if rand_val <= val:
return list(token_counter.keys())[i]
return list(token_counter.keys())[-1]
# Apply noise to filtered non-special tokens
noisy_filtered_ids = [
sample_token(probs[i].item()) if probs[i].item() > random.random() and random.random() < noise_fraction
else token_id
for i, token_id in enumerate(filtered_ids)
]
# Re-insert special tokens into their original positions
noisy_ids = []
filtered_idx = 0
for i in range(len(ids)):
if i in special_tokens_positions:
noisy_ids.append(special_tokens_positions[i]) # Add special token
else:
noisy_ids.append(noisy_filtered_ids[filtered_idx]) # Add noisy/non-noisy token
filtered_idx += 1
# Ensure the length of noisy_ids matches the original length to prevent extra tokens
noisy_ids = noisy_ids[:original_length]
noisy_input_ids.append(noisy_ids)
# Convert noisy_input_ids (a list of lists) back into a tensor
return torch.tensor(noisy_input_ids)
def add_character_noise(text, noise_level):
# Sample noise from a normal distribution
noise = torch.randn(len(text))
# Convert noise to probabilities between 0 and 1
probs = torch.sigmoid(noise * noise_level)
# Create a counter of characters in the text
char_counter = Counter(string.ascii_letters + string.digits + string.punctuation)
# Create a cumulative distribution from the counter
cum_dist = []
total = 0
for char, count in char_counter.items():
total += count
cum_dist.append(total)
cum_dist = [x / total for x in cum_dist]
# Function to sample a character based on the cumulative distribution and a random value
def sample_char(rand_val):
for i, val in enumerate(cum_dist):
if rand_val <= val:
return list(char_counter.keys())[i]
return list(char_counter.keys())[-1] # Just in case of rounding errors
# Replace characters based on probability
noisy_text = "".join([
sample_char(probs[i].item()) if probs[i].item() > random.random() else char
for i, char in enumerate(text)
])
return noisy_text
def preprocess_function(examples):
#noisy_text = add_character_noise(examples, noise_level=sigma_max) # Or your chosen noising function
tokenized_inputs = tokenizer_function(examples, tokenizer)
#tokenized_inputs = tokenizer_function(noisy_text, tokenizer)
return tokenized_inputs
#dataset = dataset.map(preprocess_function, batched=True)
#dataset.set_format(type='torch', columns=['input_ids'])
encoded_inputs_file = 'encoded_inputs_walk_jump.pt'
if os.path.exists(encoded_inputs_file):
print("Loading pre-tokenized data...")
encoded_inputs = torch.load(encoded_inputs_file, weights_only=True)
else:
# Process data
print("Tokenizing data now ...")
processed_inputs = [preprocess_function(entry['text'])
for entry in dataset['train']]
# Concatenate tensors for each key
encoded_inputs = {
'input_ids': torch.cat([x['input_ids'] for x in processed_inputs], dim=0),
'attention_mask': torch.cat([x['attention_mask'] for x in processed_inputs], dim=0)
}
#encoded_inputs = torch.cat(encoded_inputs, dim=0)
torch.save(encoded_inputs, encoded_inputs_file)
print("Finished tokenizing data !!!")
class SomeDataset(Dataset):
def __init__(self, data):
#self.data = data
self.input_ids = data['input_ids'] # Shape: [total_size, sequence_length]
self.attention_mask = data['attention_mask'] # Shape: [total_size, sequence_length]
def __len__(self):
#return len(self.data)
return len(self.input_ids)
def __getitem__(self, idx):
#return self.data[idx]
return {
'input_ids': self.input_ids[idx], # Shape: [sequence_length]
'attention_mask': self.attention_mask[idx] # Shape: [sequence_length]
}
# Split the data into train and validation sets
total_size = len(encoded_inputs['input_ids'])
train_size = int(total_size * 0.8)
print(f"total_size = {total_size}")
# Split each tensor in the dictionary
train_data = {
'input_ids': encoded_inputs['input_ids'][:train_size],
'attention_mask': encoded_inputs['attention_mask'][:train_size]
}
val_data = {
'input_ids': encoded_inputs['input_ids'][train_size:],
'attention_mask': encoded_inputs['attention_mask'][train_size:]
}
#train_data = encoded_inputs[:train_size]
#val_data = encoded_inputs[train_size:]
train_dataset = SomeDataset(train_data)
val_dataset = SomeDataset(val_data)
# Use the data_collator to prepare inputs and labels for train_loader and val_loader
data_collator = DataCollatorForSpanCorruption(
tokenizer=tokenizer,
mlm_probability=MASK_RATIO,
mean_noise_span_length=3,
input_length=input_dim
)
# Create a DataLoader for batch processing
# Now we can use data_loader in the training loop
if MASK_RATIO != -1:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)
else:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Dummy data loader
if USE_DUMMY_TRAINING_DATA:
train_loader = DataLoader(torch.randn(100, input_dim), batch_size=batch_size, shuffle=True)
# Define noise schedule
noise_schedule = torch.arange(sigma_min, sigma_max, 0.1).to(device)
print(f"noise_schedule = {noise_schedule}")
def validate_walk_jump(ebm, denoiser, val_loader, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule):
ebm.eval()
denoiser.eval()
val_ebm_losses = []
val_denoiser_losses = []
pad_token_id = tokenizer.pad_token_id
cls_token_id = tokenizer.cls_token_id
sep_token_id = tokenizer.sep_token_id
with torch.no_grad():
for val_data in val_loader:
# Clear memory before processing each batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch.mps, 'empty_cache'): # Check if MPS backend exists
torch.mps.empty_cache()
#if isinstance(train_data, dict):
if MASK_RATIO != -1: # will be using data_collator which returns dict format
input_ids = val_data['input_ids'].to(device)
target_label = val_data['labels'].to(device) # this is the target for the masked tokens in which the model should predict to unmask
else:
input_ids = val_data['input_ids'].to(device)
target_label = input_ids.clone()
if ENABLE_MASK_LEARNING:
if MASK_RATIO != -1:
# We are now using data collator class to deal with masking strategy
mask = val_data['mask_indices'].to(device) # Get mask from data collator
# Randomly mask some tokens in the clean val_data
#mask = torch.rand(val_data.shape).to(device) < MASK_RATIO # 15% masking probability
mask = mask * 1 # converts True/False into 1/0
else:
# Randomly mask only 1 single token in the clean val_data
batch_size, seq_len = val_data['attention_mask'].shape
mask = torch.zeros_like(val_data['attention_mask'], dtype=torch.bool).to(device)
# Create a boolean mask for non-pad and non-special tokens
# (True where tokens are real, and False for [PAD], [CLS], [SEP])
non_special_tokens = (val_data['input_ids'] != pad_token_id) & (val_data['input_ids'] != cls_token_id) & (val_data['input_ids'] != sep_token_id)
#print(f"non_special_tokens has a shape of {non_special_tokens.shape}")
# Sum along the sequence dimension to get the actual length of each sequence (excluding special tokens)
actual_seq_len = non_special_tokens.sum(dim=1)
#print(f"actual_seq_len has a shape of {actual_seq_len.shape}")
# For each sequence in the batch, randomly mask one token
# random indexing starts from 1 since we do not want to index the first [CLS] token in each training sequence
# Random index per sequence in the batch
random_indices = torch.stack([torch.randint(1, length.item(), (1,)) for length in actual_seq_len]).squeeze()
#print(f"random_indices = {random_indices}, random_indices.shape = {random_indices.shape}")
# Mask the selected tokens at the random indices
mask[torch.arange(batch_size), random_indices] = 1
mask = mask * 1 # converts True/False into 1/0
assert(mask.sum() == batch_size) # shape : [batch_size, seq_len] , so only 1 masked token for each sequence
#if not(MASK_RATIO != -1): # will not be using data_collator which returns dict format
# Set non-masked positions in target_label to -100 or tokenizer.pad_token_id
# See the use of ignore_index in https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
#target_label[~mask.bool()] = CONSTANTS_VALUE_IGNORE
input_ids = input_ids.float().to(device)
if ADD_EXTRA_GAUSSIAN_NOISE:
noisy_val_data = add_token_noise(input_ids, tokenizer, noise_level=sigma_min).to(device) # Or your chosen noising function
else:
noisy_val_data = input_ids # for isolating the extra noise added
if MASK_RATIO != -1:
if USE_PRETRAINED_T5:
# We should not put '<extra_id_0>' for all masked tokens, they should be numbered accordingly as in '<extra_id_*>' to indicate ordering
# We had used the data_collator to prepare val_loader, so no need to manually modify
masked_val_ids = noisy_val_data.clone() # we want to denoise noisy data to its clean version
else:
# we are now using data_collator for masking purpose, see _mask_tokens_standard()
# We had used the data_collator to prepare val_loader, so no need to manually modify
masked_val_ids = noisy_val_data.clone() # we want to denoise noisy data to its clean version
masked_val_ids[mask.bool()] = tokenizer.mask_token_id
else:
if USE_PRETRAINED_T5:
masked_val_ids = noisy_val_data.clone() # we want to denoise noisy data to its clean version
masked_val_ids[mask.bool()] = tokenizer.convert_tokens_to_ids('<extra_id_0>')
else:
masked_val_ids = noisy_val_data.clone() # we want to denoise noisy data to its clean version
masked_val_ids[mask.bool()] = tokenizer.mask_token_id
#print(f"Masked validation ids = {masked_val_ids}")
# ebm model and denoiser model are trained independently, no gradient connections between them
masked_val_data = {
'input_ids': masked_val_ids,
'labels': target_label,
'attention_mask': val_data['attention_mask']
}
# Add noise to input data
noisy_val_ids = val_data['input_ids'].float() + torch.randn_like(val_data['input_ids'].float()) * sigma
# Scales to range of [0, tokenizer.vocab_size-1]
noisy_val_ids = noisy_val_ids - noisy_val_ids.min()
noisy_val_ids = noisy_val_ids / noisy_val_ids.max()
noisy_val_ids = noisy_val_ids * (tokenizer.vocab_size - 1)
# checks for potential issues
assert_sample_range_compliance(noisy_val_ids, tokenizer)
# Get energy of noisy input
energy_real = ebm({
'input_ids': noisy_val_ids,
'labels': target_label,
'attention_mask': val_data['attention_mask']
})
# for the purpose of more efficient run for the denoiser model
input_pad_mask = (val_data['input_ids'] == tokenizer.pad_token_id).to(device)
if ENABLE_MASK_LEARNING:
# Generate samples using walk-jump
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
noisy_ids, generated_samples, denoised_masked_token_logits, denoised_token_logits = walk_jump_sampling(masked_val_data, mask, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask, target_label=target_label)
else:
noisy_ids, generated_samples, denoised_masked_token_logits = walk_jump_sampling(masked_val_data, mask, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask, target_label=target_label)
else:
# Generate samples using walk-jump
noisy_ids, generated_samples = walk_jump_sampling(val_data, None, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask, target_label=target_label)
# Analyzing the generated samples during validation
if ENABLE_SAMPLE_ANALYSIS:
analyze_samples(generated_samples, tokenizer)
#energy_fake = ebm(generated_samples)
energy_fake = ebm({
'input_ids': noisy_ids, #val_data['input_ids'],
'labels': target_label,
'attention_mask': val_data['attention_mask']
})
assert not torch.all(energy_fake == 0), "Error: energy_fake contains all zeros!"
# Compute EBM loss with contrastive divergence
#val_ebm_loss = (energy_real.mean() - energy_fake.mean()) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
# Compute EBM loss with contrastive divergence, log-sum-exp trick and offset
energies = torch.cat([energy_real, energy_fake])
mean_energy = torch.mean(energies)
#val_ebm_loss = log_sum_exp(-(energy_real-mean_energy)) - log_sum_exp(-(energy_fake-mean_energy)) + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
val_ebm_loss = energy_real.mean() + log_sum_exp(-energy_real).mean() - energy_fake.mean() - log_sum_exp(-energy_fake).mean() + ebm_energy_regularization_scale * (energy_real ** 2).mean() # Added L2 regularization
if ENABLE_MASK_LEARNING:
#print(f"mask = {mask}")
unmask = 1 - mask # Inverse of mask
# Compute the loss for unmasked positions
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
# Compute CrossEntropy loss directly between denoised_token_logits and the correct tokens
#print(f"denoised_token_logits.shape = {denoised_token_logits.shape}, val_data.shape = {val_data.shape}")
target = val_data['input_ids'].long()
val_denoiser_loss = nn.CrossEntropyLoss(ignore_index=-100)(denoised_token_logits.view(-1, denoised_token_logits.size(-1)), target.view(-1))
else:
# Compute the loss for masked positions
loss_masked = 0.00
if MASK_RATIO == -1: # loss computation in the case of a single masked token
# Assuming that 'mask' is a tensor of shape [batch_size, seq_len] where True indicates a masked position
val_data_correct = val_data[mask.bool()].long()
#print(f"denoised_masked_token_logits shape: {denoised_masked_token_logits.shape}, dtype: {denoised_masked_token_logits.dtype}")
#print(f"val_data_correct shape: {val_data_correct.shape}, dtype: {val_data_correct.dtype}")
# Compute CrossEntropy loss directly between denoised_masked_token_logits and the correct tokens
loss_masked = nn.CrossEntropyLoss(ignore_index=-100)(denoised_masked_token_logits, val_data_correct)
#loss_masked = smooth_CE_loss_fn(denoised_masked_token_logits, val_data_correct)
# Both generated_samples and val_data are of tokenized embedding nature, hence use MSELoss() here for now
#loss_masked = nn.MSELoss()(generated_samples[mask.bool()], val_data[mask.bool()])
#loss_masked = nn.CrossEntropyLoss()(generated_samples[mask.bool()], val_data[mask.bool()])
else:
if USE_PRETRAINED_T5:
val_data_correct = val_data.long()
# we do not run CE loss for computing loss_masked because DataCollatorForSpanCorruption does not yet provide masked positions directly
else:
val_data_correct = val_data[mask.bool()].long()
# BERT_MLM model outputs a shape of [batch_size, sequence_length, vocab_size] which is feasible for computing CE loss
if USE_PRETRAINED_BERT or USE_PRETRAINED_BERT_MLM:
# Compute CrossEntropy loss directly between denoised_token_logits and the correct tokens
loss_masked = nn.CrossEntropyLoss(ignore_index=-100)(denoised_token_logits.view(-1, denoised_token_logits.size(-1)), val_data_correct)
loss_unmasked = nn.MSELoss()(generated_samples[unmask.bool()], val_data[unmask.bool()])
if USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER:
alpha = 1.0 # Focus on optimizing loss_masked
else:
alpha = 0.5 # Adjust alpha as needed
print(f"loss_masked = {loss_masked}")
print(f"loss_unmasked = {loss_unmasked}")
val_denoiser_loss = alpha * loss_masked + (1 - alpha) * loss_unmasked
else:
val_denoiser_loss = nn.MSELoss()(generated_samples, val_data)
#print(f"val_denoiser_loss = {val_denoiser_loss}")
val_ebm_losses.append(val_ebm_loss.item())
val_denoiser_losses.append(val_denoiser_loss.item())
'''
# Explicitly delete tensors to save memory across training epochs
del val_data
del noisy_val_data
del masked_val_data
del mask
del unmask
del target
del input_ids
del target_label
#del non_special_tokens
#del random_indices
del input_pad_mask
del energy_real
del energy_fake
del energies
del mean_energy
#del unused_token_mask
if ENABLE_MASK_LEARNING:
# Generate samples using walk-jump
if USE_LOGITS_FOR_THE_ENTIRE_SENTENCE: # denoised_token_logits will have a shape of [batch_size, sequence_length, vocab_size]
del generated_samples
del denoised_masked_token_logits
del denoised_token_logits
else:
del generated_samples
del denoised_masked_token_logits
else:
# Generate samples using walk-jump
del generated_samples
'''
return numpy.mean(val_ebm_losses), numpy.mean(val_denoiser_losses)
# Top-p sampling selects the smallest possible set of tokens whose cumulative probability
# exceeds a threshold p. This allows for more dynamic and contextual selection of tokens
# based on their probabilities.
def top_p_sampling(probabilities, p=0.9):
sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
threshold_index = (cumulative_probs > p).nonzero(as_tuple=True)[0][0]
top_p_probs = sorted_probs[:threshold_index+1]
top_p_indices = sorted_indices[:threshold_index+1]
return top_p_probs, top_p_indices
# Temperature-based sampling introduces a temperature parameter that controls the randomness
# of the sampling process. Higher temperatures result in more diverse and random samples,
# while lower temperatures produce more deterministic and conservative samples.
def temperature_sampling(probabilities, temperature=1.0):
tempered_probs = torch.pow(probabilities, 1.0 / temperature)
tempered_probs /= torch.sum(tempered_probs, dim=-1, keepdim=True)
# Flatten the tempered probabilities
tempered_probs = tempered_probs.view(-1)
sampled_token_index = torch.multinomial(tempered_probs, num_samples=1)
sampled_token_id = sampled_token_index.item()
return sampled_token_index, sampled_token_id
# Get all special token IDs
special_token_ids = tokenizer.all_special_ids
# Get all token IDs that are labeled as "unused" in the vocabulary
range_of_unused_token_ids = 1000
unused_token_ids = [tokenizer.convert_tokens_to_ids(f'[unused{i}]') for i in range(range_of_unused_token_ids)]
# Create a set of all token indices
all_token_indices = set(range(tokenizer.vocab_size))
# Subtract the special token indices and unused token indices from the full range
valid_token_indices = sorted(all_token_indices - set(special_token_ids) - set(unused_token_ids))
# Further filter out subword tokens (those that start with '##') due to WordPiece or Byte-Pair Encoding (BPE) scheme in the tokenizer
filtered_valid_token_indices = [idx for idx in valid_token_indices if not tokenizer.convert_ids_to_tokens(idx).startswith('##')]
# Function to generate a random non-special token ID
def generate_random_non_special_token():
return random.choice(filtered_valid_token_indices)
def infer_walk_jump(input_text, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, max_length=20, top_k=5):
ebm.eval()
denoiser.eval()
# for sliding window operation
window_length = len(input_text.split())
print(f"window_length = {window_length}")
with torch.no_grad():
print(f"input_text = {input_text}")
tokenized_input = tokenizer_function(input_text, tokenizer)
input_ids = tokenized_input['input_ids'].float().to(device)
print(f"before add_token_noise, input_ids = {analyze_samples(input_ids, tokenizer, skip_special_tokens=True, num_samples=1)}")
if ADD_EXTRA_GAUSSIAN_NOISE:
noisy_input_ids = add_token_noise(input_ids, tokenizer, noise_level=sigma_min).to(device) # Or your chosen noising function
else:
noisy_input_ids = input_ids
print(f"after add_token_noise, noisy_input_ids = {analyze_samples(noisy_input_ids, tokenizer, skip_special_tokens=True, num_samples=1)}")
current_ids = noisy_input_ids.clone()
generated_texts = []
if not GENERATES_OUTPUT_OF_VARYING_LENGTH:
# we only run the following overlapping rolling diffusion loop ONCE, hence we directly get the output from the first iteration
max_length = 1
for current_index in range(max_length):
print(f"current_index = {current_index}")
print(f"current_ids has shape of {current_ids.shape}")
print(f"current_ids = {current_ids}")
if GENERATES_OUTPUT_OF_VARYING_LENGTH:
# sliding window diffusion as in [rolling diffusion models](http://arxiv.org/abs/2402.09470)
# sliding window concept is primarily designed for continuous domains like images or videos,
# might not directly translate well to NLP tasks. In the image or video domain, frames or pixels
# can have strong local correlations, making interpolation and context windows effective.
# However, in NLP, each word or token is discrete and often depends on long-range dependencies
# that are not easily captured with a simple sliding window.
#generated_ids = current_ids[:, window_length-1]
#generated_ids = current_ids[:, -1]
#generated_ids = tokenizer.mask_token_id
#generated_ids = generate_random_non_special_token()
#generated_ids = torch.tensor(generated_ids, dtype=torch.float).unsqueeze(0)
#print(f"generated_ids = {generated_ids}")
#print(f"generated_ids is of type {type(generated_ids)}")
#generated_text = analyze_samples(generated_ids, tokenizer, skip_special_tokens=True, num_samples=1)
#print(f"generated_text = {generated_text}")
#print(f"generated_text is of type {type(generated_text)}")
if ENABLE_MASK_LEARNING:
# Create a copy of current_ids and replace the token at the desired position with the mask_token_id
masked_ids = current_ids.clone()
mask_position = len(input_text.split()) + current_index + 2 # plus 2 because of the B.O.S. token and the newly generated next token
sep_position = mask_position + 1
print(f"mask_position = {mask_position} , mask_token_id = {tokenizer.mask_token_id}")
print(f"sep_position = {sep_position} , sep_token_id = {tokenizer.sep_token_id}")
print(f"Before modification, masked_ids = {masked_ids}")
masked_ids[0, mask_position] = tokenizer.mask_token_id
# Print the specific index to see if it was updated
print(f"After modification, token at mask_position: {masked_ids[0, mask_position]}")
masked_ids[0, sep_position] = tokenizer.sep_token_id # indication of sentence phrase partial end-separation
print(f"After modification, masked_ids = {masked_ids}")
# for the purpose of more efficient run for the denoiser model
input_pad_mask = (masked_ids == tokenizer.pad_token_id).to(device)
# Pass the masked sequence to the walk-jump model for processing and denoising
noisy_ids, generated_samples = walk_jump_sampling(masked_ids, None, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask)
else:
# for the purpose of more efficient run for the denoiser model
input_pad_mask = (current_ids == tokenizer.pad_token_id).to(device)
# Pass the sequence to the walk-jump model for processing and denoising
noisy_ids, generated_samples = walk_jump_sampling(current_ids, None, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask)
print(f"generated_samples has shape of {generated_samples.shape}")
# Extract the next predicted token
generated_text = analyze_samples(generated_samples, tokenizer, skip_special_tokens=True, num_samples=1)
next_token = generated_text[0].split()[-1]
print(f"next_token = {next_token}")
# Append the predicted token to the generated_texts list
generated_texts.append(next_token)
# Update current_ids with the predicted token
current_ids = generated_samples
# Since current_ids contains some noising and denoising artifact,
# there is possibility of token switch for the original input_text fed into the model
decoded_current_ids = analyze_samples(current_ids, tokenizer, skip_special_tokens=True, num_samples=1)
print(f"decoded_current_ids = {decoded_current_ids}")
print(f"decoded_current_ids is of type {type(decoded_current_ids)}")
decoded_current_ids = str(decoded_current_ids[0]).split()
decoded_current_ids[:len(input_text.split())] = input_text.split()
print(f"decoded_current_ids after 1st clean-up = {decoded_current_ids}")
# Remove the first word token
#decoded_current_ids = str(decoded_current_ids[0]).split()[1:]
# Sometimes the model will combine punctuation mark into part of words,
# hence the above operation will remove the first word token together with the punctuation mark
#if len(decoded_current_ids) < window_length-1:
# add back the punctuation mark, and it is usually period mark
#decoded_current_ids = list('.') + decoded_current_ids
#print(f"decoded_current_ids after removing first token = {decoded_current_ids}")
#print(f"decoded_current_ids after removing first token is of type {type(decoded_current_ids)}")
# appends the next predicted token
# Use Entire History, Maintain the entire sequence generated so far as the context for generating the next token.
# This avoids truncating important context, which is crucial in NLP.
if current_index > 0:
decoded_current_ids[len(input_text.split()):] = generated_texts
print(f"decoded_current_ids after 2nd clean-up = {decoded_current_ids}")
print(f"len(decoded_current_ids) = {len(decoded_current_ids)}")
new_input = decoded_current_ids
# converts to string
new_input = ' '.join(str(x) for x in new_input)
print(f"new_input = {new_input}")
# prepares the next input iteration for the walk-jump model
current_ids_retokenized = tokenizer_function(new_input, tokenizer)
current_ids = current_ids_retokenized['input_ids'].float().to(device)
# If EOS token is generated, stop generation
if next_token == tokenizer.eos_token:
break
else:
# for the purpose of more efficient run for the denoiser model
input_pad_mask = (current_ids == tokenizer.pad_token_id).to(device)
# Pass the sequence to the walk-jump model for processing and denoising
noisy_ids, generated_samples = walk_jump_sampling(current_ids, None, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, input_pad_mask=input_pad_mask)
print(f"generated_samples has shape of {generated_samples.shape}")
generated_texts = analyze_samples(generated_samples, tokenizer, skip_special_tokens=True, num_samples=1)
print(f"generated_texts has a length of {len(generated_texts)}")
return generated_texts
if not INFERENCE_ONLY:
# Train and validate models
if (USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER):
lr = 1e-4
else:
lr = 3e-3
weight_decay = 1e-5
eps = 1e-8
betas = (0.9, 0.999) # only for Adam
if USE_PRETRAINED_T5:
# replace AdamW with Adafactor
# See https://github.com/PiotrNawrot/nanoT5/blob/1375b389d33ab4f34754a9fca62e4cfa1dd52379/README.md?plain=1#L36
optimizer_ebm = Adafactor(ebm.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
optimizer_denoiser = Adafactor(denoiser.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
else:
if USE_ADAM_MINI: # for saving RAM memory during training
if not (USE_CUSTOM_TRANSFORMER_ENCODER_DECODER or USE_CUSTOM_TRANSFORMER_ENCODER):
if USE_PRETRAINED_BERT_MLM:
model_dim = denoiser.config.hidden_size
num_heads = denoiser.config.num_attention_heads
else:
model_dim = denoiser.config.dim
num_heads = denoiser.config.num_attention_heads
optimizer_ebm = Adam_mini(
named_parameters = ebm.named_parameters(),
lr = lr,
betas = betas,
eps = eps,
weight_decay = weight_decay,
dim = model_dim_ebm,
n_heads = num_heads_ebm,
n_kv_heads = num_heads_ebm,
)
optimizer_denoiser = Adam_mini(
named_parameters = denoiser.named_parameters(),
lr = lr,
betas = betas,
eps = eps,
weight_decay = weight_decay,
dim = model_dim,
n_heads = num_heads,
n_kv_heads = num_heads,
)
# https://github.com/zyushun/Adam-mini/issues/30
optimizer_ebm.embd_names.add('embedding') # add the keyword of the embedding layer
optimizer_ebm.output_names.add('denoise_head') # output layer of EBM model is not using projection layer
optimizer_denoiser.embd_names.add('embedding') # add the keyword of the embedding layer
optimizer_denoiser.output_names.add('projection') # projection layer is using weight-tying with embedding layer
optimizer_ebm.mlp_names = {"self_attn"}
optimizer_denoiser.mlp_names = {"self_attn"}
optimizer_ebm.mlp_names.add("attn")
optimizer_ebm.mlp_names.add("linear")
optimizer_denoiser.mlp_names.add("attn")
optimizer_denoiser.mlp_names.add("linear")
optimizer_denoiser.wqk_names.add("self_attn") # For query, key, and value combined
optimizer_denoiser.wqk_names.add("multihead_attn")
else:
optimizer_ebm = optim.AdamW(ebm.parameters(), lr=lr, weight_decay=weight_decay) # Added weight decay
optimizer_denoiser = optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=weight_decay) # Added weight decay
def warmup_schedule(current_step: int):
warmup_steps = 1000
step_size = 5 * len(train_loader) # 5 epochs
gamma = 0.5
if current_step < warmup_steps:
# Warmup phase
return float(current_step) / float(max(1, warmup_steps))
else:
# After warmup, apply step decay
num_steps_after_warmup = current_step - warmup_steps
num_step_decays = num_steps_after_warmup // step_size
return gamma ** num_step_decays
def warmup_cosine_schedule(current_step: int):
warmup_steps = 500
total_steps = num_epochs * len(train_loader)
if current_step < warmup_steps:
# Warmup phase
return float(current_step) / float(max(1, warmup_steps))
else:
# After warmup, apply cosine decay
return 0.5 * (1 + math.cos(math.pi * (current_step - warmup_steps) / (total_steps - warmup_steps)))
if USE_PRETRAINED_T5:
scheduler_ebm = AdafactorSchedule(optimizer_ebm)
scheduler_denoiser = AdafactorSchedule(optimizer_denoiser)
else:
scheduler_ebm = LambdaLR(optimizer_ebm, lr_lambda=warmup_cosine_schedule)
scheduler_denoiser = LambdaLR(optimizer_denoiser, lr_lambda=warmup_cosine_schedule)
#scheduler_ebm = optim.lr_scheduler.StepLR(optimizer_ebm, step_size=5, gamma=0.5)
#scheduler_denoiser = optim.lr_scheduler.StepLR(optimizer_denoiser, step_size=5, gamma=0.5)
if USE_MIXED_PRECISION_TRAINING and torch.cuda.is_available():
scaler = GradScaler() # MPS backend does not have this option yet
else:
scaler = None
best_val_ebm_loss = float('inf')
best_val_denoiser_loss = float('inf')
for epoch in range(num_epochs):
# Train models
train_ebm_loss, train_denoiser_loss = train_walk_jump(ebm, denoiser, train_loader, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, optimizer_ebm, optimizer_denoiser, scheduler_ebm, scheduler_denoiser, scaler)
# Validate models
val_ebm_loss, val_denoiser_loss = validate_walk_jump(ebm, denoiser, val_loader, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule)
print(f"Epoch {epoch + 1}/{num_epochs}, Train EBM Loss: {train_ebm_loss:.8f}, Train Denoiser Loss: {train_denoiser_loss:.8f}, Val EBM Loss: {val_ebm_loss:.8f}, Val Denoiser Loss: {val_denoiser_loss:.8f}")
# Save the trained models
if val_ebm_loss < best_val_ebm_loss:
best_val_ebm_loss = val_ebm_loss
torch.save(ebm.state_dict(), 'best_ebm.pth')
if val_denoiser_loss < best_val_denoiser_loss:
best_val_denoiser_loss = val_denoiser_loss
torch.save(denoiser.state_dict(), 'best_denoiser.pth')
# Early stopping check
if (val_denoiser_loss < EARLY_STOP_THRESHOLD) and USE_EARLY_STOP:
print(f"Early stopping triggered. Validation loss ({val_denoiser_loss:.4f}) is below the threshold ({EARLY_STOP_THRESHOLD}).")
break
print("Training complete.")
# Inference
ebm.load_state_dict(torch.load('best_ebm.pth', weights_only=True))
denoiser.load_state_dict(torch.load('best_denoiser.pth', weights_only=True))
if not GENERATES_OUTPUT_OF_VARYING_LENGTH:
#input_text = input_text + "an unfortunate shooting event in one of the Donald Trump's presidential election campaigns"
if TEST_OVERFIT:
input_text = dataset['train'][0]['text']
else:
input_text = dataset['train'][300]['text']
#input_text = "Pandemic is inevitable these days, we need to ensure that we follow lockdown policies restrictions"
num_of_words_fed_into_the_model = 6
num_of_words_to_be_generated = len(input_text.split()) - num_of_words_fed_into_the_model
input_text = " ".join(input_text.split()[0:num_of_words_fed_into_the_model-1])
print(f"input_text = {input_text}")
generated_texts = infer_walk_jump(input_text, ebm, denoiser, num_walk_steps, walk_step_size, num_jump_steps, noise_schedule, max_length=num_of_words_to_be_generated)
print("Generated text:", generated_texts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment