Created
August 13, 2025 11:41
-
-
Save tuttlem/1c848af07c79992e0517cbd24bd3873e to your computer and use it in GitHub Desktop.
A friendly, didactic, CPU-first, decoder-only Transformer (GPT-style).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mini_transformer.py | |
# A friendly, didactic, CPU-first, decoder-only Transformer (GPT-style). | |
# --------------------------------------------------------------------- | |
# Dependencies: | |
# torch | |
# numpy | |
# | |
# How to run: | |
# python mini_transformer.py | |
# | |
# Things to tweak: | |
# - TEXT (your training corpus) | |
# - Config(...) at the bottom (n_layers, n_heads, d_model, block_size, etc.) | |
# - sampling args in generate(...) | |
from __future__ import annotations | |
import math, os, time, warnings, random | |
from dataclasses import dataclass | |
from typing import Tuple, Optional | |
# --- Make sure we stay on CPU and silence scary CUDA banners for old GPUs --- | |
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") | |
warnings.filterwarnings("ignore", message=".*cuda capability.*") | |
warnings.filterwarnings("ignore", message=".*not compatible with the current PyTorch installation.*") | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# ----------------------- | |
# 0) Utilities / Config | |
# ----------------------- | |
def set_seed(seed: int = 1337) -> None: | |
random.seed(seed) | |
torch.manual_seed(seed) | |
def pick_device() -> torch.device: | |
# We intentionally use CPU for portability and predictable behavior. | |
return torch.device("cpu") | |
@dataclass | |
class Config: | |
vocab_size: int = 256 # byte-level tokenizer => 256 possible values | |
d_model: int = 128 | |
n_heads: int = 4 | |
n_layers: int = 2 | |
block_size: int = 64 # context length (tokens) | |
dropout: float = 0.1 | |
lr: float = 3e-3 | |
max_iters: int = 1000 | |
eval_every: int = 100 | |
batch_size: int = 16 | |
# ----------------------- | |
# 1) Tokenizer: byte-level | |
# ----------------------- | |
class ByteTokenizer: | |
"""UTF-8 bytes <-> ints in [0..255]. Simple & dependency-free.""" | |
def __init__(self) -> None: | |
self.vocab_size = 256 | |
def encode(self, s: str) -> list[int]: | |
return list(s.encode("utf-8")) | |
def decode(self, ids: list[int]) -> str: | |
return bytes(ids).decode("utf-8", errors="ignore") | |
# --------------------------------- | |
# 2) Core building blocks (modules) | |
# --------------------------------- | |
class MultiHeadSelfAttention(nn.Module): | |
"""Causal multi-head self-attention. | |
Shapes: | |
x: (B, T, C) where C = d_model | |
returns: (B, T, C) | |
""" | |
def __init__(self, d_model: int, n_heads: int, block_size: int, dropout: float) -> None: | |
super().__init__() | |
assert d_model % n_heads == 0, "d_model must be divisible by n_heads" | |
self.n_heads = n_heads | |
self.head_dim = d_model // n_heads | |
self.block_size = block_size | |
# One linear does Q, K, V together (faster & common) | |
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) | |
self.out_proj = nn.Linear(d_model, d_model, bias=False) | |
self.attn_drop = nn.Dropout(dropout) | |
self.resid_drop = nn.Dropout(dropout) | |
# Static lower-triangular mask to enforce "no peeking at the future" | |
mask = torch.tril(torch.ones(block_size, block_size, dtype=torch.bool)) | |
# Register as a non-trainable buffer so it moves with the model device | |
self.register_buffer("causal_mask", mask) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
B, T, C = x.shape | |
# 1) project to q, k, v then split heads: (B, T, 3C) -> 3 * (B, nH, T, head_dim) | |
qkv = self.qkv(x) # (B, T, 3C) | |
q, k, v = qkv.chunk(3, dim=-1) # each (B, T, C) | |
def split_heads(t: torch.Tensor) -> torch.Tensor: | |
return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
q, k, v = split_heads(q), split_heads(k), split_heads(v) # (B, nH, T, d) | |
# 2) scaled dot-product attention | |
# scores: (B, nH, T, T) | |
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
# 3) causal mask: only allow attending to <= current position | |
mask = self.causal_mask[:T, :T] # (T, T) | |
scores = scores.masked_fill(~mask, float("-inf")) | |
# 4) attention weights over "keys/time" | |
att = F.softmax(scores, dim=-1) # (B, nH, T, T) | |
att = self.attn_drop(att) | |
# 5) weight the values and merge heads back | |
y = att @ v # (B, nH, T, d) | |
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C) | |
# 6) final projection | |
y = self.out_proj(y) | |
y = self.resid_drop(y) | |
return y | |
class FeedForward(nn.Module): | |
"""Position-wise MLP (applied to each token independently).""" | |
def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0) -> None: | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(d_model, mult * d_model), | |
nn.GELU(), | |
nn.Linear(mult * d_model, d_model), | |
nn.Dropout(dropout), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.net(x) | |
class TransformerBlock(nn.Module): | |
"""Pre-LN block = LN -> MHA -> residual; LN -> MLP -> residual.""" | |
def __init__(self, d_model: int, n_heads: int, block_size: int, dropout: float) -> None: | |
super().__init__() | |
self.ln1 = nn.LayerNorm(d_model) | |
self.ln2 = nn.LayerNorm(d_model) | |
self.attn = MultiHeadSelfAttention(d_model, n_heads, block_size, dropout) | |
self.mlp = FeedForward(d_model, mult=4, dropout=dropout) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = x + self.attn(self.ln1(x)) | |
x = x + self.mlp(self.ln2(x)) | |
return x | |
class TinyGPT(nn.Module): | |
"""A tiny decoder-only Transformer language model.""" | |
def __init__(self, cfg: Config) -> None: | |
super().__init__() | |
self.cfg = cfg | |
# (1) token embeddings: map token id -> d_model vector | |
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
# (2) positional embeddings: add order info to each position [0..block_size-1] | |
self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model) | |
self.drop = nn.Dropout(cfg.dropout) | |
self.blocks = nn.ModuleList([ | |
TransformerBlock(cfg.d_model, cfg.n_heads, cfg.block_size, cfg.dropout) | |
for _ in range(cfg.n_layers) | |
]) | |
self.ln_f = nn.LayerNorm(cfg.d_model) | |
# LM head: project back to vocab. Weight tying reduces params & often helps. | |
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
self.head.weight = self.tok_emb.weight # weight tying | |
self.apply(self._init_weights) | |
@staticmethod | |
def _init_weights(m: nn.Module) -> None: | |
if isinstance(m, nn.Linear): | |
nn.init.normal_(m.weight, mean=0.0, std=0.02) | |
if m.bias is not None: | |
nn.init.zeros_(m.bias) | |
elif isinstance(m, nn.Embedding): | |
nn.init.normal_(m.weight, mean=0.0, std=0.02) | |
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
idx: (B, T) int64 token ids | |
targets: (B, T) int64 next-token labels (optional) | |
returns: | |
logits: (B, T, vocab) | |
loss: scalar (if targets provided) | |
""" | |
B, T = idx.shape | |
assert T <= self.cfg.block_size, "Sequence length exceeds block_size" | |
tok = self.tok_emb(idx) # (B, T, C) | |
pos = self.pos_emb(torch.arange(T, device=idx.device)) # (T, C) | |
x = self.drop(tok + pos) # (B, T, C) | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.ln_f(x) | |
logits = self.head(x) # (B, T, vocab) | |
loss = None | |
if targets is not None: | |
# Flatten time & batch for cross-entropy | |
loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B * T)) | |
return logits, loss | |
@torch.no_grad() | |
def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor: | |
"""Autoregressive sampling with temperature and top-k filtering.""" | |
for _ in range(max_new_tokens): | |
idx_cond = idx[:, -self.cfg.block_size:] # crop to context window | |
logits, _ = self(idx_cond) # (B, T, vocab) | |
logits = logits[:, -1, :] # last step: (B, vocab) | |
# temperature: soften / sharpen distribution | |
if temperature != 1.0: | |
logits = logits / temperature | |
# top-k: keep only the k most likely tokens | |
if top_k is not None: | |
v, _ = torch.topk(logits, top_k) | |
thresh = v[:, [-1]] # (B, 1) | |
logits = torch.where(logits < thresh, torch.full_like(logits, float("-inf")), logits) | |
probs = F.softmax(logits, dim=-1) # (B, vocab) | |
next_id = torch.multinomial(probs, num_samples=1) # (B, 1) | |
idx = torch.cat([idx, next_id], dim=1) # append | |
return idx | |
# ----------------------- | |
# 3) Tiny training loop | |
# ----------------------- | |
def train_demo(text: str, cfg: Config, device: torch.device) -> TinyGPT: | |
tok = ByteTokenizer() | |
# Encode entire corpus to a single 1-D tensor of token ids | |
ids = torch.tensor(tok.encode(text), dtype=torch.long) | |
n = int(0.9 * len(ids)) | |
train_data, val_data = ids[:n], ids[n:] | |
# Safety: if text is too short for the chosen block_size, mirror & repeat | |
def ensure_block(src: torch.Tensor) -> torch.Tensor: | |
if len(src) < cfg.block_size + 1: | |
reps = (cfg.block_size + 1 + len(src) - 1) // len(src) | |
return src.repeat(reps) | |
return src | |
train_data = ensure_block(train_data) | |
val_data = ensure_block(val_data) | |
def get_batch(split: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
src = train_data if split == "train" else val_data | |
hi = len(src) - cfg.block_size - 1 | |
ix = torch.randint(0, hi, (cfg.batch_size,)) | |
x = torch.stack([src[i:i+cfg.block_size] for i in ix]) | |
y = torch.stack([src[i+1:i+cfg.block_size+1] for i in ix]) | |
return x.to(device), y.to(device) | |
model = TinyGPT(cfg).to(device) | |
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr) | |
@torch.no_grad() | |
def estimate_loss() -> dict[str, float]: | |
model.eval() | |
out = {} | |
for split in ("train", "val"): | |
losses = [] | |
for _ in range(20): | |
xb, yb = get_batch(split) | |
_, loss = model(xb, yb) | |
losses.append(loss.item()) | |
out[split] = sum(losses) / len(losses) | |
model.train() | |
return out | |
t0 = time.time() | |
for it in range(1, cfg.max_iters + 1): | |
if it % cfg.eval_every == 0 or it == 1: | |
losses = estimate_loss() | |
def ppl(x): return math.exp(x) | |
print(f"iter {it:4d} | train {losses['train']:.3f} (ppl {ppl(losses['train']):.2f})" | |
f" | val {losses['val']:.3f} (ppl {ppl(losses['val']):.2f})") | |
xb, yb = get_batch("train") | |
_, loss = model(xb, yb) | |
opt.zero_grad(set_to_none=True) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # tiny safety | |
opt.step() | |
dt = time.time() - t0 | |
print(f"Finished training in {dt:.1f}s") | |
return model | |
@torch.no_grad() | |
def ask(model, tok, device, question, temperature=0.8, top_k=50, max_new_tokens=120): | |
# A few-shot prefix helps the model learn the "Q: ... A:" pattern at inference time | |
examples = [ | |
("Why does the narrator go to sea?", | |
"To drive off the spleen and regulate the circulation."), | |
("What city are the Manhattoes?", | |
"Manhattan (New York)."), | |
] | |
fewshot = "" | |
for q, a in examples: | |
fewshot += f"Q: {q}\nA: {a}\n\n" | |
prompt = fewshot + f"Q: {question}\nA:" | |
ctx_ids = torch.tensor([tok.encode(prompt)], dtype=torch.long, device=device) | |
out = model.generate(ctx_ids, max_new_tokens=max_new_tokens, | |
temperature=temperature, top_k=top_k)[0].tolist() | |
# Return only the generated answer (strip the prompt) | |
gen = tok.decode(out[len(ctx_ids[0]):]) | |
# Stop at the next Q: if it appears, or at a double newline | |
stop_idx = min([i for i in [ | |
gen.find("\nQ:"), gen.find("\n\n") | |
] if i != -1] or [len(gen)]) | |
return gen[:stop_idx].strip() | |
# ----------------------- | |
# 4) Demo entry point | |
# ----------------------- | |
if __name__ == "__main__": | |
set_seed(42) | |
device = pick_device() | |
print("Device:", device) | |
# A small public-domain excerpt (byte-level is okay here). | |
TEXT = ( | |
"Call me Ishmael. Some years ago—never mind how long precisely—having " | |
"little or no money in my purse, and nothing particular to interest me " | |
"on shore, I thought I would sail about a little and see the watery part " | |
"of the world. It is a way I have of driving off the spleen and " | |
"regulating the circulation. " | |
) | |
cfg = Config( | |
vocab_size=256, # byte-level | |
d_model=128, | |
n_heads=4, | |
n_layers=2, | |
block_size=64, | |
dropout=0.1, | |
lr=3e-3, | |
max_iters=800, # keep short for CPU demo | |
eval_every=100, | |
batch_size=16, | |
) | |
model = train_demo(TEXT, cfg, device) | |
# Sample a short continuation | |
# start = torch.zeros((1, 1), dtype=torch.long, device=device) # start token 0 | |
tok = ByteTokenizer() | |
prompt = "Q: Why does he go to sea?\nA:" | |
start = torch.tensor([tok.encode(prompt)], dtype=torch.long, device=device) | |
ids = model.generate(start, max_new_tokens=300, temperature=0.8, top_k=50)[0].tolist() | |
print("\n=== SAMPLE ===") | |
print(ByteTokenizer().decode(ids)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment