Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created June 25, 2025 17:34
Show Gist options
  • Save msaroufim/0a0452617d0eb08bf7c8a897a045f24e to your computer and use it in GitHub Desktop.
Save msaroufim/0a0452617d0eb08bf7c8a897a045f24e to your computer and use it in GitHub Desktop.
import torch
from torch.utils.cpp_extension import _get_cuda_arch_flags
def test_fix():
print("Testing CUDA arch flags fix...")
user_arch_flags = ['-gencode=arch=compute_86,code=sm_86']
result = _get_cuda_arch_flags(user_arch_flags)
print(f"User provided: {user_arch_flags}")
print(f"Function returned: {result}")
if len(result) == 0:
print("PASS: User arch flags prevent default generation")
return True
else:
print("FAIL: Default flags still generated despite user input")
print(f"Expected: []")
print(f"Got: {result}")
return False
def test_backward_compatibility():
print("Testing backward compatibility...")
# Test: No arguments - should still generate defaults
default_flags = _get_cuda_arch_flags()
print(f"No args provided: {len(default_flags)} flags generated")
# Test: Non-arch flags - should still generate defaults
non_arch_flags = _get_cuda_arch_flags(['-O2', '--use-fast-math'])
print(f"Non-arch flags provided: {len(non_arch_flags)} flags generated")
if len(default_flags) > 0 and len(non_arch_flags) > 0:
print("PASS: Backward compatibility preserved")
return True
else:
print("FAIL: Backward compatibility broken")
return False
def test_compilation():
"""Test compilation with user arch flags"""
if not torch.cuda.is_available():
print("CUDA not available, skipping compilation test")
return True
print("Testing compilation with user arch flags...")
try:
from torch.utils.cpp_extension import load_inline
cuda_code = '__global__ void dummy() {}'
cpp_code = '''
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
'''
capability = torch.cuda.get_device_capability()
arch_flag = f'-gencode=arch=compute_{capability[0]}{capability[1]},code=sm_{capability[0]}{capability[1]}'
module = load_inline(
name='test_fix',
cpp_sources=[cpp_code],
cuda_sources=[cuda_code],
extra_cuda_cflags=[arch_flag],
verbose=False
)
print(f"PASS: Compilation succeeded with {arch_flag}")
return True
except Exception as e:
if "no kernel image is available" in str(e):
print("PASS: Arch targeting worked (runtime error expected)")
return True
else:
print(f"FAIL: Compilation failed: {e}")
return False
if __name__ == "__main__":
print("=" * 40)
success1 = test_fix()
success2 = test_backward_compatibility()
success3 = test_compilation()
print("=" * 40)
if success1 and success2 and success3:
print("ALL TESTS PASSED - BUG FIX VERIFIED")
else:
print("TESTS FAILED - BUG FIX NOT WORKING")
print("=" * 40)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment