Created
June 25, 2025 17:34
-
-
Save msaroufim/0a0452617d0eb08bf7c8a897a045f24e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.cpp_extension import _get_cuda_arch_flags | |
def test_fix(): | |
print("Testing CUDA arch flags fix...") | |
user_arch_flags = ['-gencode=arch=compute_86,code=sm_86'] | |
result = _get_cuda_arch_flags(user_arch_flags) | |
print(f"User provided: {user_arch_flags}") | |
print(f"Function returned: {result}") | |
if len(result) == 0: | |
print("PASS: User arch flags prevent default generation") | |
return True | |
else: | |
print("FAIL: Default flags still generated despite user input") | |
print(f"Expected: []") | |
print(f"Got: {result}") | |
return False | |
def test_backward_compatibility(): | |
print("Testing backward compatibility...") | |
# Test: No arguments - should still generate defaults | |
default_flags = _get_cuda_arch_flags() | |
print(f"No args provided: {len(default_flags)} flags generated") | |
# Test: Non-arch flags - should still generate defaults | |
non_arch_flags = _get_cuda_arch_flags(['-O2', '--use-fast-math']) | |
print(f"Non-arch flags provided: {len(non_arch_flags)} flags generated") | |
if len(default_flags) > 0 and len(non_arch_flags) > 0: | |
print("PASS: Backward compatibility preserved") | |
return True | |
else: | |
print("FAIL: Backward compatibility broken") | |
return False | |
def test_compilation(): | |
"""Test compilation with user arch flags""" | |
if not torch.cuda.is_available(): | |
print("CUDA not available, skipping compilation test") | |
return True | |
print("Testing compilation with user arch flags...") | |
try: | |
from torch.utils.cpp_extension import load_inline | |
cuda_code = '__global__ void dummy() {}' | |
cpp_code = ''' | |
#include <torch/extension.h> | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} | |
''' | |
capability = torch.cuda.get_device_capability() | |
arch_flag = f'-gencode=arch=compute_{capability[0]}{capability[1]},code=sm_{capability[0]}{capability[1]}' | |
module = load_inline( | |
name='test_fix', | |
cpp_sources=[cpp_code], | |
cuda_sources=[cuda_code], | |
extra_cuda_cflags=[arch_flag], | |
verbose=False | |
) | |
print(f"PASS: Compilation succeeded with {arch_flag}") | |
return True | |
except Exception as e: | |
if "no kernel image is available" in str(e): | |
print("PASS: Arch targeting worked (runtime error expected)") | |
return True | |
else: | |
print(f"FAIL: Compilation failed: {e}") | |
return False | |
if __name__ == "__main__": | |
print("=" * 40) | |
success1 = test_fix() | |
success2 = test_backward_compatibility() | |
success3 = test_compilation() | |
print("=" * 40) | |
if success1 and success2 and success3: | |
print("ALL TESTS PASSED - BUG FIX VERIFIED") | |
else: | |
print("TESTS FAILED - BUG FIX NOT WORKING") | |
print("=" * 40) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment