Created
May 12, 2025 14:38
-
-
Save mahadirz/8093f49b4b0e2d7d8862901bec64a7af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn.functional as F | |
import torch | |
import torch.nn as nn | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, dim_in, dim_out, context_length, dropout, num_heads): | |
super().__init__() | |
assert dim_out % num_heads == 0, "dim_out must be divisible by num_heads" | |
self.num_heads = num_heads | |
self.head_dim = dim_out // num_heads | |
self.scale = self.head_dim ** -0.5 # For scaling dot product | |
# Q, K, V projections | |
self.W_q = nn.Linear(dim_in, dim_out) | |
self.W_k = nn.Linear(dim_in, dim_out) | |
self.W_v = nn.Linear(dim_in, dim_out) | |
# Final output projection | |
self.out_proj = nn.Linear(dim_out, dim_out) | |
# Dropout | |
self.dropout = nn.Dropout(dropout) | |
# Causal mask (upper triangular matrix) | |
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1) | |
# exclude this mask from being updated during backward | |
self.register_buffer("mask", mask) | |
def forward(self, x): | |
batch_size, embed_size, _ = x.size() # Batch, embed_size, Features | |
# Linear projections | |
Q = self.W_q(x) # (batch_size, embed_size, dim_out) | |
K = self.W_k(x) | |
V = self.W_v(x) | |
# Split the features equally into every heads | |
Q = Q.view(batch_size, embed_size, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, num_heads, embed_size, head_dim) | |
K = K.view(batch_size, embed_size, self.num_heads, self.head_dim).transpose(1, 2) | |
V = V.view(batch_size, embed_size, self.num_heads, self.head_dim).transpose(1, 2) | |
# Attention scores | |
attention_scores = (Q @ K.transpose(-2, -1)) * self.scale # (batch_size, num_heads, embed_size, embed_size) | |
# Apply causal mask | |
mask = self.mask[:embed_size, :embed_size] == 1 | |
attention_scores = attention_scores.masked_fill(mask, float('-inf')) | |
# Softmax and dropout | |
attention_weights = F.softmax(attention_scores, dim=-1) | |
attention_weights = self.dropout(attention_weights) | |
# Weighted sum of values | |
context = attention_weights @ V # (B, num_heads, embed_size, head_dim) | |
context = context.transpose(1, 2).contiguous().view(batch_size, embed_size, -1) # (batch_size, embed_size, dim_out) | |
return self.out_proj(context) | |
class LayerNorm(nn.Module): | |
def __init__(self, embed_dim): | |
super().__init__() | |
self.eps = 1e-5 # epsilon, small num to prevent dividing 0 | |
self.scale = nn.Parameter(torch.ones(embed_dim)) | |
self.shift = nn.Parameter(torch.zeros(embed_dim)) | |
def forward(self, x): | |
mean = x.mean(dim=-1, keepdim=True) | |
var = x.var(dim=-1, keepdim=True, unbiased=False) | |
norm_x = (x-mean) / torch.sqrt(var+self.eps) | |
return self.scale * norm_x + self.shift | |
class GELU(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x,3)) )) | |
class FeedForward(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(dim, 4 * dim), | |
GELU(), | |
nn.Linear(4 * dim,dim), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class TransformerLayer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.att = MultiHeadAttention( | |
dim_in = config.embedding_dim, | |
dim_out = config.embedding_dim, | |
context_length = config.context_length, | |
num_heads=config.attention_head, | |
dropout=config.drop_rate | |
) | |
self.ff = FeedForward(config.embedding_dim) | |
self.norm1 = LayerNorm(config.embedding_dim) | |
self.norm2 = LayerNorm(config.embedding_dim) | |
self.drop_shortcut = nn.Dropout(config.drop_rate) | |
def forward(self, x): | |
residual1 = x | |
x = self.norm1(x) | |
x = self.att(x) | |
x = self.drop_shortcut(x) | |
x = x + residual1 | |
residual2 = x | |
x = self.norm2(x) | |
x = self.ff(x) | |
x = self.drop_shortcut(x) | |
x = x + residual2 | |
return x | |
class GPTModel(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.token_embedding = nn.Embedding(config.vocab_size, config.embedding_dim) | |
self.positional_embedding = nn.Embedding(config.vocab_size, config.embedding_dim) | |
self.dropout = nn.Dropout(config.drop_rate) | |
self.transformer_layers = nn.Sequential( | |
*[ TransformerLayer(config) for _ in range(config.transformer_layer)] | |
) | |
self.final_norm = nn.LayerNorm(config.embedding_dim) | |
self.final_output = nn.Linear(config.embedding_dim, config.vocab_size) | |
def forward(self, in_idx): | |
batch_size, seq_len = in_idx.shape | |
token_embedding = self.token_embedding(in_idx) | |
positional_embedding = self.positional_embedding(torch.arange(seq_len, device=in_idx.device)) | |
x = token_embedding + positional_embedding | |
x = self.dropout(x) | |
x = self.transformer_layers(x) | |
x = self.final_norm(x) | |
logits = self.final_output(x) | |
return logits | |
class Config: | |
def __init__(self): | |
self.vocab_size = 50257 | |
self.embedding_dim = 768 | |
self.drop_rate = 0.1 | |
self.transformer_layer = 12 | |
self.attention_head = 12 | |
self.context_length = 1024 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment