Skip to content

Instantly share code, notes, and snippets.

@mahadirz
Created May 12, 2025 14:38
Show Gist options
  • Save mahadirz/8093f49b4b0e2d7d8862901bec64a7af to your computer and use it in GitHub Desktop.
Save mahadirz/8093f49b4b0e2d7d8862901bec64a7af to your computer and use it in GitHub Desktop.
import torch.nn.functional as F
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, dim_in, dim_out, context_length, dropout, num_heads):
super().__init__()
assert dim_out % num_heads == 0, "dim_out must be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim_out // num_heads
self.scale = self.head_dim ** -0.5 # For scaling dot product
# Q, K, V projections
self.W_q = nn.Linear(dim_in, dim_out)
self.W_k = nn.Linear(dim_in, dim_out)
self.W_v = nn.Linear(dim_in, dim_out)
# Final output projection
self.out_proj = nn.Linear(dim_out, dim_out)
# Dropout
self.dropout = nn.Dropout(dropout)
# Causal mask (upper triangular matrix)
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# exclude this mask from being updated during backward
self.register_buffer("mask", mask)
def forward(self, x):
batch_size, embed_size, _ = x.size() # Batch, embed_size, Features
# Linear projections
Q = self.W_q(x) # (batch_size, embed_size, dim_out)
K = self.W_k(x)
V = self.W_v(x)
# Split the features equally into every heads
Q = Q.view(batch_size, embed_size, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, embed_size, head_dim)
K = K.view(batch_size, embed_size, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, embed_size, self.num_heads, self.head_dim).transpose(1, 2)
# Attention scores
attention_scores = (Q @ K.transpose(-2, -1)) * self.scale # (batch_size, num_heads, embed_size, embed_size)
# Apply causal mask
mask = self.mask[:embed_size, :embed_size] == 1
attention_scores = attention_scores.masked_fill(mask, float('-inf'))
# Softmax and dropout
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Weighted sum of values
context = attention_weights @ V # (B, num_heads, embed_size, head_dim)
context = context.transpose(1, 2).contiguous().view(batch_size, embed_size, -1) # (batch_size, embed_size, dim_out)
return self.out_proj(context)
class LayerNorm(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.eps = 1e-5 # epsilon, small num to prevent dividing 0
self.scale = nn.Parameter(torch.ones(embed_dim))
self.shift = nn.Parameter(torch.zeros(embed_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x-mean) / torch.sqrt(var+self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x,3)) ))
class FeedForward(nn.Module):
def __init__(self, dim):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, 4 * dim),
GELU(),
nn.Linear(4 * dim,dim),
)
def forward(self, x):
return self.layers(x)
class TransformerLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.att = MultiHeadAttention(
dim_in = config.embedding_dim,
dim_out = config.embedding_dim,
context_length = config.context_length,
num_heads=config.attention_head,
dropout=config.drop_rate
)
self.ff = FeedForward(config.embedding_dim)
self.norm1 = LayerNorm(config.embedding_dim)
self.norm2 = LayerNorm(config.embedding_dim)
self.drop_shortcut = nn.Dropout(config.drop_rate)
def forward(self, x):
residual1 = x
x = self.norm1(x)
x = self.att(x)
x = self.drop_shortcut(x)
x = x + residual1
residual2 = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + residual2
return x
class GPTModel(nn.Module):
def __init__(self, config):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
self.positional_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
self.dropout = nn.Dropout(config.drop_rate)
self.transformer_layers = nn.Sequential(
*[ TransformerLayer(config) for _ in range(config.transformer_layer)]
)
self.final_norm = nn.LayerNorm(config.embedding_dim)
self.final_output = nn.Linear(config.embedding_dim, config.vocab_size)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
token_embedding = self.token_embedding(in_idx)
positional_embedding = self.positional_embedding(torch.arange(seq_len, device=in_idx.device))
x = token_embedding + positional_embedding
x = self.dropout(x)
x = self.transformer_layers(x)
x = self.final_norm(x)
logits = self.final_output(x)
return logits
class Config:
def __init__(self):
self.vocab_size = 50257
self.embedding_dim = 768
self.drop_rate = 0.1
self.transformer_layer = 12
self.attention_head = 12
self.context_length = 1024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment