Skip to content

Instantly share code, notes, and snippets.

@tuttlem
Created August 13, 2025 11:41
Show Gist options
  • Save tuttlem/1c848af07c79992e0517cbd24bd3873e to your computer and use it in GitHub Desktop.
Save tuttlem/1c848af07c79992e0517cbd24bd3873e to your computer and use it in GitHub Desktop.
A friendly, didactic, CPU-first, decoder-only Transformer (GPT-style).
# mini_transformer.py
# A friendly, didactic, CPU-first, decoder-only Transformer (GPT-style).
# ---------------------------------------------------------------------
# Dependencies:
# torch
# numpy
#
# How to run:
# python mini_transformer.py
#
# Things to tweak:
# - TEXT (your training corpus)
# - Config(...) at the bottom (n_layers, n_heads, d_model, block_size, etc.)
# - sampling args in generate(...)
from __future__ import annotations
import math, os, time, warnings, random
from dataclasses import dataclass
from typing import Tuple, Optional
# --- Make sure we stay on CPU and silence scary CUDA banners for old GPUs ---
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
warnings.filterwarnings("ignore", message=".*cuda capability.*")
warnings.filterwarnings("ignore", message=".*not compatible with the current PyTorch installation.*")
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------
# 0) Utilities / Config
# -----------------------
def set_seed(seed: int = 1337) -> None:
random.seed(seed)
torch.manual_seed(seed)
def pick_device() -> torch.device:
# We intentionally use CPU for portability and predictable behavior.
return torch.device("cpu")
@dataclass
class Config:
vocab_size: int = 256 # byte-level tokenizer => 256 possible values
d_model: int = 128
n_heads: int = 4
n_layers: int = 2
block_size: int = 64 # context length (tokens)
dropout: float = 0.1
lr: float = 3e-3
max_iters: int = 1000
eval_every: int = 100
batch_size: int = 16
# -----------------------
# 1) Tokenizer: byte-level
# -----------------------
class ByteTokenizer:
"""UTF-8 bytes <-> ints in [0..255]. Simple & dependency-free."""
def __init__(self) -> None:
self.vocab_size = 256
def encode(self, s: str) -> list[int]:
return list(s.encode("utf-8"))
def decode(self, ids: list[int]) -> str:
return bytes(ids).decode("utf-8", errors="ignore")
# ---------------------------------
# 2) Core building blocks (modules)
# ---------------------------------
class MultiHeadSelfAttention(nn.Module):
"""Causal multi-head self-attention.
Shapes:
x: (B, T, C) where C = d_model
returns: (B, T, C)
"""
def __init__(self, d_model: int, n_heads: int, block_size: int, dropout: float) -> None:
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.block_size = block_size
# One linear does Q, K, V together (faster & common)
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.attn_drop = nn.Dropout(dropout)
self.resid_drop = nn.Dropout(dropout)
# Static lower-triangular mask to enforce "no peeking at the future"
mask = torch.tril(torch.ones(block_size, block_size, dtype=torch.bool))
# Register as a non-trainable buffer so it moves with the model device
self.register_buffer("causal_mask", mask)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
# 1) project to q, k, v then split heads: (B, T, 3C) -> 3 * (B, nH, T, head_dim)
qkv = self.qkv(x) # (B, T, 3C)
q, k, v = qkv.chunk(3, dim=-1) # each (B, T, C)
def split_heads(t: torch.Tensor) -> torch.Tensor:
return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q, k, v = split_heads(q), split_heads(k), split_heads(v) # (B, nH, T, d)
# 2) scaled dot-product attention
# scores: (B, nH, T, T)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 3) causal mask: only allow attending to <= current position
mask = self.causal_mask[:T, :T] # (T, T)
scores = scores.masked_fill(~mask, float("-inf"))
# 4) attention weights over "keys/time"
att = F.softmax(scores, dim=-1) # (B, nH, T, T)
att = self.attn_drop(att)
# 5) weight the values and merge heads back
y = att @ v # (B, nH, T, d)
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
# 6) final projection
y = self.out_proj(y)
y = self.resid_drop(y)
return y
class FeedForward(nn.Module):
"""Position-wise MLP (applied to each token independently)."""
def __init__(self, d_model: int, mult: int = 4, dropout: float = 0.0) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, mult * d_model),
nn.GELU(),
nn.Linear(mult * d_model, d_model),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class TransformerBlock(nn.Module):
"""Pre-LN block = LN -> MHA -> residual; LN -> MLP -> residual."""
def __init__(self, d_model: int, n_heads: int, block_size: int, dropout: float) -> None:
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(d_model, n_heads, block_size, dropout)
self.mlp = FeedForward(d_model, mult=4, dropout=dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
class TinyGPT(nn.Module):
"""A tiny decoder-only Transformer language model."""
def __init__(self, cfg: Config) -> None:
super().__init__()
self.cfg = cfg
# (1) token embeddings: map token id -> d_model vector
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
# (2) positional embeddings: add order info to each position [0..block_size-1]
self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([
TransformerBlock(cfg.d_model, cfg.n_heads, cfg.block_size, cfg.dropout)
for _ in range(cfg.n_layers)
])
self.ln_f = nn.LayerNorm(cfg.d_model)
# LM head: project back to vocab. Weight tying reduces params & often helps.
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.head.weight = self.tok_emb.weight # weight tying
self.apply(self._init_weights)
@staticmethod
def _init_weights(m: nn.Module) -> None:
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
idx: (B, T) int64 token ids
targets: (B, T) int64 next-token labels (optional)
returns:
logits: (B, T, vocab)
loss: scalar (if targets provided)
"""
B, T = idx.shape
assert T <= self.cfg.block_size, "Sequence length exceeds block_size"
tok = self.tok_emb(idx) # (B, T, C)
pos = self.pos_emb(torch.arange(T, device=idx.device)) # (T, C)
x = self.drop(tok + pos) # (B, T, C)
for blk in self.blocks:
x = blk(x)
x = self.ln_f(x)
logits = self.head(x) # (B, T, vocab)
loss = None
if targets is not None:
# Flatten time & batch for cross-entropy
loss = F.cross_entropy(logits.view(B * T, -1), targets.view(B * T))
return logits, loss
@torch.no_grad()
def generate(self, idx: torch.Tensor, max_new_tokens: int, temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:
"""Autoregressive sampling with temperature and top-k filtering."""
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.cfg.block_size:] # crop to context window
logits, _ = self(idx_cond) # (B, T, vocab)
logits = logits[:, -1, :] # last step: (B, vocab)
# temperature: soften / sharpen distribution
if temperature != 1.0:
logits = logits / temperature
# top-k: keep only the k most likely tokens
if top_k is not None:
v, _ = torch.topk(logits, top_k)
thresh = v[:, [-1]] # (B, 1)
logits = torch.where(logits < thresh, torch.full_like(logits, float("-inf")), logits)
probs = F.softmax(logits, dim=-1) # (B, vocab)
next_id = torch.multinomial(probs, num_samples=1) # (B, 1)
idx = torch.cat([idx, next_id], dim=1) # append
return idx
# -----------------------
# 3) Tiny training loop
# -----------------------
def train_demo(text: str, cfg: Config, device: torch.device) -> TinyGPT:
tok = ByteTokenizer()
# Encode entire corpus to a single 1-D tensor of token ids
ids = torch.tensor(tok.encode(text), dtype=torch.long)
n = int(0.9 * len(ids))
train_data, val_data = ids[:n], ids[n:]
# Safety: if text is too short for the chosen block_size, mirror & repeat
def ensure_block(src: torch.Tensor) -> torch.Tensor:
if len(src) < cfg.block_size + 1:
reps = (cfg.block_size + 1 + len(src) - 1) // len(src)
return src.repeat(reps)
return src
train_data = ensure_block(train_data)
val_data = ensure_block(val_data)
def get_batch(split: str) -> Tuple[torch.Tensor, torch.Tensor]:
src = train_data if split == "train" else val_data
hi = len(src) - cfg.block_size - 1
ix = torch.randint(0, hi, (cfg.batch_size,))
x = torch.stack([src[i:i+cfg.block_size] for i in ix])
y = torch.stack([src[i+1:i+cfg.block_size+1] for i in ix])
return x.to(device), y.to(device)
model = TinyGPT(cfg).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
@torch.no_grad()
def estimate_loss() -> dict[str, float]:
model.eval()
out = {}
for split in ("train", "val"):
losses = []
for _ in range(20):
xb, yb = get_batch(split)
_, loss = model(xb, yb)
losses.append(loss.item())
out[split] = sum(losses) / len(losses)
model.train()
return out
t0 = time.time()
for it in range(1, cfg.max_iters + 1):
if it % cfg.eval_every == 0 or it == 1:
losses = estimate_loss()
def ppl(x): return math.exp(x)
print(f"iter {it:4d} | train {losses['train']:.3f} (ppl {ppl(losses['train']):.2f})"
f" | val {losses['val']:.3f} (ppl {ppl(losses['val']):.2f})")
xb, yb = get_batch("train")
_, loss = model(xb, yb)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # tiny safety
opt.step()
dt = time.time() - t0
print(f"Finished training in {dt:.1f}s")
return model
@torch.no_grad()
def ask(model, tok, device, question, temperature=0.8, top_k=50, max_new_tokens=120):
# A few-shot prefix helps the model learn the "Q: ... A:" pattern at inference time
examples = [
("Why does the narrator go to sea?",
"To drive off the spleen and regulate the circulation."),
("What city are the Manhattoes?",
"Manhattan (New York)."),
]
fewshot = ""
for q, a in examples:
fewshot += f"Q: {q}\nA: {a}\n\n"
prompt = fewshot + f"Q: {question}\nA:"
ctx_ids = torch.tensor([tok.encode(prompt)], dtype=torch.long, device=device)
out = model.generate(ctx_ids, max_new_tokens=max_new_tokens,
temperature=temperature, top_k=top_k)[0].tolist()
# Return only the generated answer (strip the prompt)
gen = tok.decode(out[len(ctx_ids[0]):])
# Stop at the next Q: if it appears, or at a double newline
stop_idx = min([i for i in [
gen.find("\nQ:"), gen.find("\n\n")
] if i != -1] or [len(gen)])
return gen[:stop_idx].strip()
# -----------------------
# 4) Demo entry point
# -----------------------
if __name__ == "__main__":
set_seed(42)
device = pick_device()
print("Device:", device)
# A small public-domain excerpt (byte-level is okay here).
TEXT = (
"Call me Ishmael. Some years ago—never mind how long precisely—having "
"little or no money in my purse, and nothing particular to interest me "
"on shore, I thought I would sail about a little and see the watery part "
"of the world. It is a way I have of driving off the spleen and "
"regulating the circulation. "
)
cfg = Config(
vocab_size=256, # byte-level
d_model=128,
n_heads=4,
n_layers=2,
block_size=64,
dropout=0.1,
lr=3e-3,
max_iters=800, # keep short for CPU demo
eval_every=100,
batch_size=16,
)
model = train_demo(TEXT, cfg, device)
# Sample a short continuation
# start = torch.zeros((1, 1), dtype=torch.long, device=device) # start token 0
tok = ByteTokenizer()
prompt = "Q: Why does he go to sea?\nA:"
start = torch.tensor([tok.encode(prompt)], dtype=torch.long, device=device)
ids = model.generate(start, max_new_tokens=300, temperature=0.8, top_k=50)[0].tolist()
print("\n=== SAMPLE ===")
print(ByteTokenizer().decode(ids))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment