Created
July 21, 2025 14:22
-
-
Save a-r-r-o-w/9f14b507af8118eb4b655258b1beb1e7 to your computer and use it in GitHub Desktop.
Nothing ever happens: https://arxiv.org/abs/2507.14111
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
class Model(nn.Module): | |
""" | |
Simple model that performs a single square matrix multiplication (C = A * B) | |
""" | |
def __init__(self): | |
super(Model, self).__init__() | |
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: | |
""" | |
Performs the matrix multiplication. | |
Args: | |
A (torch.Tensor): Input matrix A of shape (N, N). | |
B (torch.Tensor): Input matrix B of shape (N, N). | |
Returns: | |
torch.Tensor: Output matrix C of shape (N, N). | |
""" | |
return torch.matmul(A, B) | |
# ======================================================================================================= | |
class ModelNew(nn.Module): | |
""" | |
Optimized model that performs a single square matrix multiplication (C = A * B) | |
""" | |
def __init__(self): | |
super(ModelNew, self).__init__() | |
# Initialize optimization components | |
self.compute_stream = None | |
self.transfer_stream = None | |
self.output_cache = None | |
self.A_cache = None | |
self.B_cache = None | |
self.device = None | |
self.warmed_up = False | |
self.graph = None | |
self.transfer_event = None | |
self.compute_event = None | |
self.has_cuda = torch.cuda.is_available() | |
# Pre-initialize for optimal performance if CUDA is available | |
if self.has_cuda: | |
self.device = torch.device('cuda') | |
# Create streams with optimal priority settings | |
priority_range = torch.cuda.Stream.priority_range() | |
high_priority = priority_range[0] # Highest priority for computation | |
low_priority = priority_range[1] # Lower priority for transfers | |
self.compute_stream = torch.cuda.Stream(priority=high_priority) | |
self.transfer_stream = torch.cuda.Stream(priority=low_priority) | |
# Create high-performance events for synchronization | |
self.transfer_event = torch.cuda.Event(enable_timing=False, blocking=False) | |
self.compute_event = torch.cuda.Event(enable_timing=False, blocking=False) | |
# Pre-allocate all tensors with optimal memory configuration | |
with torch.cuda.stream(self.compute_stream): | |
self.output_cache = torch.empty( | |
N, N, | |
dtype=torch.float32, | |
device=self.device, | |
memory_format=torch.contiguous_format | |
) | |
self.A_cache = torch.empty( | |
N, N, | |
dtype=torch.float32, | |
device=self.device, | |
memory_format=torch.contiguous_format | |
) | |
self.B_cache = torch.empty( | |
N, N, | |
dtype=torch.float32, | |
device=self.device, | |
memory_format=torch.contiguous_format | |
) | |
# Touch memory to ensure allocation | |
self.output_cache.zero_() | |
self.A_cache.zero_() | |
self.B_cache.zero_() | |
# Check CUDA graph support with enhanced detection | |
self.use_graph = (hasattr(torch.cuda, 'graph') and | |
hasattr(torch.cuda, 'CUDAGraph') and | |
torch.cuda.get_device_capability()[0] >= 7) | |
# Pre-warm the GPU to ensure it's at full clock speed | |
with torch.cuda.stream(self.compute_stream): | |
dummy_a = torch.randn(128, 128, device=self.device) | |
dummy_b = torch.randn(128, 128, device=self.device) | |
for _ in range(5): | |
_ = torch.matmul(dummy_a, dummy_b) | |
self.compute_stream.synchronize() | |
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: | |
""" | |
Performs the matrix multiplication with ultra-optimized GPU utilization. | |
Args: | |
A (torch.Tensor): Input matrix A of shape (N, N). | |
B (torch.Tensor): Input matrix B of shape (N, N). | |
Returns: | |
torch.Tensor: Output matrix C of shape (N, N). | |
""" | |
if not self.has_cuda: | |
# CPU fallback for systems without CUDA | |
return torch.matmul(A, B) | |
# Fast path for optimal case: both tensors already on GPU and contiguous | |
if (A.is_cuda and B.is_cuda and A.is_contiguous() and B.is_contiguous()): | |
# DIFF: need to synchronize the compute stream with main stream before using it | |
self.compute_stream.wait_stream(torch.cuda.current_stream()) | |
with torch.cuda.stream(self.compute_stream): | |
# Handle warmup and graph capture for optimal tensors | |
if not self.warmed_up: | |
# Warmup with optimal iteration count | |
for _ in range(3): | |
torch.matmul(A, B, out=self.output_cache) | |
# Attempt CUDA graph capture for maximum performance | |
if self.use_graph: | |
try: | |
g = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(g): | |
torch.matmul(A, B, out=self.output_cache) | |
self.graph = g | |
except Exception: | |
self.graph = None | |
self.warmed_up = True | |
self.compute_stream.synchronize() | |
# Execute with graph if available, otherwise direct computation | |
if self.graph is not None: | |
self.graph.replay() | |
else: | |
torch.matmul(A, B, out=self.output_cache) | |
# DIFF: wait stream must be outside the compute stream context | |
torch.cuda.current_stream().wait_stream(self.compute_stream) | |
return self.output_cache | |
# DIFF: transfer stream needs to wait for current stream | |
torch.cuda.current_stream().wait_stream(self.transfer_stream) | |
# Optimized path for tensors requiring transfer or memory layout fixes | |
with torch.cuda.stream(self.transfer_stream): | |
# Handle A tensor with minimal overhead | |
if not A.is_cuda: | |
# Pin memory for faster transfer if not already pinned | |
if not A.is_pinned() and hasattr(A, 'pin_memory'): | |
A = A.pin_memory() | |
self.A_cache.copy_(A, non_blocking=True) | |
A_gpu = self.A_cache | |
elif not A.is_contiguous(): | |
# Fix memory layout if needed | |
self.A_cache.copy_(A, non_blocking=True) | |
A_gpu = self.A_cache | |
else: | |
A_gpu = A | |
# Handle B tensor with minimal overhead | |
if not B.is_cuda: | |
# Pin memory for faster transfer if not already pinned | |
if not B.is_pinned() and hasattr(B, 'pin_memory'): | |
B = B.pin_memory() | |
self.B_cache.copy_(B, non_blocking=True) | |
B_gpu = self.B_cache | |
elif not B.is_contiguous(): | |
# Fix memory layout if needed | |
self.B_cache.copy_(B, non_blocking=True) | |
B_gpu = self.B_cache | |
else: | |
B_gpu = B | |
# Signal transfer completion | |
self.transfer_event.record(self.transfer_stream) | |
# DIFF: need to wait for transfer to complete before compute stream can start | |
self.compute_stream.wait_stream(self.transfer_stream) | |
# Compute with optimal synchronization | |
with torch.cuda.stream(self.compute_stream): | |
# # Wait for transfers only if necessary | |
# self.transfer_event.wait(self.compute_stream) | |
# Handle warmup and graph capture | |
if not self.warmed_up: | |
# Optimal warmup iterations | |
for _ in range(3): | |
torch.matmul(A_gpu, B_gpu, out=self.output_cache) | |
# Attempt CUDA graph capture | |
if self.use_graph: | |
try: | |
g = torch.cuda.CUDAGraph() | |
with torch.cuda.graph(g): | |
torch.matmul(A_gpu, B_gpu, out=self.output_cache) | |
self.graph = g | |
except Exception: | |
self.graph = None | |
self.warmed_up = True | |
self.compute_stream.synchronize() | |
# Execute computation | |
if self.graph is not None: | |
self.graph.replay() | |
else: | |
torch.matmul(A_gpu, B_gpu, out=self.output_cache) | |
# Record completion for potential future synchronization | |
self.compute_event.record(self.compute_stream) | |
# DIFF: can't return output until computation is actually completed, otherwise results | |
# will be wrong | |
torch.cuda.current_stream().wait_stream(self.compute_stream) | |
return self.output_cache | |
N = 2048 | |
device = "cuda" | |
dtype = torch.float32 | |
torch.manual_seed(42) | |
class ModelWithOut(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.out = torch.empty(N, N, dtype=dtype, device=device) | |
def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: | |
return torch.matmul(A, B, out=self.out) | |
def get_inputs(): | |
A = torch.randn(N, N, dtype=dtype, device=device) | |
B = torch.randn(N, N, dtype=dtype, device=device) | |
return [A, B] | |
def quantile(a, q): | |
import math | |
n = len(a) | |
a = sorted(a) | |
def get_quantile(q): | |
if not (0 <= q <= 1): | |
raise ValueError("Quantiles must be in the range [0, 1]") | |
point = q * (n - 1) | |
lower = math.floor(point) | |
upper = math.ceil(point) | |
t = point - lower | |
return (1 - t) * a[lower] + t * a[upper] | |
return [get_quantile(q) for q in q] | |
model_normal = Model().to(device=device, dtype=dtype) | |
model_with_out = ModelWithOut().to(device=device, dtype=dtype) | |
model_theirs_optimized = ModelNew().to(device=device, dtype=dtype) | |
import triton | |
import triton.language as tl | |
import triton.runtime as runtime | |
num_warmups = 8 | |
num_repeats = 32 | |
inputs = get_inputs() | |
def benchmark(fn, A, B): | |
fn(A, B) | |
torch.cuda.synchronize() | |
cache = runtime.driver.active.get_empty_cache_for_benchmark() | |
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)] | |
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_repeats)] | |
for i in range(num_warmups): | |
runtime.driver.active.clear_cache(cache) | |
fn(A, B) | |
torch.cuda.synchronize() | |
for i in range(num_repeats): | |
runtime.driver.active.clear_cache(cache) | |
start_events[i].record() | |
fn(A, B) | |
end_events[i].record() | |
torch.cuda.synchronize() | |
elapsed_times = [start.elapsed_time(end) for start, end in zip(start_events, end_events)] | |
mean_time = sum(elapsed_times) / len(elapsed_times) | |
quantile_times = quantile(elapsed_times, [0.5, 0.2, 0.8]) | |
print(f"Benchmarking {fn.__class__.__name__}") | |
print(f"Mean time: {mean_time:.3f} ms") | |
print(f"Quantiles (0.5, 0.2, 0.8): {quantile_times}") | |
print() | |
return mean_time | |
normal_time = benchmark(model_normal, *inputs) | |
with_out_time = benchmark(model_with_out, *inputs) | |
theirs_optimized_time = benchmark(model_theirs_optimized, *inputs) | |
with_out_speedup = normal_time / with_out_time | |
theirs_speedup = normal_time / theirs_optimized_time | |
print(f"Speedup with out= {with_out_speedup:.2f}x") | |
print(f"Speedup with theirs optimized= {theirs_speedup:.2f}x") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Line 157 is incorrect: should be transfer stream waits for default stream