Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Created July 21, 2025 14:22
Show Gist options
  • Save a-r-r-o-w/9f14b507af8118eb4b655258b1beb1e7 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/9f14b507af8118eb4b655258b1beb1e7 to your computer and use it in GitHub Desktop.
Nothing ever happens: https://arxiv.org/abs/2507.14111
import torch
import torch.nn as nn
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class Model(nn.Module):
"""
Simple model that performs a single square matrix multiplication (C = A * B)
"""
def __init__(self):
super(Model, self).__init__()
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs the matrix multiplication.
Args:
A (torch.Tensor): Input matrix A of shape (N, N).
B (torch.Tensor): Input matrix B of shape (N, N).
Returns:
torch.Tensor: Output matrix C of shape (N, N).
"""
return torch.matmul(A, B)
# =======================================================================================================
class ModelNew(nn.Module):
"""
Optimized model that performs a single square matrix multiplication (C = A * B)
"""
def __init__(self):
super(ModelNew, self).__init__()
# Initialize optimization components
self.compute_stream = None
self.transfer_stream = None
self.output_cache = None
self.A_cache = None
self.B_cache = None
self.device = None
self.warmed_up = False
self.graph = None
self.transfer_event = None
self.compute_event = None
self.has_cuda = torch.cuda.is_available()
# Pre-initialize for optimal performance if CUDA is available
if self.has_cuda:
self.device = torch.device('cuda')
# Create streams with optimal priority settings
priority_range = torch.cuda.Stream.priority_range()
high_priority = priority_range[0] # Highest priority for computation
low_priority = priority_range[1] # Lower priority for transfers
self.compute_stream = torch.cuda.Stream(priority=high_priority)
self.transfer_stream = torch.cuda.Stream(priority=low_priority)
# Create high-performance events for synchronization
self.transfer_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.compute_event = torch.cuda.Event(enable_timing=False, blocking=False)
# Pre-allocate all tensors with optimal memory configuration
with torch.cuda.stream(self.compute_stream):
self.output_cache = torch.empty(
N, N,
dtype=torch.float32,
device=self.device,
memory_format=torch.contiguous_format
)
self.A_cache = torch.empty(
N, N,
dtype=torch.float32,
device=self.device,
memory_format=torch.contiguous_format
)
self.B_cache = torch.empty(
N, N,
dtype=torch.float32,
device=self.device,
memory_format=torch.contiguous_format
)
# Touch memory to ensure allocation
self.output_cache.zero_()
self.A_cache.zero_()
self.B_cache.zero_()
# Check CUDA graph support with enhanced detection
self.use_graph = (hasattr(torch.cuda, 'graph') and
hasattr(torch.cuda, 'CUDAGraph') and
torch.cuda.get_device_capability()[0] >= 7)
# Pre-warm the GPU to ensure it's at full clock speed
with torch.cuda.stream(self.compute_stream):
dummy_a = torch.randn(128, 128, device=self.device)
dummy_b = torch.randn(128, 128, device=self.device)
for _ in range(5):
_ = torch.matmul(dummy_a, dummy_b)
self.compute_stream.synchronize()
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
"""
Performs the matrix multiplication with ultra-optimized GPU utilization.
Args:
A (torch.Tensor): Input matrix A of shape (N, N).
B (torch.Tensor): Input matrix B of shape (N, N).
Returns:
torch.Tensor: Output matrix C of shape (N, N).
"""
if not self.has_cuda:
# CPU fallback for systems without CUDA
return torch.matmul(A, B)
# Fast path for optimal case: both tensors already on GPU and contiguous
if (A.is_cuda and B.is_cuda and A.is_contiguous() and B.is_contiguous()):
# DIFF: need to synchronize the compute stream with main stream before using it
self.compute_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.compute_stream):
# Handle warmup and graph capture for optimal tensors
if not self.warmed_up:
# Warmup with optimal iteration count
for _ in range(3):
torch.matmul(A, B, out=self.output_cache)
# Attempt CUDA graph capture for maximum performance
if self.use_graph:
try:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
torch.matmul(A, B, out=self.output_cache)
self.graph = g
except Exception:
self.graph = None
self.warmed_up = True
self.compute_stream.synchronize()
# Execute with graph if available, otherwise direct computation
if self.graph is not None:
self.graph.replay()
else:
torch.matmul(A, B, out=self.output_cache)
# DIFF: wait stream must be outside the compute stream context
torch.cuda.current_stream().wait_stream(self.compute_stream)
return self.output_cache
# DIFF: transfer stream needs to wait for current stream
torch.cuda.current_stream().wait_stream(self.transfer_stream)
# Optimized path for tensors requiring transfer or memory layout fixes
with torch.cuda.stream(self.transfer_stream):
# Handle A tensor with minimal overhead
if not A.is_cuda:
# Pin memory for faster transfer if not already pinned
if not A.is_pinned() and hasattr(A, 'pin_memory'):
A = A.pin_memory()
self.A_cache.copy_(A, non_blocking=True)
A_gpu = self.A_cache
elif not A.is_contiguous():
# Fix memory layout if needed
self.A_cache.copy_(A, non_blocking=True)
A_gpu = self.A_cache
else:
A_gpu = A
# Handle B tensor with minimal overhead
if not B.is_cuda:
# Pin memory for faster transfer if not already pinned
if not B.is_pinned() and hasattr(B, 'pin_memory'):
B = B.pin_memory()
self.B_cache.copy_(B, non_blocking=True)
B_gpu = self.B_cache
elif not B.is_contiguous():
# Fix memory layout if needed
self.B_cache.copy_(B, non_blocking=True)
B_gpu = self.B_cache
else:
B_gpu = B
# Signal transfer completion
self.transfer_event.record(self.transfer_stream)
# DIFF: need to wait for transfer to complete before compute stream can start
self.compute_stream.wait_stream(self.transfer_stream)
# Compute with optimal synchronization
with torch.cuda.stream(self.compute_stream):
# # Wait for transfers only if necessary
# self.transfer_event.wait(self.compute_stream)
# Handle warmup and graph capture
if not self.warmed_up:
# Optimal warmup iterations
for _ in range(3):
torch.matmul(A_gpu, B_gpu, out=self.output_cache)
# Attempt CUDA graph capture
if self.use_graph:
try:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
torch.matmul(A_gpu, B_gpu, out=self.output_cache)
self.graph = g
except Exception:
self.graph = None
self.warmed_up = True
self.compute_stream.synchronize()
# Execute computation
if self.graph is not None:
self.graph.replay()
else:
torch.matmul(A_gpu, B_gpu, out=self.output_cache)
# Record completion for potential future synchronization
self.compute_event.record(self.compute_stream)
# DIFF: can't return output until computation is actually completed, otherwise results
# will be wrong
torch.cuda.current_stream().wait_stream(self.compute_stream)
return self.output_cache
N = 2048
device = "cuda"
dtype = torch.float32
torch.manual_seed(42)
class ModelWithOut(nn.Module):
def __init__(self):
super().__init__()
self.out = torch.empty(N, N, dtype=dtype, device=device)
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return torch.matmul(A, B, out=self.out)
def get_inputs():
A = torch.randn(N, N, dtype=dtype, device=device)
B = torch.randn(N, N, dtype=dtype, device=device)
return [A, B]
def quantile(a, q):
import math
n = len(a)
a = sorted(a)
def get_quantile(q):
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return (1 - t) * a[lower] + t * a[upper]
return [get_quantile(q) for q in q]
model_normal = Model().to(device=device, dtype=dtype)
model_with_out = ModelWithOut().to(device=device, dtype=dtype)
model_theirs_optimized = ModelNew().to(device=device, dtype=dtype)
import triton
import triton.language as tl
import triton.runtime as runtime
num_warmups = 8
num_repeats = 32
inputs = get_inputs()
def benchmark(fn, A, B):
fn(A, B)
torch.cuda.synchronize()
cache = runtime.driver.active.get_empty_cache_for_benchmark()
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)]
for i in range(num_warmups):
runtime.driver.active.clear_cache(cache)
fn(A, B)
torch.cuda.synchronize()
for i in range(num_repeats):
runtime.driver.active.clear_cache(cache)
start_events[i].record()
fn(A, B)
end_events[i].record()
torch.cuda.synchronize()
elapsed_times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)]
mean_time = sum(elapsed_times) / len(elapsed_times)
quantile_times = quantile(elapsed_times, [0.5, 0.2, 0.8])
print(f"Benchmarking {fn.__class__.__name__}")
print(f"Mean time: {mean_time:.3f} ms")
print(f"Quantiles (0.5, 0.2, 0.8): {quantile_times}")
print()
return mean_time
normal_time = benchmark(model_normal, *inputs)
with_out_time = benchmark(model_with_out, *inputs)
theirs_optimized_time = benchmark(model_theirs_optimized, *inputs)
with_out_speedup = normal_time / with_out_time
theirs_speedup = normal_time / theirs_optimized_time
print(f"Speedup with out= {with_out_speedup:.2f}x")
print(f"Speedup with theirs optimized= {theirs_speedup:.2f}x")
@a-r-r-o-w
Copy link
Author

Line 157 is incorrect: should be transfer stream waits for default stream

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment