Last active
August 13, 2025 15:58
-
-
Save ProExpertProg/3aae3d8a0eaed7aeadbd9e88c3d00985 to your computer and use it in GitHub Desktop.
Mirage starting example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Callable, Sequence, Any | |
import torch | |
from torch import nn, fx | |
from torch.library import Library | |
import torch.nn.functional as F | |
import torch._inductor | |
import torch._inductor.compile_fx | |
mirage_lib = Library("mirage", "FRAGMENT") # noqa | |
def direct_register_custom_op( | |
op_name: str, | |
op_func: Callable, | |
mutates_args: list[str] = [], | |
fake_impl: Optional[Callable] = None, | |
target_lib: Optional[Library] = None, | |
dispatch_key: str = "CUDA", | |
tags: tuple[torch.Tag, ...] = (), | |
): | |
""" | |
`torch.library.custom_op` can have significant overhead because it | |
needs to consider complicated dispatching logic. This function | |
directly registers a custom op and dispatches it to the CUDA backend. | |
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 | |
for more details. | |
IMPORTANT: the lifetime of the operator is tied to the lifetime of the | |
library object. If you want to bind the operator to a different library, | |
make sure the library object is alive when the operator is used. | |
""" | |
import torch.library | |
schema_str = torch.library.infer_schema(op_func, | |
mutates_args=mutates_args) | |
my_lib = target_lib or mirage_lib | |
my_lib.define(op_name + schema_str, tags=tags) | |
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) | |
if fake_impl is not None: | |
my_lib._register_fake(op_name, fake_impl) | |
# ============================================================ | |
# Mirage placeholder op registration | |
# ============================================================ | |
def rms_norm(input: torch.Tensor, | |
weight: torch.Tensor, | |
residual: Optional[torch.Tensor], | |
epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]: | |
# Never actually called | |
print("rms_norm") | |
if residual is None: | |
residual = input | |
return torch.zeros_like(input), residual | |
def rms_norm_fake(input: torch.Tensor, | |
weight: torch.Tensor, | |
residual: Optional[torch.Tensor], | |
epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]: | |
return torch.empty_like(input), torch.empty_like(input) | |
direct_register_custom_op("rms_norm", rms_norm, fake_impl=rms_norm_fake) | |
def silu_mul(input: torch.Tensor) -> torch.Tensor: | |
# Never actually called | |
print("silu_mul") | |
return torch.zeros_like(input[..., 0:input.shape[1] // 2]) | |
def silu_mul_fake(input: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(input[..., 0:input.shape[1] // 2]) | |
direct_register_custom_op("silu_mul", silu_mul, fake_impl=silu_mul_fake) | |
def rope(q: torch.Tensor, | |
k: torch.Tensor, | |
positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
# Never actually called | |
print("rope") | |
return torch.zeros_like(q), torch.zeros_like(k) | |
def rope_fake(q: torch.Tensor, | |
k: torch.Tensor, | |
positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
return torch.empty_like(q), torch.empty_like(k) | |
direct_register_custom_op("rope", rope, fake_impl=rope_fake) | |
def quantize(input: torch.Tensor, | |
scale: Optional[torch.Tensor], | |
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: | |
# Never actually called | |
print("quantize") | |
return torch.zeros_like(input, dtype=dtype), scale | |
def quantize_fake(input: torch.Tensor, | |
scale: Optional[torch.Tensor], | |
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: | |
return torch.empty_like(input, dtype=dtype), scale | |
direct_register_custom_op("quantize", quantize, fake_impl=quantize_fake) | |
def attention(q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor) -> torch.Tensor: | |
# Never actually called | |
print("attention") | |
return torch.zeros_like(q) | |
def attention_fake(q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(q) | |
direct_register_custom_op("attention", attention, fake_impl=attention_fake) | |
# ============================================================ | |
# Example PyTorch-model | |
# ============================================================ | |
class SimpleLlamaLayer(nn.Module): | |
def __init__(self, | |
hidden_dim: int = 4096, | |
num_heads: int = 32, | |
num_kv_heads: int = 8, | |
head_size: int = 128, | |
dtype: torch.dtype = torch.float16, | |
qdtype: Optional[torch.dtype] = None, | |
): | |
super().__init__() | |
if qdtype is None: | |
qdtype = dtype | |
self.hidden_dim = hidden_dim | |
self.head_size = head_size | |
self.num_heads = num_heads | |
self.num_kv_heads = num_kv_heads | |
self.dtype = dtype | |
self.qdtype = qdtype | |
self.quantized = qdtype != dtype | |
rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda") | |
rand_wq = lambda *dims, **kwargs: rand_w(*dims, **kwargs).to(dtype=qdtype).t().contiguous().t() # column-major for scaled-mm | |
self.weights = { | |
"qkv_proj": rand_wq(hidden_dim, (num_heads + num_kv_heads * 2) * head_size), | |
"o_proj": rand_wq(hidden_dim, hidden_dim), | |
"gate_up_proj": rand_wq(hidden_dim, 2 * hidden_dim), | |
"down_proj": rand_wq(hidden_dim, hidden_dim), | |
"input_norm": rand_w(hidden_dim), | |
"post_attn_norm": rand_w(hidden_dim), | |
} | |
if self.quantized: | |
self.scales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights} | |
self.wscales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights} | |
def _linear(self, input: torch.Tensor, name: str) -> torch.Tensor: | |
weight = self.weights[name] | |
if not self.quantized: | |
return input @ weight | |
scale_a, scale_b = self.scales[name], self.wscales[name] | |
qinput, scale_a = torch.ops.mirage.quantize(input, scale_a, dtype=self.qdtype) | |
return torch._scaled_mm(qinput, weight, scale_a=scale_a, scale_b=scale_b) | |
def forward(self, input: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor) \ | |
-> tuple[torch.Tensor, torch.Tensor]: | |
input_norm, residual = torch.ops.mirage.rms_norm(input, self.weights["input_norm"], residual) | |
qkv = self._linear(input_norm, "qkv_proj") | |
q, k, v = qkv.split_with_sizes([ | |
self.num_heads * self.head_size, | |
self.num_kv_heads * self.head_size, | |
self.num_kv_heads * self.head_size | |
], dim=-1) | |
q, k = torch.ops.mirage.rope(q, k, positions) | |
out = torch.ops.mirage.attention(q, k, v) | |
out2 = self._linear(out, "o_proj") | |
out_norm, residual = torch.ops.mirage.rms_norm(out2, self.weights["post_attn_norm"], residual) | |
# mlp | |
up_gate = self._linear(out_norm, "gate_up_proj") | |
silu = torch.ops.mirage.silu_mul(up_gate) | |
down = self._linear(silu, "down_proj") | |
return down, residual | |
class SimpleLlama(nn.Module): | |
def __init__(self, | |
num_layers: int = 32, | |
vocab_size: int = 128256, | |
hidden_dim: int = 4096, | |
num_heads: int = 32, | |
num_kv_heads: int = 8, | |
head_size: int = 128, | |
dtype: torch.dtype = torch.float16, | |
qdtype: Optional[torch.dtype] = None, | |
): | |
super().__init__() | |
rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda") | |
self.weights = { | |
"embedding": rand_w(vocab_size, hidden_dim), | |
"out_norm": rand_w(hidden_dim), | |
} | |
self.layers = nn.ModuleList([SimpleLlamaLayer( | |
hidden_dim=hidden_dim, | |
num_heads=num_heads, | |
num_kv_heads=num_kv_heads, | |
head_size=head_size, | |
dtype=dtype, | |
qdtype=qdtype, | |
) for _ in range(num_layers)]) | |
def forward(self, input: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: | |
x_emb = F.embedding(input, self.weights["embedding"]) | |
x, residual = x_emb, None | |
for layer in self.layers: | |
x, residual = layer(x, residual, positions) | |
x, _ = torch.ops.mirage.rms_norm(x, self.weights["out_norm"], residual) | |
return x | |
# ============================================================ | |
# Backends | |
# ============================================================ | |
class AotBackend: | |
"""Boilerplace to get Mirage backend to """ | |
def __init__(self, compile_fn: Callable[[fx.GraphModule, Sequence], Callable[[Sequence], Any]]): | |
self.compile_fn = compile_fn | |
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence): | |
from torch._dynamo.backends.common import aot_autograd | |
return aot_autograd( | |
fw_compiler=self.compile_fn, | |
decompositions=torch._inductor.compile_fx.select_decomp_table(), | |
)(graph, example_inputs) | |
# ============================================================ | |
# Skeleton for the actual Mirage backend that takes a Mirage-friendly fx graph and compiles it. | |
# ============================================================ | |
class MirageBackend: | |
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence): | |
""" | |
Receives normalized (post-grad) IR. | |
""" | |
print(graph.graph.python_code(root_module="self").src) | |
return self.run | |
def run(self, *args, **kwargs): | |
print(f"Forward called with {len(args)=} args and {len(kwargs)=} kwargs.") | |
return torch.empty_like(args[0], dtype=torch.float16) | |
torch.set_default_device("cuda") | |
model = SimpleLlama() | |
inputs = [torch.randint(0, 4096, (5,)), torch.arange(0, 4096)] | |
model(*inputs) | |
compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True) | |
compiled_model(*inputs) | |
qmodel = SimpleLlama(qdtype=torch.float8_e4m3fn) | |
qmodel(*inputs) | |
compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True) | |
compiled_qmodel(*inputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Resulting fx.Graph for 4 layers:
Resulting fx.Graph for 4 layers (quantized to fp8):