Skip to content

Instantly share code, notes, and snippets.

@ProExpertProg
Last active August 13, 2025 15:58
Show Gist options
  • Save ProExpertProg/3aae3d8a0eaed7aeadbd9e88c3d00985 to your computer and use it in GitHub Desktop.
Save ProExpertProg/3aae3d8a0eaed7aeadbd9e88c3d00985 to your computer and use it in GitHub Desktop.
Mirage starting example
from typing import Optional, Callable, Sequence, Any
import torch
from torch import nn, fx
from torch.library import Library
import torch.nn.functional as F
import torch._inductor
import torch._inductor.compile_fx
mirage_lib = Library("mirage", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str] = [],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
import torch.library
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
my_lib = target_lib or mirage_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
# ============================================================
# Mirage placeholder op registration
# ============================================================
def rms_norm(input: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor],
epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
# Never actually called
print("rms_norm")
if residual is None:
residual = input
return torch.zeros_like(input), residual
def rms_norm_fake(input: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor],
epsilon: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(input), torch.empty_like(input)
direct_register_custom_op("rms_norm", rms_norm, fake_impl=rms_norm_fake)
def silu_mul(input: torch.Tensor) -> torch.Tensor:
# Never actually called
print("silu_mul")
return torch.zeros_like(input[..., 0:input.shape[1] // 2])
def silu_mul_fake(input: torch.Tensor) -> torch.Tensor:
return torch.empty_like(input[..., 0:input.shape[1] // 2])
direct_register_custom_op("silu_mul", silu_mul, fake_impl=silu_mul_fake)
def rope(q: torch.Tensor,
k: torch.Tensor,
positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Never actually called
print("rope")
return torch.zeros_like(q), torch.zeros_like(k)
def rope_fake(q: torch.Tensor,
k: torch.Tensor,
positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(q), torch.empty_like(k)
direct_register_custom_op("rope", rope, fake_impl=rope_fake)
def quantize(input: torch.Tensor,
scale: Optional[torch.Tensor],
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
# Never actually called
print("quantize")
return torch.zeros_like(input, dtype=dtype), scale
def quantize_fake(input: torch.Tensor,
scale: Optional[torch.Tensor],
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(input, dtype=dtype), scale
direct_register_custom_op("quantize", quantize, fake_impl=quantize_fake)
def attention(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor) -> torch.Tensor:
# Never actually called
print("attention")
return torch.zeros_like(q)
def attention_fake(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor) -> torch.Tensor:
return torch.empty_like(q)
direct_register_custom_op("attention", attention, fake_impl=attention_fake)
# ============================================================
# Example PyTorch-model
# ============================================================
class SimpleLlamaLayer(nn.Module):
def __init__(self,
hidden_dim: int = 4096,
num_heads: int = 32,
num_kv_heads: int = 8,
head_size: int = 128,
dtype: torch.dtype = torch.float16,
qdtype: Optional[torch.dtype] = None,
):
super().__init__()
if qdtype is None:
qdtype = dtype
self.hidden_dim = hidden_dim
self.head_size = head_size
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.dtype = dtype
self.qdtype = qdtype
self.quantized = qdtype != dtype
rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda")
rand_wq = lambda *dims, **kwargs: rand_w(*dims, **kwargs).to(dtype=qdtype).t().contiguous().t() # column-major for scaled-mm
self.weights = {
"qkv_proj": rand_wq(hidden_dim, (num_heads + num_kv_heads * 2) * head_size),
"o_proj": rand_wq(hidden_dim, hidden_dim),
"gate_up_proj": rand_wq(hidden_dim, 2 * hidden_dim),
"down_proj": rand_wq(hidden_dim, hidden_dim),
"input_norm": rand_w(hidden_dim),
"post_attn_norm": rand_w(hidden_dim),
}
if self.quantized:
self.scales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights}
self.wscales = {k: torch.ones(1, 1, dtype=torch.float32) for k in self.weights}
def _linear(self, input: torch.Tensor, name: str) -> torch.Tensor:
weight = self.weights[name]
if not self.quantized:
return input @ weight
scale_a, scale_b = self.scales[name], self.wscales[name]
qinput, scale_a = torch.ops.mirage.quantize(input, scale_a, dtype=self.qdtype)
return torch._scaled_mm(qinput, weight, scale_a=scale_a, scale_b=scale_b)
def forward(self, input: torch.Tensor, residual: torch.Tensor, positions: torch.Tensor) \
-> tuple[torch.Tensor, torch.Tensor]:
input_norm, residual = torch.ops.mirage.rms_norm(input, self.weights["input_norm"], residual)
qkv = self._linear(input_norm, "qkv_proj")
q, k, v = qkv.split_with_sizes([
self.num_heads * self.head_size,
self.num_kv_heads * self.head_size,
self.num_kv_heads * self.head_size
], dim=-1)
q, k = torch.ops.mirage.rope(q, k, positions)
out = torch.ops.mirage.attention(q, k, v)
out2 = self._linear(out, "o_proj")
out_norm, residual = torch.ops.mirage.rms_norm(out2, self.weights["post_attn_norm"], residual)
# mlp
up_gate = self._linear(out_norm, "gate_up_proj")
silu = torch.ops.mirage.silu_mul(up_gate)
down = self._linear(silu, "down_proj")
return down, residual
class SimpleLlama(nn.Module):
def __init__(self,
num_layers: int = 32,
vocab_size: int = 128256,
hidden_dim: int = 4096,
num_heads: int = 32,
num_kv_heads: int = 8,
head_size: int = 128,
dtype: torch.dtype = torch.float16,
qdtype: Optional[torch.dtype] = None,
):
super().__init__()
rand_w = lambda *dims, **kwargs: torch.randn(*dims, **kwargs, dtype=dtype, device="cuda")
self.weights = {
"embedding": rand_w(vocab_size, hidden_dim),
"out_norm": rand_w(hidden_dim),
}
self.layers = nn.ModuleList([SimpleLlamaLayer(
hidden_dim=hidden_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
qdtype=qdtype,
) for _ in range(num_layers)])
def forward(self, input: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
x_emb = F.embedding(input, self.weights["embedding"])
x, residual = x_emb, None
for layer in self.layers:
x, residual = layer(x, residual, positions)
x, _ = torch.ops.mirage.rms_norm(x, self.weights["out_norm"], residual)
return x
# ============================================================
# Backends
# ============================================================
class AotBackend:
"""Boilerplace to get Mirage backend to """
def __init__(self, compile_fn: Callable[[fx.GraphModule, Sequence], Callable[[Sequence], Any]]):
self.compile_fn = compile_fn
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence):
from torch._dynamo.backends.common import aot_autograd
return aot_autograd(
fw_compiler=self.compile_fn,
decompositions=torch._inductor.compile_fx.select_decomp_table(),
)(graph, example_inputs)
# ============================================================
# Skeleton for the actual Mirage backend that takes a Mirage-friendly fx graph and compiles it.
# ============================================================
class MirageBackend:
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence):
"""
Receives normalized (post-grad) IR.
"""
print(graph.graph.python_code(root_module="self").src)
return self.run
def run(self, *args, **kwargs):
print(f"Forward called with {len(args)=} args and {len(kwargs)=} kwargs.")
return torch.empty_like(args[0], dtype=torch.float16)
torch.set_default_device("cuda")
model = SimpleLlama()
inputs = [torch.randint(0, 4096, (5,)), torch.arange(0, 4096)]
model(*inputs)
compiled_model = torch.compile(model, backend=AotBackend(MirageBackend()), fullgraph=True)
compiled_model(*inputs)
qmodel = SimpleLlama(qdtype=torch.float8_e4m3fn)
qmodel(*inputs)
compiled_qmodel = torch.compile(qmodel, backend=AotBackend(MirageBackend()), fullgraph=True)
compiled_qmodel(*inputs)
@ProExpertProg
Copy link
Author

Resulting fx.Graph for 4 layers:

def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1):
    embedding = torch.ops.aten.embedding.default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    rms_norm = torch.ops.mirage.rms_norm.default(embedding, arg2_1, None);  embedding = arg2_1 = None
    getitem = rms_norm[0]
    getitem_1 = rms_norm[1];  rms_norm = None
    mm = torch.ops.aten.mm.default(getitem, arg3_1);  getitem = arg3_1 = None
    split_with_sizes = torch.ops.aten.split_with_sizes.default(mm, [4096, 1024, 1024], -1);  mm = None
    getitem_2 = split_with_sizes[0]
    getitem_3 = split_with_sizes[1]
    getitem_4 = split_with_sizes[2];  split_with_sizes = None
    rope = torch.ops.mirage.rope.default(getitem_2, getitem_3, arg4_1);  getitem_2 = getitem_3 = None
    getitem_5 = rope[0]
    getitem_6 = rope[1];  rope = None
    attention = torch.ops.mirage.attention.default(getitem_5, getitem_6, getitem_4);  getitem_5 = getitem_6 = getitem_4 = None
    mm_1 = torch.ops.aten.mm.default(attention, arg5_1);  attention = arg5_1 = None
    rms_norm_1 = torch.ops.mirage.rms_norm.default(mm_1, arg6_1, getitem_1);  mm_1 = arg6_1 = getitem_1 = None
    getitem_7 = rms_norm_1[0]
    getitem_8 = rms_norm_1[1];  rms_norm_1 = None
    mm_2 = torch.ops.aten.mm.default(getitem_7, arg7_1);  getitem_7 = arg7_1 = None
    silu_mul = torch.ops.mirage.silu_mul.default(mm_2);  mm_2 = None
    mm_3 = torch.ops.aten.mm.default(silu_mul, arg8_1);  silu_mul = arg8_1 = None
    rms_norm_2 = torch.ops.mirage.rms_norm.default(mm_3, arg9_1, getitem_8);  mm_3 = arg9_1 = getitem_8 = None
    getitem_9 = rms_norm_2[0]
    getitem_10 = rms_norm_2[1];  rms_norm_2 = None
    mm_4 = torch.ops.aten.mm.default(getitem_9, arg10_1);  getitem_9 = arg10_1 = None
    split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(mm_4, [4096, 1024, 1024], -1);  mm_4 = None
    getitem_11 = split_with_sizes_1[0]
    getitem_12 = split_with_sizes_1[1]
    getitem_13 = split_with_sizes_1[2];  split_with_sizes_1 = None
    rope_1 = torch.ops.mirage.rope.default(getitem_11, getitem_12, arg4_1);  getitem_11 = getitem_12 = None
    getitem_14 = rope_1[0]
    getitem_15 = rope_1[1];  rope_1 = None
    attention_1 = torch.ops.mirage.attention.default(getitem_14, getitem_15, getitem_13);  getitem_14 = getitem_15 = getitem_13 = None
    mm_5 = torch.ops.aten.mm.default(attention_1, arg11_1);  attention_1 = arg11_1 = None
    rms_norm_3 = torch.ops.mirage.rms_norm.default(mm_5, arg12_1, getitem_10);  mm_5 = arg12_1 = getitem_10 = None
    getitem_16 = rms_norm_3[0]
    getitem_17 = rms_norm_3[1];  rms_norm_3 = None
    mm_6 = torch.ops.aten.mm.default(getitem_16, arg13_1);  getitem_16 = arg13_1 = None
    silu_mul_1 = torch.ops.mirage.silu_mul.default(mm_6);  mm_6 = None
    mm_7 = torch.ops.aten.mm.default(silu_mul_1, arg14_1);  silu_mul_1 = arg14_1 = None
    rms_norm_4 = torch.ops.mirage.rms_norm.default(mm_7, arg15_1, getitem_17);  mm_7 = arg15_1 = getitem_17 = None
    getitem_18 = rms_norm_4[0]
    getitem_19 = rms_norm_4[1];  rms_norm_4 = None
    mm_8 = torch.ops.aten.mm.default(getitem_18, arg16_1);  getitem_18 = arg16_1 = None
    split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(mm_8, [4096, 1024, 1024], -1);  mm_8 = None
    getitem_20 = split_with_sizes_2[0]
    getitem_21 = split_with_sizes_2[1]
    getitem_22 = split_with_sizes_2[2];  split_with_sizes_2 = None
    rope_2 = torch.ops.mirage.rope.default(getitem_20, getitem_21, arg4_1);  getitem_20 = getitem_21 = None
    getitem_23 = rope_2[0]
    getitem_24 = rope_2[1];  rope_2 = None
    attention_2 = torch.ops.mirage.attention.default(getitem_23, getitem_24, getitem_22);  getitem_23 = getitem_24 = getitem_22 = None
    mm_9 = torch.ops.aten.mm.default(attention_2, arg17_1);  attention_2 = arg17_1 = None
    rms_norm_5 = torch.ops.mirage.rms_norm.default(mm_9, arg18_1, getitem_19);  mm_9 = arg18_1 = getitem_19 = None
    getitem_25 = rms_norm_5[0]
    getitem_26 = rms_norm_5[1];  rms_norm_5 = None
    mm_10 = torch.ops.aten.mm.default(getitem_25, arg19_1);  getitem_25 = arg19_1 = None
    silu_mul_2 = torch.ops.mirage.silu_mul.default(mm_10);  mm_10 = None
    mm_11 = torch.ops.aten.mm.default(silu_mul_2, arg20_1);  silu_mul_2 = arg20_1 = None
    rms_norm_6 = torch.ops.mirage.rms_norm.default(mm_11, arg21_1, getitem_26);  mm_11 = arg21_1 = getitem_26 = None
    getitem_27 = rms_norm_6[0]
    getitem_28 = rms_norm_6[1];  rms_norm_6 = None
    mm_12 = torch.ops.aten.mm.default(getitem_27, arg22_1);  getitem_27 = arg22_1 = None
    split_with_sizes_3 = torch.ops.aten.split_with_sizes.default(mm_12, [4096, 1024, 1024], -1);  mm_12 = None
    getitem_29 = split_with_sizes_3[0]
    getitem_30 = split_with_sizes_3[1]
    getitem_31 = split_with_sizes_3[2];  split_with_sizes_3 = None
    rope_3 = torch.ops.mirage.rope.default(getitem_29, getitem_30, arg4_1);  getitem_29 = getitem_30 = arg4_1 = None
    getitem_32 = rope_3[0]
    getitem_33 = rope_3[1];  rope_3 = None
    attention_3 = torch.ops.mirage.attention.default(getitem_32, getitem_33, getitem_31);  getitem_32 = getitem_33 = getitem_31 = None
    mm_13 = torch.ops.aten.mm.default(attention_3, arg23_1);  attention_3 = arg23_1 = None
    rms_norm_7 = torch.ops.mirage.rms_norm.default(mm_13, arg24_1, getitem_28);  mm_13 = arg24_1 = getitem_28 = None
    getitem_34 = rms_norm_7[0]
    getitem_35 = rms_norm_7[1];  rms_norm_7 = None
    mm_14 = torch.ops.aten.mm.default(getitem_34, arg25_1);  getitem_34 = arg25_1 = None
    silu_mul_3 = torch.ops.mirage.silu_mul.default(mm_14);  mm_14 = None
    mm_15 = torch.ops.aten.mm.default(silu_mul_3, arg26_1);  silu_mul_3 = arg26_1 = None
    rms_norm_8 = torch.ops.mirage.rms_norm.default(mm_15, arg27_1, getitem_35);  mm_15 = arg27_1 = getitem_35 = None
    getitem_36 = rms_norm_8[0];  rms_norm_8 = None
    return (getitem_36,)

Resulting fx.Graph for 4 layers (quantized to fp8):

def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1, arg9_1, arg10_1, arg11_1, arg12_1, arg13_1, arg14_1, arg15_1, arg16_1, arg17_1, arg18_1, arg19_1, arg20_1, arg21_1, arg22_1, arg23_1, arg24_1, arg25_1, arg26_1, arg27_1, arg28_1, arg29_1, arg30_1, arg31_1, arg32_1, arg33_1, arg34_1, arg35_1, arg36_1, arg37_1, arg38_1, arg39_1, arg40_1, arg41_1, arg42_1, arg43_1, arg44_1, arg45_1, arg46_1, arg47_1, arg48_1, arg49_1, arg50_1, arg51_1, arg52_1, arg53_1, arg54_1, arg55_1, arg56_1, arg57_1, arg58_1, arg59_1):
    embedding = torch.ops.aten.embedding.default(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
    rms_norm = torch.ops.mirage.rms_norm.default(embedding, arg2_1, None);  embedding = arg2_1 = None
    getitem = rms_norm[0]
    getitem_1 = rms_norm[1];  rms_norm = None
    quantize = torch.ops.mirage.quantize.default(getitem, arg4_1, torch.float8_e4m3fnuz);  getitem = arg4_1 = None
    getitem_2 = quantize[0]
    getitem_3 = quantize[1];  quantize = None
    _scaled_mm = torch.ops.aten._scaled_mm.default(getitem_2, arg3_1, getitem_3, arg5_1);  getitem_2 = arg3_1 = getitem_3 = arg5_1 = None
    split_with_sizes = torch.ops.aten.split_with_sizes.default(_scaled_mm, [4096, 1024, 1024], -1);  _scaled_mm = None
    getitem_4 = split_with_sizes[0]
    getitem_5 = split_with_sizes[1]
    getitem_6 = split_with_sizes[2];  split_with_sizes = None
    rope = torch.ops.mirage.rope.default(getitem_4, getitem_5, arg6_1);  getitem_4 = getitem_5 = None
    getitem_7 = rope[0]
    getitem_8 = rope[1];  rope = None
    attention = torch.ops.mirage.attention.default(getitem_7, getitem_8, getitem_6);  getitem_7 = getitem_8 = getitem_6 = None
    quantize_1 = torch.ops.mirage.quantize.default(attention, arg8_1, torch.float8_e4m3fnuz);  attention = arg8_1 = None
    getitem_9 = quantize_1[0]
    getitem_10 = quantize_1[1];  quantize_1 = None
    _scaled_mm_1 = torch.ops.aten._scaled_mm.default(getitem_9, arg7_1, getitem_10, arg9_1);  getitem_9 = arg7_1 = getitem_10 = arg9_1 = None
    rms_norm_1 = torch.ops.mirage.rms_norm.default(_scaled_mm_1, arg10_1, getitem_1);  _scaled_mm_1 = arg10_1 = getitem_1 = None
    getitem_11 = rms_norm_1[0]
    getitem_12 = rms_norm_1[1];  rms_norm_1 = None
    quantize_2 = torch.ops.mirage.quantize.default(getitem_11, arg12_1, torch.float8_e4m3fnuz);  getitem_11 = arg12_1 = None
    getitem_13 = quantize_2[0]
    getitem_14 = quantize_2[1];  quantize_2 = None
    _scaled_mm_2 = torch.ops.aten._scaled_mm.default(getitem_13, arg11_1, getitem_14, arg13_1);  getitem_13 = arg11_1 = getitem_14 = arg13_1 = None
    silu_mul = torch.ops.mirage.silu_mul.default(_scaled_mm_2);  _scaled_mm_2 = None
    quantize_3 = torch.ops.mirage.quantize.default(silu_mul, arg15_1, torch.float8_e4m3fnuz);  silu_mul = arg15_1 = None
    getitem_15 = quantize_3[0]
    getitem_16 = quantize_3[1];  quantize_3 = None
    _scaled_mm_3 = torch.ops.aten._scaled_mm.default(getitem_15, arg14_1, getitem_16, arg16_1);  getitem_15 = arg14_1 = getitem_16 = arg16_1 = None
    rms_norm_2 = torch.ops.mirage.rms_norm.default(_scaled_mm_3, arg17_1, getitem_12);  _scaled_mm_3 = arg17_1 = getitem_12 = None
    getitem_17 = rms_norm_2[0]
    getitem_18 = rms_norm_2[1];  rms_norm_2 = None
    quantize_4 = torch.ops.mirage.quantize.default(getitem_17, arg19_1, torch.float8_e4m3fnuz);  getitem_17 = arg19_1 = None
    getitem_19 = quantize_4[0]
    getitem_20 = quantize_4[1];  quantize_4 = None
    _scaled_mm_4 = torch.ops.aten._scaled_mm.default(getitem_19, arg18_1, getitem_20, arg20_1);  getitem_19 = arg18_1 = getitem_20 = arg20_1 = None
    split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(_scaled_mm_4, [4096, 1024, 1024], -1);  _scaled_mm_4 = None
    getitem_21 = split_with_sizes_1[0]
    getitem_22 = split_with_sizes_1[1]
    getitem_23 = split_with_sizes_1[2];  split_with_sizes_1 = None
    rope_1 = torch.ops.mirage.rope.default(getitem_21, getitem_22, arg6_1);  getitem_21 = getitem_22 = None
    getitem_24 = rope_1[0]
    getitem_25 = rope_1[1];  rope_1 = None
    attention_1 = torch.ops.mirage.attention.default(getitem_24, getitem_25, getitem_23);  getitem_24 = getitem_25 = getitem_23 = None
    quantize_5 = torch.ops.mirage.quantize.default(attention_1, arg22_1, torch.float8_e4m3fnuz);  attention_1 = arg22_1 = None
    getitem_26 = quantize_5[0]
    getitem_27 = quantize_5[1];  quantize_5 = None
    _scaled_mm_5 = torch.ops.aten._scaled_mm.default(getitem_26, arg21_1, getitem_27, arg23_1);  getitem_26 = arg21_1 = getitem_27 = arg23_1 = None
    rms_norm_3 = torch.ops.mirage.rms_norm.default(_scaled_mm_5, arg24_1, getitem_18);  _scaled_mm_5 = arg24_1 = getitem_18 = None
    getitem_28 = rms_norm_3[0]
    getitem_29 = rms_norm_3[1];  rms_norm_3 = None
    quantize_6 = torch.ops.mirage.quantize.default(getitem_28, arg26_1, torch.float8_e4m3fnuz);  getitem_28 = arg26_1 = None
    getitem_30 = quantize_6[0]
    getitem_31 = quantize_6[1];  quantize_6 = None
    _scaled_mm_6 = torch.ops.aten._scaled_mm.default(getitem_30, arg25_1, getitem_31, arg27_1);  getitem_30 = arg25_1 = getitem_31 = arg27_1 = None
    silu_mul_1 = torch.ops.mirage.silu_mul.default(_scaled_mm_6);  _scaled_mm_6 = None
    quantize_7 = torch.ops.mirage.quantize.default(silu_mul_1, arg29_1, torch.float8_e4m3fnuz);  silu_mul_1 = arg29_1 = None
    getitem_32 = quantize_7[0]
    getitem_33 = quantize_7[1];  quantize_7 = None
    _scaled_mm_7 = torch.ops.aten._scaled_mm.default(getitem_32, arg28_1, getitem_33, arg30_1);  getitem_32 = arg28_1 = getitem_33 = arg30_1 = None
    rms_norm_4 = torch.ops.mirage.rms_norm.default(_scaled_mm_7, arg31_1, getitem_29);  _scaled_mm_7 = arg31_1 = getitem_29 = None
    getitem_34 = rms_norm_4[0]
    getitem_35 = rms_norm_4[1];  rms_norm_4 = None
    quantize_8 = torch.ops.mirage.quantize.default(getitem_34, arg33_1, torch.float8_e4m3fnuz);  getitem_34 = arg33_1 = None
    getitem_36 = quantize_8[0]
    getitem_37 = quantize_8[1];  quantize_8 = None
    _scaled_mm_8 = torch.ops.aten._scaled_mm.default(getitem_36, arg32_1, getitem_37, arg34_1);  getitem_36 = arg32_1 = getitem_37 = arg34_1 = None
    split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(_scaled_mm_8, [4096, 1024, 1024], -1);  _scaled_mm_8 = None
    getitem_38 = split_with_sizes_2[0]
    getitem_39 = split_with_sizes_2[1]
    getitem_40 = split_with_sizes_2[2];  split_with_sizes_2 = None
    rope_2 = torch.ops.mirage.rope.default(getitem_38, getitem_39, arg6_1);  getitem_38 = getitem_39 = None
    getitem_41 = rope_2[0]
    getitem_42 = rope_2[1];  rope_2 = None
    attention_2 = torch.ops.mirage.attention.default(getitem_41, getitem_42, getitem_40);  getitem_41 = getitem_42 = getitem_40 = None
    quantize_9 = torch.ops.mirage.quantize.default(attention_2, arg36_1, torch.float8_e4m3fnuz);  attention_2 = arg36_1 = None
    getitem_43 = quantize_9[0]
    getitem_44 = quantize_9[1];  quantize_9 = None
    _scaled_mm_9 = torch.ops.aten._scaled_mm.default(getitem_43, arg35_1, getitem_44, arg37_1);  getitem_43 = arg35_1 = getitem_44 = arg37_1 = None
    rms_norm_5 = torch.ops.mirage.rms_norm.default(_scaled_mm_9, arg38_1, getitem_35);  _scaled_mm_9 = arg38_1 = getitem_35 = None
    getitem_45 = rms_norm_5[0]
    getitem_46 = rms_norm_5[1];  rms_norm_5 = None
    quantize_10 = torch.ops.mirage.quantize.default(getitem_45, arg40_1, torch.float8_e4m3fnuz);  getitem_45 = arg40_1 = None
    getitem_47 = quantize_10[0]
    getitem_48 = quantize_10[1];  quantize_10 = None
    _scaled_mm_10 = torch.ops.aten._scaled_mm.default(getitem_47, arg39_1, getitem_48, arg41_1);  getitem_47 = arg39_1 = getitem_48 = arg41_1 = None
    silu_mul_2 = torch.ops.mirage.silu_mul.default(_scaled_mm_10);  _scaled_mm_10 = None
    quantize_11 = torch.ops.mirage.quantize.default(silu_mul_2, arg43_1, torch.float8_e4m3fnuz);  silu_mul_2 = arg43_1 = None
    getitem_49 = quantize_11[0]
    getitem_50 = quantize_11[1];  quantize_11 = None
    _scaled_mm_11 = torch.ops.aten._scaled_mm.default(getitem_49, arg42_1, getitem_50, arg44_1);  getitem_49 = arg42_1 = getitem_50 = arg44_1 = None
    rms_norm_6 = torch.ops.mirage.rms_norm.default(_scaled_mm_11, arg45_1, getitem_46);  _scaled_mm_11 = arg45_1 = getitem_46 = None
    getitem_51 = rms_norm_6[0]
    getitem_52 = rms_norm_6[1];  rms_norm_6 = None
    quantize_12 = torch.ops.mirage.quantize.default(getitem_51, arg47_1, torch.float8_e4m3fnuz);  getitem_51 = arg47_1 = None
    getitem_53 = quantize_12[0]
    getitem_54 = quantize_12[1];  quantize_12 = None
    _scaled_mm_12 = torch.ops.aten._scaled_mm.default(getitem_53, arg46_1, getitem_54, arg48_1);  getitem_53 = arg46_1 = getitem_54 = arg48_1 = None
    split_with_sizes_3 = torch.ops.aten.split_with_sizes.default(_scaled_mm_12, [4096, 1024, 1024], -1);  _scaled_mm_12 = None
    getitem_55 = split_with_sizes_3[0]
    getitem_56 = split_with_sizes_3[1]
    getitem_57 = split_with_sizes_3[2];  split_with_sizes_3 = None
    rope_3 = torch.ops.mirage.rope.default(getitem_55, getitem_56, arg6_1);  getitem_55 = getitem_56 = arg6_1 = None
    getitem_58 = rope_3[0]
    getitem_59 = rope_3[1];  rope_3 = None
    attention_3 = torch.ops.mirage.attention.default(getitem_58, getitem_59, getitem_57);  getitem_58 = getitem_59 = getitem_57 = None
    quantize_13 = torch.ops.mirage.quantize.default(attention_3, arg50_1, torch.float8_e4m3fnuz);  attention_3 = arg50_1 = None
    getitem_60 = quantize_13[0]
    getitem_61 = quantize_13[1];  quantize_13 = None
    _scaled_mm_13 = torch.ops.aten._scaled_mm.default(getitem_60, arg49_1, getitem_61, arg51_1);  getitem_60 = arg49_1 = getitem_61 = arg51_1 = None
    rms_norm_7 = torch.ops.mirage.rms_norm.default(_scaled_mm_13, arg52_1, getitem_52);  _scaled_mm_13 = arg52_1 = getitem_52 = None
    getitem_62 = rms_norm_7[0]
    getitem_63 = rms_norm_7[1];  rms_norm_7 = None
    quantize_14 = torch.ops.mirage.quantize.default(getitem_62, arg54_1, torch.float8_e4m3fnuz);  getitem_62 = arg54_1 = None
    getitem_64 = quantize_14[0]
    getitem_65 = quantize_14[1];  quantize_14 = None
    _scaled_mm_14 = torch.ops.aten._scaled_mm.default(getitem_64, arg53_1, getitem_65, arg55_1);  getitem_64 = arg53_1 = getitem_65 = arg55_1 = None
    silu_mul_3 = torch.ops.mirage.silu_mul.default(_scaled_mm_14);  _scaled_mm_14 = None
    quantize_15 = torch.ops.mirage.quantize.default(silu_mul_3, arg57_1, torch.float8_e4m3fnuz);  silu_mul_3 = arg57_1 = None
    getitem_66 = quantize_15[0]
    getitem_67 = quantize_15[1];  quantize_15 = None
    _scaled_mm_15 = torch.ops.aten._scaled_mm.default(getitem_66, arg56_1, getitem_67, arg58_1);  getitem_66 = arg56_1 = getitem_67 = arg58_1 = None
    rms_norm_8 = torch.ops.mirage.rms_norm.default(_scaled_mm_15, arg59_1, getitem_63);  _scaled_mm_15 = arg59_1 = getitem_63 = None
    getitem_68 = rms_norm_8[0];  rms_norm_8 = None
    return (getitem_68,)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment