Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created June 30, 2025 00:40
Show Gist options
  • Save Birch-san/64dc067892cc05abaa99a4bec9cfcdee to your computer and use it in GitHub Desktop.
Save Birch-san/64dc067892cc05abaa99a4bec9cfcdee to your computer and use it in GitHub Desktop.
Failed attempt to reproduce "torch.compile gets invalidated too easily by einops rearrange"
from __future__ import annotations
from typing import NamedTuple, Optional
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
import torch
from torch import Tensor, inference_mode
from torch.nn import Module, Linear
from torch.nn.functional import relu
from einops import rearrange
@dataclass
class Args:
iterations: int
@staticmethod
def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--iterations", type=int, default=8)
return parser
@staticmethod
def from_namespace(namespace: Namespace) -> Args:
args = Args(**vars(namespace))
return args
class Modulation(Module):
class Out(NamedTuple):
gate: Tensor
shift: Tensor
scale: Tensor
def __init__(
self,
dim: int,
device: Optional[str | int | torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dim = dim
self.lin = Linear(dim, 3 * dim, bias=False, **factory_kwargs)
def forward(self, x: Tensor) -> Modulation.Out:
return Modulation.Out(*self.lin(relu(x)).chunk(3, dim=-1))
class UnsqueezeMod(Module):
def __init__(self, mod: Modulation):
super().__init__()
self.mod = mod
def forward(self, x: Tensor, t: Tensor) -> Tensor:
out: Modulation.Out = self.mod(t)
_, shift, scale = out
return shift[..., None, None, :].addcmul(x, scale[..., None, None, :])
class EinopsMod(Module):
def __init__(self, mod: Modulation):
super().__init__()
self.mod = mod
def forward(self, x: Tensor, t: Tensor) -> Tensor:
out: Modulation.Out = self.mod(t)
_, shift, scale = out
return rearrange(shift, "... c -> ... 1 1 c").addcmul(x, rearrange(scale, "... c -> ... 1 1 c"))
def main(args: Args) -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
dim = 320
bsz = 2
torch.manual_seed(seed)
mod = Modulation(dim, device=device, dtype=dtype).eval()
ein = EinopsMod(mod)
usq = UnsqueezeMod(mod)
ein_c, usq_c = (torch.compile(m, dynamic=False) for m in (ein, usq))
with inference_mode():
for m, m_c in zip((ein, usq), (ein_c, usq_c)):
print(f"Testing {m.__class__.__name__} implementation")
for it in range(args.iterations):
x = torch.randn(bsz, 16*(2**it), 16*(2**it), dim, device=device, dtype=dtype, generator=gen.manual_seed(seed)).to(memory_format=torch.channels_last)
t = torch.randn(bsz, dim, device=device, dtype=dtype, generator=gen.manual_seed(seed))
out = m_c(x, t)
if __name__ == "__main__":
parser = Args.get_parser()
args_untyped: Namespace = parser.parse_args()
args: Args = Args.from_namespace(args_untyped)
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment