Last active
May 9, 2025 01:26
-
-
Save davidberard98/75b8ba682421568a9d5386fceec6eb24 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
============================= test session starts ============================== | |
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0 | |
rootdir: /workspace/triton/python | |
configfile: pyproject.toml | |
collected 21530 items | |
unit/blackwell/test_tmem.py s [ 0%] | |
unit/cuda/test_experimental_tma.py .........F.......F................... [ 0%] | |
......................s..s................................s..s.......... [ 0%] | |
......................s..s.....s..s.....s..s.....s..s................... [ 0%] | |
........................................................................ [ 1%] | |
......................s..s.....s..s.....s..s.....s..s................... [ 1%] | |
........................................................................ [ 1%] | |
..........................F.FFssssssssssssssssssssssssssssssssssssssssss [ 2%] | |
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 2%] | |
unit/cuda/test_flashattention.py .FFFFF [ 2%] | |
unit/cuda/test_gemm.py ........FFFF............FFFF....FFFFFFFFFFFFFFFFF [ 2%] | |
FFFFFFF....FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFssssFFFFFFFFF [ 3%] | |
FFFssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFssssFFFFFFFFFFFFssssFFFFFFFFFFFFFFFFF [ 3%] | |
FFFFFFFFFFFFFFFFFFF................................ssssssssssssssss..... [ 3%] | |
...........ssssssssssssssssFFFFFFFFFFFFFFFFFFFFFFFFssssFFFFFFFFFFFFFFFFF [ 4%] | |
FFFFFFFFFFFssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 4%] | |
FFF................ssssssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 4%] | |
unit/cuda/test_gemm_fusion.py ss [ 4%] | |
unit/cuda/test_mixed_io.py ... [ 4%] | |
unit/cuda/test_tma_descriptor.py sssssssss [ 4%] | |
unit/cuda/test_tma_store_gemm.py ........ [ 4%] | |
unit/instrumentation/test_gpuhello.py . [ 4%] | |
unit/language/test_annotations.py .......... [ 4%] | |
unit/language/test_block_pointer.py ......sss......sss......sss......sss [ 5%] | |
......sss......sss......sss......sss......sss......sss......sss......sss [ 5%] | |
......sss......sss......sss............................................. [ 5%] | |
........................................................................ [ 6%] | |
........................ [ 6%] | |
unit/language/test_compile_errors.py .............................. [ 6%] | |
unit/language/test_compile_only.py .... [ 6%] | |
unit/language/test_conversions.py .....................ss....ssss [ 6%] | |
unit/language/test_core.py ............................................. [ 6%] | |
........................................................................ [ 7%] | |
........................................................................ [ 7%] | |
........................................................................ [ 7%] | |
........................................................................ [ 8%] | |
........................................................................ [ 8%] | |
........................................................................ [ 8%] | |
........................................................................ [ 9%] | |
........................................................................ [ 9%] | |
........................................................................ [ 9%] | |
........................................................................ [ 10%] | |
........................................................................ [ 10%] | |
........................................................................ [ 10%] | |
........................................................................ [ 11%] | |
........................................................................ [ 11%] | |
........................................................................ [ 11%] | |
........................................................................ [ 12%] | |
........................................................................ [ 12%] | |
........................................................................ [ 12%] | |
........................................................................ [ 13%] | |
........................................................................ [ 13%] | |
........................................................................ [ 13%] | |
........................................................................ [ 14%] | |
........................................................................ [ 14%] | |
........................................................................ [ 14%] | |
........................................................................ [ 15%] | |
........................................................................ [ 15%] | |
........................................................................ [ 15%] | |
........................................................................ [ 16%] | |
........................................................................ [ 16%] | |
........................................................................ [ 16%] | |
........................................................................ [ 17%] | |
........................................................................ [ 17%] | |
........................................................................ [ 17%] | |
........................................................................ [ 18%] | |
........................................................................ [ 18%] | |
........................................................................ [ 18%] | |
........................................................................ [ 19%] | |
........................................................................ [ 19%] | |
........................................................................ [ 19%] | |
........................................................................ [ 20%] | |
........................................................................ [ 20%] | |
........................................................................ [ 20%] | |
........................................................................ [ 21%] | |
........................................................................ [ 21%] | |
........................................................................ [ 21%] | |
....s................................................................... [ 22%] | |
........................................................................ [ 22%] | |
........................................................................ [ 22%] | |
........................................................................ [ 23%] | |
........................................................................ [ 23%] | |
........................................................................ [ 23%] | |
........................................................................ [ 24%] | |
........................................................................ [ 24%] | |
........................................................................ [ 24%] | |
........................................................................ [ 25%] | |
........................................................................ [ 25%] | |
........................................................................ [ 25%] | |
........................................................................ [ 26%] | |
........................................................................ [ 26%] | |
........................................................................ [ 26%] | |
........................................................................ [ 27%] | |
........................................................................ [ 27%] | |
........................................................................ [ 27%] | |
........................................................................ [ 28%] | |
........................................................................ [ 28%] | |
........................................................................ [ 28%] | |
........................................................................ [ 29%] | |
........................................................................ [ 29%] | |
...................................................................ss... [ 29%] | |
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 30%] | |
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 30%] | |
.ss....ss....ss......................................................... [ 30%] | |
........................................................................ [ 31%] | |
........................................................................ [ 31%] | |
........................................................................ [ 31%] | |
...................................................................ss... [ 32%] | |
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 32%] | |
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 32%] | |
.ss....ss....ss......................................................... [ 33%] | |
........................................................................ [ 33%] | |
........................................................................ [ 33%] | |
........................................................................ [ 34%] | |
........................................................................ [ 34%] | |
........................................................................ [ 34%] | |
........................................................................ [ 35%] | |
........................................................................ [ 35%] | |
........................................................................ [ 35%] | |
........................................................................ [ 36%] | |
........................................................................ [ 36%] | |
........................................................................ [ 36%] | |
........................................................................ [ 37%] | |
........................................................................ [ 37%] | |
...............................................................sss..sss. [ 37%] | |
.sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..s [ 38%] | |
ss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss [ 38%] | |
..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss.. [ 38%] | |
sss..sss..sss..sss..sss..sss..sss..sss..sss............................. [ 39%] | |
........................................................................ [ 39%] | |
........................................................................ [ 39%] | |
........................................................................ [ 40%] | |
........................................................................ [ 40%] | |
........................................................................ [ 40%] | |
........................................................................ [ 41%] | |
........................................................................ [ 41%] | |
........................................................................ [ 41%] | |
........................................................................ [ 42%] | |
........................................................................ [ 42%] | |
........................................................................ [ 42%] | |
........................................................................ [ 43%] | |
........................................................................ [ 43%] | |
........................................................................ [ 43%] | |
........................................................................ [ 44%] | |
........................................................................ [ 44%] | |
........................................................................ [ 44%] | |
........................................................................ [ 45%] | |
........................................................................ [ 45%] | |
........................................................................ [ 45%] | |
........................................................................ [ 46%] | |
..................................................ss.................... [ 46%] | |
........................................................................ [ 46%] | |
........................................................................ [ 47%] | |
........................................................................ [ 47%] | |
........................................................................ [ 47%] | |
........................................................................ [ 48%] | |
........................................................................ [ 48%] | |
........................................................................ [ 48%] | |
................ssssssssssssssssssssssssssssssssssssssss................ [ 49%] | |
........................................................................ [ 49%] | |
........................................................................ [ 49%] | |
........................................................................ [ 50%] | |
........................................................................ [ 50%] | |
........................................................................ [ 50%] | |
........................................................................ [ 51%] | |
........................................................................ [ 51%] | |
........................................................................ [ 51%] | |
........................................................................ [ 52%] | |
........................................................................ [ 52%] | |
........................................................................ [ 52%] | |
........................................................................ [ 53%] | |
........................................................................ [ 53%] | |
........................................................................ [ 53%] | |
........................................................................ [ 54%] | |
........................................................................ [ 54%] | |
........................................................................ [ 54%] | |
........................................................................ [ 55%] | |
........................................................................ [ 55%] | |
........................................................................ [ 55%] | |
........................................................................ [ 56%] | |
........................................................................ [ 56%] | |
F....................................................................... [ 56%] | |
........................................................................ [ 57%] | |
........................................................................ [ 57%] | |
...................................s.......s.......s.......s.......s.... [ 57%] | |
.ssssssssssssssssssssssssssssssssssssssssssssssssssssss................. [ 58%] | |
........................................................................ [ 58%] | |
........................................................................ [ 58%] | |
........................................................................ [ 59%] | |
........................................................................ [ 59%] | |
........................................................................ [ 59%] | |
.....................................................ssss............... [ 60%] | |
.............................................ssss........ssssssss....sss [ 60%] | |
sssssssssssssssssssssssssssss........ssss........ssssssss....sssssssssss [ 60%] | |
sssssssssssssssssssss........ssss........ssssssss....sssssssssssssssssss [ 61%] | |
sssssssssssss........ssss........ssssssss....sssssssssssssssssssssssssss [ 61%] | |
sssss............ssss................................................... [ 61%] | |
.........ssss....ssssssss....ssssssssssssssssssssssssssssssss........... [ 62%] | |
.ssss....ssssssss....ssssssssssssssssssssssssssssssss............ssss... [ 62%] | |
.ssssssss....ssssssssssssssssssssssssssssssss............ssss....sssssss [ 62%] | |
s....ssssssssssssssssssssssssssssssss................ssss............... [ 63%] | |
.............................................ssssssssssss....sssssssssss [ 63%] | |
sssssssssssssssssssss................ssssssssssss....sssssssssssssssssss [ 63%] | |
sssssssssssss................ssssssssssss....sssssssssssssssssssssssssss [ 64%] | |
sssss................ssssssssssss....ssssssssssssssssssssssssssssssss... [ 64%] | |
.................ssss................................................sss [ 64%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 65%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 65%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 65%] | |
sssssssssssssssssssssssssssssssssssss................ssss............... [ 66%] | |
.............................sssssssssssssssssssssssssssssssssssssssssss [ 66%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 67%] | |
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss... [ 67%] | |
.................ssss................................................... [ 67%] | |
.ssssssssssssssssssssssssssssssssssssssssssss....................sssssss [ 68%] | |
sssssssssssssssssssssssssssssssssssss....................sssssssssssssss [ 68%] | |
sssssssssssssssssssssssssssss....................sssssssssssssssssssssss [ 68%] | |
sssssssssssssssssssss................................ssss............... [ 69%] | |
.....................sssssssssssssssssssssssssssssssssssssssssssssssssss [ 69%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 69%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 70%] | |
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss........... [ 70%] | |
.................ssss................................sssssssssssssssssss [ 70%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 71%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 71%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 71%] | |
sssssssssssssssssssss................................ssss............... [ 72%] | |
.............sssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 73%] | |
sssssssssssssssssssssssssssssssssssssssssssssssssssss................... [ 73%] | |
.................ssss........................sssssssssssssssssssssssssss [ 73%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%] | |
sssssssssssss........................................ssss............... [ 75%] | |
.....sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 75%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 75%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 76%] | |
sssssssssssssssssssssssssssssssssssssssssssss........................... [ 76%] | |
.................ssss................sssssssssssssssssssssssssssssssssss [ 76%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 77%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 77%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 77%] | |
sssss................................................ssss............sss [ 78%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 78%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 78%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 79%] | |
sssssssssssssssssssssssssssssssssssss................................... [ 79%] | |
.................ssss........sssssssssssssssssssssssssssssssssssssssssss [ 79%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 80%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 80%] | |
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss... [ 80%] | |
.....................................................ssss............... [ 81%] | |
.ssssssss....ssssssssssssssssssssssssssssssssssss................sssssss [ 81%] | |
s....ssssssssssssssssssssssssssssssssssss................ssssssss....sss [ 81%] | |
sssssssssssssssssssssssssssssssss................ssssssss....sssssssssss [ 82%] | |
sssssssssssssssssssssssss............................................... [ 82%] | |
.................ssss............ssssssss....sssssssssssssssssssssssssss [ 82%] | |
sssss....ssss............ssssssss....ssssssssssssssssssssssssssssssss... [ 83%] | |
.ssss............ssssssss....ssssssssssssssssssssssssssssssss....ssss... [ 83%] | |
.........ssssssss....ssssssssssssssssssssssssssssssss....ssss........... [ 83%] | |
........................................................................ [ 84%] | |
................................................F............ss..sssssss [ 84%] | |
sssss....ssssssss..ss..ssssssssss......ssssssss..ss..ssssssssss......sss [ 84%] | |
sssss..ss..ssss..ssss......ssssssss..ss......ssssss......ssssssss..ss... [ 85%] | |
.ssssssss....ssssssss....ss..ssssssss....ssssss........ssssssss......... [ 85%] | |
.ssssssssssssss..ssssssss......ssssssssss..ss..ssssssss......ssssssssss. [ 85%] | |
.ss..ssssssss......ssss..ssss..ss..ssssssss..........ssssss..ss..sssssss [ 86%] | |
s........ssssssss..ss.......ss....ss....ss....ss........................ [ 86%] | |
........................................................................ [ 86%] | |
.... [ 86%] | |
unit/language/test_decorator.py .. [ 86%] | |
unit/language/test_libdevice.py ............. [ 87%] | |
unit/language/test_line_info.py ......ssssss.. [ 87%] | |
unit/language/test_matmul.py ...sss.s....ss......ss..FF..ss......ss..... [ 87%] | |
.ss.....sss.ssFFsssFssFFFssFFsssssssssFFFssFFsFFFssFFsFFFssFFssssssss... [ 87%] | |
sss.s....ss......ss..FF..ss......ss......ss.....sss.ssFFsssFssFFFssFFsss [ 87%] | |
ssssssFFFssFFsFFFssFFsFFFssFFssssssss............................FFFFFFF [ 88%] | |
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 88%] | |
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 88%] | |
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 89%] | |
FFFFFFFFFFFxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx [ 89%] | |
xxxxxxxxxxxFFFFFFFFFFFFFssssssssssssssssssssssssssssssssssssFFFFFFFFFFFF [ 89%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 90%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 90%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 90%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 91%] | |
ssssssssssssssssssssssssssssssssssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 91%] | |
FFFFFFFFFFFFssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 91%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 92%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 92%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 92%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 94%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 94%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%] | |
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 98%] | |
ssssssssssssssssssssssssssssssssssss [ 98%] | |
unit/language/test_mxfp.py ........................ [ 98%] | |
unit/language/test_pipeliner.py .............................F..s [ 98%] | |
unit/language/test_random.py ........................................... [ 98%] | |
................................................................ [ 98%] | |
unit/language/test_reproducer.py . [ 98%] | |
unit/language/test_standard.py ......................................... [ 99%] | |
.................................. [ 99%] | |
unit/language/test_subprocess.py ....................F......... [ 99%] | |
unit/language/test_tuple.py ........ [ 99%] | |
unit/language/test_warp_specialization.py .. [ 99%] | |
unit/runtime/test_autotuner.py .......s [ 99%] | |
unit/runtime/test_bindings.py .. [ 99%] | |
unit/runtime/test_cache.py ..........................s. [ 99%] | |
unit/runtime/test_cublas.py ........ [ 99%] | |
unit/runtime/test_driver.py .. [ 99%] | |
unit/runtime/test_jit.py . [ 99%] | |
unit/runtime/test_launch.py .. [ 99%] | |
unit/runtime/test_subproc.py ... [ 99%] | |
unit/test_debug.py ..FFFFFFFFFFFFFFFFFFFFFFF.FFFFFFFFFFFFFFFFFFFFF [ 99%] | |
unit/test_debug_dump.py F [ 99%] | |
unit/test_perf_warning.py s.F [ 99%] | |
unit/tools/test_aot.py ..... [ 99%] | |
unit/tools/test_disasm.py F [ 99%] | |
unit/tools/test_irsource.py . [100%] | |
=================================== FAILURES =================================== | |
_______________ test_experimental_tma_matmul[True-128-256-64-4] ________________ | |
num_stages = 4, BLOCK_M = 128, BLOCK_N = 256, BLOCK_K = 64, byval_tma = True | |
@pytest.mark.parametrize("num_stages", [1, 4]) | |
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) | |
@pytest.mark.parametrize("byval_tma", [True, False]) | |
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): | |
if not supports_tma(byval_tma): | |
pytest.skip(tma_skip_msg(byval_tma)) | |
device = "cuda" | |
M, N, K = 8192, 8192, 1024 | |
torch.manual_seed(42) | |
A = torch.randn((M, K), dtype=torch.float16, device=device) | |
B = torch.randn((K, N), dtype=torch.float16, device=device) | |
C = torch.empty((M, N), dtype=torch.float16, device=device) | |
if byval_tma: | |
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) | |
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) | |
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) | |
else: | |
desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size()) | |
desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) | |
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) | |
> kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, | |
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, | |
num_warps=8, num_stages=num_stages, dtype=tl.float16) | |
unit/cuda/test_experimental_tma.py:108: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0f64b90> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147480, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
_______________ test_experimental_tma_matmul[False-128-256-64-4] _______________ | |
num_stages = 4, BLOCK_M = 128, BLOCK_N = 256, BLOCK_K = 64, byval_tma = False | |
@pytest.mark.parametrize("num_stages", [1, 4]) | |
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) | |
@pytest.mark.parametrize("byval_tma", [True, False]) | |
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma): | |
if not supports_tma(byval_tma): | |
pytest.skip(tma_skip_msg(byval_tma)) | |
device = "cuda" | |
M, N, K = 8192, 8192, 1024 | |
torch.manual_seed(42) | |
A = torch.randn((M, K), dtype=torch.float16, device=device) | |
B = torch.randn((K, N), dtype=torch.float16, device=device) | |
C = torch.empty((M, N), dtype=torch.float16, device=device) | |
if byval_tma: | |
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size()) | |
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size()) | |
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size()) | |
else: | |
desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size()) | |
desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size()) | |
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size()) | |
> kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, | |
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma, | |
num_warps=8, num_stages=num_stages, dtype=tl.float16) | |
unit/cuda/test_experimental_tma.py:108: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0f780e0> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147480, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
________ test_experimental_make_tensor_descriptor_matmul[128-256-64-4] _________ | |
num_stages = 4, BLOCK_M = 128, BLOCK_N = 256, BLOCK_K = 64 | |
@requires_tma | |
@pytest.mark.interpreter | |
@pytest.mark.parametrize("num_stages", [1, 4]) | |
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)]) | |
def test_experimental_make_tensor_descriptor_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K): | |
device = "cuda" | |
if is_interpreter(): | |
M, N, K = BLOCK_M, BLOCK_N, BLOCK_K | |
else: | |
M, N, K = 8192, 8192, 1024 | |
torch.manual_seed(42) | |
A = torch.randn((M, K), dtype=torch.float16, device=device) | |
B = torch.randn((K, N), dtype=torch.float16, device=device) | |
C = torch.empty((M, N), dtype=torch.float16, device=device) | |
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1) | |
def alloc_fn(size: int, align: int, stream: Optional[int]): | |
assert size == 3 * 128 * grid[0] | |
assert align == 128 | |
assert stream == 0 | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
> kernel = matmul_kernel_make_tensor_desciptor[grid]( | |
A, | |
B, | |
C, | |
M, | |
N, | |
K, | |
BLOCK_M, | |
BLOCK_N, | |
BLOCK_K, | |
num_warps=8, | |
num_stages=num_stages, | |
) | |
unit/cuda/test_experimental_tma.py:723: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1d68557d70> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147480, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
__________________ test_tensor_descriptor_batched_gemm_2d_tma __________________ | |
@requires_tma | |
@pytest.mark.interpreter | |
def test_tensor_descriptor_batched_gemm_2d_tma(): | |
device = "cuda" | |
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 | |
if is_interpreter(): | |
B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K | |
else: | |
B, M, N, K = 2, 1024, 1024, 128 | |
NUM_SMS = 96 | |
num_stages = 3 | |
grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) | |
a = torch.randn((B, M, K), device=device, dtype=torch.float16) | |
b = torch.randn((B, N, K), device=device, dtype=torch.float16) | |
c = torch.empty((B, M, N), device=device, dtype=torch.float16) | |
expect = torch.bmm(a, b.mT) | |
def alloc_fn(size: int, align: int, stream: Optional[int]): | |
# TODO: should only need num_stages * 3 descriptors per SM | |
assert size == 128 * 3 * (num_stages + 1) * grid[0] | |
assert align == 128 | |
assert stream == 0 | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
> batched_gemm_2d_tma_kernel[grid]( | |
a, b, c, # | |
B, M, N, K, # | |
tl.float16, # | |
BLOCK_M, BLOCK_N, BLOCK_K, # | |
NUM_SMS, # | |
num_stages=num_stages, num_warps=8) | |
unit/cuda/test_experimental_tma.py:920: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af14fafc0> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 196624, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
__________________ test_tensor_descriptor_batched_gemm_3d_tma __________________ | |
@requires_tma | |
@pytest.mark.interpreter | |
def test_tensor_descriptor_batched_gemm_3d_tma(): | |
device = "cuda" | |
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 | |
if is_interpreter(): | |
B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K | |
else: | |
B, M, N, K = 2, 1024, 1024, 128 | |
NUM_SMS = 96 | |
num_stages = 3 | |
grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), ) | |
a = torch.randn((B, M, K), device=device, dtype=torch.float16) | |
b = torch.randn((B, N, K), device=device, dtype=torch.float16) | |
c = torch.empty((B, M, N), device=device, dtype=torch.float16) | |
expect = torch.bmm(a, b.mT) | |
def alloc_fn(size: int, align: int, stream: Optional[int]): | |
# TODO: should only need num_stages * 3 descriptors per SM | |
assert size == 128 * 3 * grid[0] | |
assert align == 128 | |
assert stream == 0 | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
> h = batched_gemm_3d_tma_kernel[grid]( | |
a, b, c, # | |
B, M, N, K, # | |
tl.float16, # | |
BLOCK_M, BLOCK_N, BLOCK_K, # | |
NUM_SMS, # | |
num_stages=num_stages, num_warps=8) | |
unit/cuda/test_experimental_tma.py:1020: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af212ce90> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 196624, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
_____________________________ test_op[4-48-256-64] _____________________________ | |
Z = 4, H = 48, N_CTX = 256, D_HEAD = 64, dtype = torch.float16 | |
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ | |
(4, 48, 128, 64), | |
(4, 48, 256, 64), | |
(4, 48, 512, 64), | |
(4, 48, 1024, 64), | |
(4, 48, 2048, 64), | |
(4, 48, 4096, 64), | |
# (4, 48, 8192, 64), out of memory | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") | |
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): | |
torch.manual_seed(20) | |
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() | |
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() | |
sm_scale = 0.2 | |
dout = torch.randn_like(q) | |
# reference implementation | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) | |
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
for z in range(Z): | |
for h in range(H): | |
p[:, :, M == 0] = float("-inf") | |
p = torch.softmax(p.float(), dim=-1).half() | |
# p = torch.exp(p) | |
ref_out = torch.matmul(p, v) | |
ref_out.backward(dout) | |
ref_dv, v.grad = v.grad.clone(), None | |
ref_dk, k.grad = k.grad.clone(), None | |
ref_dq, q.grad = q.grad.clone(), None | |
# triton implementation | |
tri_out = attention(q, k, v, sm_scale) | |
# print(ref_out) | |
# print(tri_out) | |
> tri_out.backward(dout) | |
unit/cuda/test_flashattention.py:395: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward | |
torch.autograd.backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward | |
_engine_run_backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: in _engine_run_backward | |
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/function.py:307: in apply | |
return user_fn(self, *args) | |
unit/cuda/test_flashattention.py:342: in backward | |
_bwd_kernel[(ctx.grid[1], )]( | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0248620> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
_____________________________ test_op[4-48-512-64] _____________________________ | |
Z = 4, H = 48, N_CTX = 512, D_HEAD = 64, dtype = torch.float16 | |
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ | |
(4, 48, 128, 64), | |
(4, 48, 256, 64), | |
(4, 48, 512, 64), | |
(4, 48, 1024, 64), | |
(4, 48, 2048, 64), | |
(4, 48, 4096, 64), | |
# (4, 48, 8192, 64), out of memory | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") | |
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): | |
torch.manual_seed(20) | |
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() | |
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() | |
sm_scale = 0.2 | |
dout = torch.randn_like(q) | |
# reference implementation | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) | |
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
for z in range(Z): | |
for h in range(H): | |
p[:, :, M == 0] = float("-inf") | |
p = torch.softmax(p.float(), dim=-1).half() | |
# p = torch.exp(p) | |
ref_out = torch.matmul(p, v) | |
ref_out.backward(dout) | |
ref_dv, v.grad = v.grad.clone(), None | |
ref_dk, k.grad = k.grad.clone(), None | |
ref_dq, q.grad = q.grad.clone(), None | |
# triton implementation | |
tri_out = attention(q, k, v, sm_scale) | |
# print(ref_out) | |
# print(tri_out) | |
> tri_out.backward(dout) | |
unit/cuda/test_flashattention.py:395: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward | |
torch.autograd.backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward | |
_engine_run_backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: in _engine_run_backward | |
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/function.py:307: in apply | |
return user_fn(self, *args) | |
unit/cuda/test_flashattention.py:342: in backward | |
_bwd_kernel[(ctx.grid[1], )]( | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0248620> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
____________________________ test_op[4-48-1024-64] _____________________________ | |
Z = 4, H = 48, N_CTX = 1024, D_HEAD = 64, dtype = torch.float16 | |
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ | |
(4, 48, 128, 64), | |
(4, 48, 256, 64), | |
(4, 48, 512, 64), | |
(4, 48, 1024, 64), | |
(4, 48, 2048, 64), | |
(4, 48, 4096, 64), | |
# (4, 48, 8192, 64), out of memory | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") | |
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): | |
torch.manual_seed(20) | |
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() | |
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() | |
sm_scale = 0.2 | |
dout = torch.randn_like(q) | |
# reference implementation | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) | |
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
for z in range(Z): | |
for h in range(H): | |
p[:, :, M == 0] = float("-inf") | |
p = torch.softmax(p.float(), dim=-1).half() | |
# p = torch.exp(p) | |
ref_out = torch.matmul(p, v) | |
ref_out.backward(dout) | |
ref_dv, v.grad = v.grad.clone(), None | |
ref_dk, k.grad = k.grad.clone(), None | |
ref_dq, q.grad = q.grad.clone(), None | |
# triton implementation | |
tri_out = attention(q, k, v, sm_scale) | |
# print(ref_out) | |
# print(tri_out) | |
> tri_out.backward(dout) | |
unit/cuda/test_flashattention.py:395: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward | |
torch.autograd.backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward | |
_engine_run_backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: in _engine_run_backward | |
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/function.py:307: in apply | |
return user_fn(self, *args) | |
unit/cuda/test_flashattention.py:342: in backward | |
_bwd_kernel[(ctx.grid[1], )]( | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:591: in run | |
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, | |
../triton/compiler/compiler.py:413: in __getattribute__ | |
self._init_handles() | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0248620> | |
def _init_handles(self): | |
if self.module is not None: | |
return | |
device = driver.active.get_current_device() | |
# create launcher | |
self.run = driver.active.launcher_cls(self.src, self.metadata) | |
# not enough shared memory to run the kernel | |
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"] | |
if self.metadata.shared > max_shared: | |
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory") | |
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help. | |
../triton/compiler/compiler.py:401: OutOfResources | |
____________________________ test_op[4-48-2048-64] _____________________________ | |
Z = 4, H = 48, N_CTX = 2048, D_HEAD = 64, dtype = torch.float16 | |
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ | |
(4, 48, 128, 64), | |
(4, 48, 256, 64), | |
(4, 48, 512, 64), | |
(4, 48, 1024, 64), | |
(4, 48, 2048, 64), | |
(4, 48, 4096, 64), | |
# (4, 48, 8192, 64), out of memory | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") | |
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): | |
torch.manual_seed(20) | |
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() | |
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() | |
sm_scale = 0.2 | |
dout = torch.randn_like(q) | |
# reference implementation | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) | |
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
for z in range(Z): | |
for h in range(H): | |
p[:, :, M == 0] = float("-inf") | |
p = torch.softmax(p.float(), dim=-1).half() | |
# p = torch.exp(p) | |
ref_out = torch.matmul(p, v) | |
> ref_out.backward(dout) | |
unit/cuda/test_flashattention.py:387: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward | |
torch.autograd.backward( | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward | |
_engine_run_backward( | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
t_outputs = (tensor([[[[ 0.3435, 0.1052, 0.6152, ..., 0.2805, 0.1891, 0.5288], | |
[ 0.4490, 0.2205, 0.3833, ..., ...993, ..., 0.2983, 0.3005, 0.2988]]]], | |
device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>),) | |
args = ((tensor([[[[-8.0859e-01, 7.9541e-01, 2.0740e-01, ..., 7.9773e-02, | |
-1.7402e+00, 4.6851e-01], | |
... ..., 6.1133e-01, | |
1.3564e+00, -1.8457e+00]]]], device='cuda:0', dtype=torch.float16),), False, False, ()) | |
kwargs = {'accumulate_grad': True, 'allow_unreachable': True} | |
attach_logging_hooks = False | |
def _engine_run_backward( | |
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]], | |
*args: Any, | |
**kwargs: Any, | |
) -> tuple[torch.Tensor, ...]: | |
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG | |
if attach_logging_hooks: | |
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) | |
try: | |
> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass | |
t_outputs, *args, **kwargs | |
) # Calls into the C++ engine to run the backward pass | |
E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB. GPU 0 has a total capacity of 15.48 GiB of which 2.32 GiB is free. Including non-PyTorch memory, this process has 13.15 GiB memory in use. Of the allocated memory 9.70 GiB is allocated by PyTorch, and 1.74 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) | |
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: OutOfMemoryError | |
____________________________ test_op[4-48-4096-64] _____________________________ | |
Z = 4, H = 48, N_CTX = 4096, D_HEAD = 64, dtype = torch.float16 | |
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ | |
(4, 48, 128, 64), | |
(4, 48, 256, 64), | |
(4, 48, 512, 64), | |
(4, 48, 1024, 64), | |
(4, 48, 2048, 64), | |
(4, 48, 4096, 64), | |
# (4, 48, 8192, 64), out of memory | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") | |
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): | |
torch.manual_seed(20) | |
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() | |
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() | |
sm_scale = 0.2 | |
dout = torch.randn_like(q) | |
# reference implementation | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) | |
> p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 15.48 GiB of which 5.32 GiB is free. Including non-PyTorch memory, this process has 10.15 GiB memory in use. Of the allocated memory 7.13 GiB is allocated by PyTorch, and 1.30 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) | |
unit/cuda/test_flashattention.py:380: OutOfMemoryError | |
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float16-True] ____________ | |
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = True | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f19bb9dd220> | |
src = <triton._C.libtriton.ir.module object at 0x7f19bb94bbf0> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%5 = arith.extsi %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%6 = arith.extsi %4 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked> | |
%8 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi64, #blocked1> | |
%9 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked> | |
%10 = arith.muli %7, %9 : tensor<256x1xi64, #blocked> | |
%11 = tt.broadcast %10 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%12 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%14 = arith.extsi %12 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%15 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%17 = tt.broadcast %16 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%18 = arith.addi %11, %17 : tensor<256x16xi64, #blocked> | |
%19 = tt.addptr %2, %18 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked> | |
%20 = tt.load %19 : tensor<256x16x!tt.ptr<f16>, #blocked> | |
%21 = ttg.local_alloc %20 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem> | |
%22 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%23 = tt.expand_dims %14 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2> | |
%24 = tt.broadcast %23 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%26 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%27 = arith.extsi %26 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%28 = arith.extsi %25 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%29 = tt.expand_dims %27 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> | |
%30 = tt.expand_dims %28 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi64, #blocked1> | |
%31 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2> | |
%32 = arith.muli %29, %31 : tensor<1x64xi64, #blocked2> | |
%33 = tt.broadcast %32 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%34 = arith.addi %24, %33 : tensor<16x64xi64, #blocked2> | |
%35 = tt.addptr %22, %34 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2> | |
%36 = tt.load %35 : tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%37 = ttg.local_alloc %36 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem> | |
%38 = ttg.local_load %21 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%39 = ttg.local_load %37 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma> | |
%41 = arith.truncf %40 : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma> | |
%42 = arith.extsi %arg8 : i32 to i64 | |
%43 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked1> | |
%44 = tt.splat %42 : i64 -> tensor<256x1xi64, #blocked1> | |
%45 = arith.muli %8, %44 : tensor<256x1xi64, #blocked1> | |
%46 = tt.broadcast %45 : tensor<256x1xi64, #blocked1> -> tensor<256x64xi64, #blocked1> | |
%47 = tt.broadcast %30 : tensor<1x64xi64, #blocked1> -> tensor<256x64xi64, #blocked1> | |
%48 = arith.addi %46, %47 : tensor<256x64xi64, #blocked1> | |
%49 = tt.addptr %43, %48 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi64, #blocked1> | |
%50 = ttg.convert_layout %41 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #blocked1> | |
tt.store %49, %50 : tensor<256x64x!tt.ptr<f16>, #blocked1> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float32-True] ____________ | |
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = True | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f19bba4d370> | |
src = <triton._C.libtriton.ir.module object at 0x7f19bba77890> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%5 = arith.extsi %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%6 = arith.extsi %4 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%7 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi64, #blocked1> | |
%8 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked> | |
%9 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked> | |
%10 = arith.muli %8, %9 : tensor<256x1xi64, #blocked> | |
%11 = tt.broadcast %10 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%12 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%14 = arith.extsi %12 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%15 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%17 = tt.broadcast %16 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%18 = arith.addi %11, %17 : tensor<256x16xi64, #blocked> | |
%19 = tt.addptr %2, %18 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked> | |
%20 = tt.load %19 : tensor<256x16x!tt.ptr<f16>, #blocked> | |
%21 = ttg.local_alloc %20 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem> | |
%22 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%23 = tt.expand_dims %14 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2> | |
%24 = tt.broadcast %23 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%26 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%27 = arith.extsi %26 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%28 = arith.extsi %25 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%29 = tt.expand_dims %27 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi64, #blocked1> | |
%30 = tt.expand_dims %28 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> | |
%31 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2> | |
%32 = arith.muli %30, %31 : tensor<1x64xi64, #blocked2> | |
%33 = tt.broadcast %32 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%34 = arith.addi %24, %33 : tensor<16x64xi64, #blocked2> | |
%35 = tt.addptr %22, %34 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2> | |
%36 = tt.load %35 : tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%37 = ttg.local_alloc %36 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem> | |
%38 = ttg.local_load %21 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%39 = ttg.local_load %37 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma> | |
%41 = arith.extsi %arg8 : i32 to i64 | |
%42 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #blocked1> | |
%43 = tt.splat %41 : i64 -> tensor<256x1xi64, #blocked1> | |
%44 = arith.muli %7, %43 : tensor<256x1xi64, #blocked1> | |
%45 = tt.broadcast %44 : tensor<256x1xi64, #blocked1> -> tensor<256x64xi64, #blocked1> | |
%46 = tt.broadcast %29 : tensor<1x64xi64, #blocked1> -> tensor<256x64xi64, #blocked1> | |
%47 = arith.addi %45, %46 : tensor<256x64xi64, #blocked1> | |
%48 = tt.addptr %42, %47 : tensor<256x64x!tt.ptr<f32>, #blocked1>, tensor<256x64xi64, #blocked1> | |
%49 = ttg.convert_layout %40 : tensor<256x64xf32, #mma> -> tensor<256x64xf32, #blocked1> | |
tt.store %48, %49 : tensor<256x64x!tt.ptr<f32>, #blocked1> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
___________ test_gemm_no_scf[128-128-16-4-4-False-True-float16-True] ___________ | |
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = True | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f1d685f4110> | |
src = <triton._C.libtriton.ir.module object at 0x7f19bba77170> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%7 = arith.extsi %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%8 = arith.extsi %6 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%9 = arith.extsi %5 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%10 = arith.extsi %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi64, #blocked1> | |
%12 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> | |
%13 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked> | |
%14 = arith.muli %12, %13 : tensor<128x1xi64, #blocked> | |
%15 = tt.broadcast %14 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%18 = arith.extsi %17 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%19 = arith.extsi %16 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%20 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%21 = tt.broadcast %20 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%22 = arith.addi %15, %21 : tensor<128x16xi64, #blocked> | |
%23 = tt.addptr %2, %22 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked> | |
%24 = tt.load %23 : tensor<128x16x!tt.ptr<f16>, #blocked> | |
%25 = ttg.local_alloc %24 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem> | |
%26 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked2> | |
%27 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2> | |
%28 = tt.broadcast %27 : tensor<16x1xi64, #blocked2> -> tensor<16x128xi64, #blocked2> | |
%29 = tt.expand_dims %8 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi64, #blocked2> | |
%30 = tt.expand_dims %9 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi64, #blocked1> | |
%31 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked2> | |
%32 = arith.muli %29, %31 : tensor<1x128xi64, #blocked2> | |
%33 = tt.broadcast %32 : tensor<1x128xi64, #blocked2> -> tensor<16x128xi64, #blocked2> | |
%34 = arith.addi %28, %33 : tensor<16x128xi64, #blocked2> | |
%35 = tt.addptr %26, %34 : tensor<16x128x!tt.ptr<f16>, #blocked2>, tensor<16x128xi64, #blocked2> | |
%36 = tt.load %35 : tensor<16x128x!tt.ptr<f16>, #blocked2> | |
%37 = ttg.local_alloc %36 : (tensor<16x128xf16, #blocked2>) -> !ttg.memdesc<16x128xf16, #shared1, #smem> | |
%38 = ttg.local_load %25 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%39 = ttg.local_load %37 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> | |
%41 = arith.truncf %40 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> | |
%42 = arith.extsi %arg8 : i32 to i64 | |
%43 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked1> | |
%44 = tt.splat %42 : i64 -> tensor<128x1xi64, #blocked1> | |
%45 = arith.muli %11, %44 : tensor<128x1xi64, #blocked1> | |
%46 = tt.broadcast %45 : tensor<128x1xi64, #blocked1> -> tensor<128x128xi64, #blocked1> | |
%47 = tt.broadcast %30 : tensor<1x128xi64, #blocked1> -> tensor<128x128xi64, #blocked1> | |
%48 = arith.addi %46, %47 : tensor<128x128xi64, #blocked1> | |
%49 = tt.addptr %43, %48 : tensor<128x128x!tt.ptr<f16>, #blocked1>, tensor<128x128xi64, #blocked1> | |
%50 = ttg.convert_layout %41 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1> | |
tt.store %49, %50 : tensor<128x128x!tt.ptr<f16>, #blocked1> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
___________ test_gemm_no_scf[128-128-16-4-4-False-True-float32-True] ___________ | |
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = True | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f1ae013cd10> | |
src = <triton._C.libtriton.ir.module object at 0x7f19bba76570> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%7 = arith.extsi %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%8 = arith.extsi %6 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%9 = arith.extsi %3 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%10 = arith.extsi %5 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%11 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> | |
%12 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2> | |
%13 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked> | |
%14 = arith.muli %11, %13 : tensor<128x1xi64, #blocked> | |
%15 = tt.broadcast %14 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%18 = arith.extsi %16 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%19 = arith.extsi %17 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%21 = tt.broadcast %20 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%22 = arith.addi %15, %21 : tensor<128x16xi64, #blocked> | |
%23 = tt.addptr %2, %22 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked> | |
%24 = tt.load %23 : tensor<128x16x!tt.ptr<f16>, #blocked> | |
%25 = ttg.local_alloc %24 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem> | |
%26 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked1> | |
%27 = tt.expand_dims %19 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi64, #blocked1> | |
%28 = tt.broadcast %27 : tensor<16x1xi64, #blocked1> -> tensor<16x128xi64, #blocked1> | |
%29 = tt.expand_dims %9 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi64, #blocked1> | |
%30 = tt.expand_dims %10 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi64, #blocked2> | |
%31 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked1> | |
%32 = arith.muli %29, %31 : tensor<1x128xi64, #blocked1> | |
%33 = tt.broadcast %32 : tensor<1x128xi64, #blocked1> -> tensor<16x128xi64, #blocked1> | |
%34 = arith.addi %28, %33 : tensor<16x128xi64, #blocked1> | |
%35 = tt.addptr %26, %34 : tensor<16x128x!tt.ptr<f16>, #blocked1>, tensor<16x128xi64, #blocked1> | |
%36 = tt.load %35 : tensor<16x128x!tt.ptr<f16>, #blocked1> | |
%37 = ttg.local_alloc %36 : (tensor<16x128xf16, #blocked1>) -> !ttg.memdesc<16x128xf16, #shared1, #smem> | |
%38 = ttg.local_load %25 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%39 = ttg.local_load %37 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> | |
%41 = arith.extsi %arg8 : i32 to i64 | |
%42 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #blocked2> | |
%43 = tt.splat %41 : i64 -> tensor<128x1xi64, #blocked2> | |
%44 = arith.muli %12, %43 : tensor<128x1xi64, #blocked2> | |
%45 = tt.broadcast %44 : tensor<128x1xi64, #blocked2> -> tensor<128x128xi64, #blocked2> | |
%46 = tt.broadcast %30 : tensor<1x128xi64, #blocked2> -> tensor<128x128xi64, #blocked2> | |
%47 = arith.addi %45, %46 : tensor<128x128xi64, #blocked2> | |
%48 = tt.addptr %42, %47 : tensor<128x128x!tt.ptr<f32>, #blocked2>, tensor<128x128xi64, #blocked2> | |
%49 = ttg.convert_layout %40 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked2> | |
tt.store %48, %49 : tensor<128x128x!tt.ptr<f32>, #blocked2> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float16-False] ___________ | |
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = False | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f1af0f3ab70> | |
src = <triton._C.libtriton.ir.module object at 0x7f1af0fb3e90> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%5 = arith.extsi %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked> | |
%7 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked> | |
%8 = arith.muli %6, %7 : tensor<256x1xi64, #blocked> | |
%9 = tt.broadcast %8 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%12 = arith.extsi %11 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%13 = arith.extsi %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%15 = tt.broadcast %14 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%16 = arith.addi %9, %15 : tensor<256x16xi64, #blocked> | |
%17 = tt.addptr %2, %16 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked> | |
%18 = tt.load %17 : tensor<256x16x!tt.ptr<f16>, #blocked> | |
%19 = ttg.local_alloc %18 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem> | |
%20 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%21 = tt.expand_dims %13 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2> | |
%22 = tt.broadcast %21 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%25 = arith.extsi %24 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%26 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> | |
%27 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2> | |
%28 = arith.muli %26, %27 : tensor<1x64xi64, #blocked2> | |
%29 = tt.broadcast %28 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%30 = arith.addi %22, %29 : tensor<16x64xi64, #blocked2> | |
%31 = tt.addptr %20, %30 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2> | |
%32 = tt.load %31 : tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%33 = ttg.local_alloc %32 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem> | |
%34 = ttg.local_load %19 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%35 = ttg.local_load %33 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma> | |
%37 = arith.truncf %36 : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma> | |
%38 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> | |
%39 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked1> | |
%40 = arith.muli %38, %39 : tensor<256x1xi32, #blocked1> | |
%41 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1> | |
%42 = tt.addptr %41, %40 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1> | |
%43 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> | |
%44 = tt.broadcast %42 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1> | |
%45 = tt.broadcast %43 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> | |
%46 = tt.addptr %44, %45 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1> | |
%47 = ttg.convert_layout %37 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #blocked1> | |
tt.store %46, %47 : tensor<256x64x!tt.ptr<f16>, #blocked1> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float32-False] ___________ | |
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = False | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f1af0fd04d0> | |
src = <triton._C.libtriton.ir.module object at 0x7f1af0fb37d0> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%5 = arith.extsi %4 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked> | |
%7 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked> | |
%8 = arith.muli %6, %7 : tensor<256x1xi64, #blocked> | |
%9 = tt.broadcast %8 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%12 = arith.extsi %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%13 = arith.extsi %11 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%15 = tt.broadcast %14 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked> | |
%16 = arith.addi %9, %15 : tensor<256x16xi64, #blocked> | |
%17 = tt.addptr %2, %16 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked> | |
%18 = tt.load %17 : tensor<256x16x!tt.ptr<f16>, #blocked> | |
%19 = ttg.local_alloc %18 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem> | |
%20 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%21 = tt.expand_dims %12 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2> | |
%22 = tt.broadcast %21 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%25 = arith.extsi %24 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%26 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2> | |
%27 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2> | |
%28 = arith.muli %26, %27 : tensor<1x64xi64, #blocked2> | |
%29 = tt.broadcast %28 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2> | |
%30 = arith.addi %22, %29 : tensor<16x64xi64, #blocked2> | |
%31 = tt.addptr %20, %30 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2> | |
%32 = tt.load %31 : tensor<16x64x!tt.ptr<f16>, #blocked2> | |
%33 = ttg.local_alloc %32 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem> | |
%34 = ttg.local_load %19 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%35 = ttg.local_load %33 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma> | |
%37 = tt.expand_dims %3 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> | |
%38 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked1> | |
%39 = arith.muli %37, %38 : tensor<256x1xi32, #blocked1> | |
%40 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x1x!tt.ptr<f32>, #blocked1> | |
%41 = tt.addptr %40, %39 : tensor<256x1x!tt.ptr<f32>, #blocked1>, tensor<256x1xi32, #blocked1> | |
%42 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> | |
%43 = tt.broadcast %41 : tensor<256x1x!tt.ptr<f32>, #blocked1> -> tensor<256x64x!tt.ptr<f32>, #blocked1> | |
%44 = tt.broadcast %42 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> | |
%45 = tt.addptr %43, %44 : tensor<256x64x!tt.ptr<f32>, #blocked1>, tensor<256x64xi32, #blocked1> | |
%46 = ttg.convert_layout %36 : tensor<256x64xf32, #mma> -> tensor<256x64xf32, #blocked1> | |
tt.store %45, %46 : tensor<256x64x!tt.ptr<f32>, #blocked1> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
__________ test_gemm_no_scf[128-128-16-4-4-False-True-float16-False] ___________ | |
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = False | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f1af0f756a0> | |
src = <triton._C.libtriton.ir.module object at 0x7f19bba77dd0> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%7 = arith.extsi %6 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%8 = arith.extsi %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%9 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> | |
%10 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked> | |
%11 = arith.muli %9, %10 : tensor<128x1xi64, #blocked> | |
%12 = tt.broadcast %11 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%15 = arith.extsi %14 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%16 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%18 = tt.broadcast %17 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%19 = arith.addi %12, %18 : tensor<128x16xi64, #blocked> | |
%20 = tt.addptr %2, %19 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked> | |
%21 = tt.load %20 : tensor<128x16x!tt.ptr<f16>, #blocked> | |
%22 = ttg.local_alloc %21 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem> | |
%23 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked2> | |
%24 = tt.expand_dims %15 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2> | |
%25 = tt.broadcast %24 : tensor<16x1xi64, #blocked2> -> tensor<16x128xi64, #blocked2> | |
%26 = tt.expand_dims %7 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi64, #blocked2> | |
%27 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked2> | |
%28 = arith.muli %26, %27 : tensor<1x128xi64, #blocked2> | |
%29 = tt.broadcast %28 : tensor<1x128xi64, #blocked2> -> tensor<16x128xi64, #blocked2> | |
%30 = arith.addi %25, %29 : tensor<16x128xi64, #blocked2> | |
%31 = tt.addptr %23, %30 : tensor<16x128x!tt.ptr<f16>, #blocked2>, tensor<16x128xi64, #blocked2> | |
%32 = tt.load %31 : tensor<16x128x!tt.ptr<f16>, #blocked2> | |
%33 = ttg.local_alloc %32 : (tensor<16x128xf16, #blocked2>) -> !ttg.memdesc<16x128xf16, #shared1, #smem> | |
%34 = ttg.local_load %22 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%35 = ttg.local_load %33 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> | |
%37 = arith.truncf %36 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> | |
%38 = tt.expand_dims %3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> | |
%39 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1> | |
%40 = arith.muli %38, %39 : tensor<128x1xi32, #blocked1> | |
%41 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1> | |
%42 = tt.addptr %41, %40 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1> | |
%43 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> | |
%44 = tt.broadcast %42 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x128x!tt.ptr<f16>, #blocked1> | |
%45 = tt.broadcast %43 : tensor<1x128xi32, #blocked1> -> tensor<128x128xi32, #blocked1> | |
%46 = tt.addptr %44, %45 : tensor<128x128x!tt.ptr<f16>, #blocked1>, tensor<128x128xi32, #blocked1> | |
%47 = ttg.convert_layout %37 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1> | |
tt.store %46, %47 : tensor<128x128x!tt.ptr<f16>, #blocked1> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
__________ test_gemm_no_scf[128-128-16-4-4-False-True-float32-False] ___________ | |
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False | |
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = False | |
@pytest.mark.parametrize( | |
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE', | |
itertools.chain(*[[ | |
# numCTAs = 1, no TMA multicast: | |
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# static mask, cluster 4x1 | |
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# dynamic mask, cluster 2x2 | |
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE], | |
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
# small M, N | |
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE], | |
] for USE_TMA_EPILOGUE in [True, False]])) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE): | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
if OUTPUT_TYPE == "float16": | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
else: | |
c = torch.empty((M, N), device=a.device, dtype=torch.float32) | |
> matmul_no_scf_kernel[(1, 1)]( | |
a_ptr=a, b_ptr=b, c_ptr=c, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_cm=c.stride(0), stride_cn=c.stride(1), # | |
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # | |
num_warps=NUM_WARPS, # | |
num_ctas=NUM_CTAS, # | |
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # | |
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE) | |
unit/cuda/test_gemm.py:125: | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
../triton/runtime/jit.py:347: in <lambda> | |
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) | |
../triton/runtime/jit.py:569: in run | |
kernel = self.compile(src, target=target, options=options.__dict__) | |
../triton/compiler/compiler.py:284: in compile | |
next_module = compile_ir(module, metadata) | |
../triton/backends/nvidia/compiler.py:450: in <lambda> | |
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) | |
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ | |
self = <nvidia.CUDABackend object at 0x7f1a540990a0> | |
src = <triton._C.libtriton.ir.module object at 0x7f19bb94b050> | |
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...} | |
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120') | |
capability = 120 | |
def make_llir(self, src, metadata, options, capability): | |
ptx_version = get_ptx_version_from_options(options, self.target.arch) | |
mod = src | |
# TritonGPU -> LLVM-IR (MLIR) | |
pm = ir.pass_manager(mod.context) | |
pm.enable_debug() | |
nvidia.passes.ttnvgpuir.add_lower_mma(pm) | |
passes.ttgpuir.add_combine_tensor_select_and_if(pm) | |
passes.ttgpuir.add_allocate_warp_groups(pm) | |
passes.convert.add_scf_to_cf(pm) | |
passes.ttgpuir.add_allocate_shared_memory(pm) | |
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm) | |
passes.ttgpuir.add_allocate_global_scratch_memory(pm) | |
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm) | |
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm) | |
passes.common.add_canonicalizer(pm) | |
passes.common.add_cse(pm) | |
passes.common.add_symbol_dce(pm) | |
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": | |
passes.llvmir.add_di_scope(pm) | |
> pm.run(mod) | |
E RuntimeError: PassManager::run failed | |
../triton/backends/nvidia/compiler.py:341: RuntimeError | |
----------------------------- Captured stderr call ----------------------------- | |
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed. | |
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}> | |
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}> | |
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} { | |
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> | |
%0 = arith.extsi %arg6 : i32 to i64 | |
%1 = arith.extsi %arg7 : i32 to i64 | |
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked> | |
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> | |
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> | |
%7 = arith.extsi %3 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%8 = arith.extsi %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%9 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked> | |
%10 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked> | |
%11 = arith.muli %9, %10 : tensor<128x1xi64, #blocked> | |
%12 = tt.broadcast %11 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%15 = arith.extsi %14 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%16 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%17 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked> | |
%18 = tt.broadcast %17 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked> | |
%19 = arith.addi %12, %18 : tensor<128x16xi64, #blocked> | |
%20 = tt.addptr %2, %19 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked> | |
%21 = tt.load %20 : tensor<128x16x!tt.ptr<f16>, #blocked> | |
%22 = ttg.local_alloc %21 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem> | |
%23 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked1> | |
%24 = tt.expand_dims %16 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi64, #blocked1> | |
%25 = tt.broadcast %24 : tensor<16x1xi64, #blocked1> -> tensor<16x128xi64, #blocked1> | |
%26 = tt.expand_dims %7 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi64, #blocked1> | |
%27 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked1> | |
%28 = arith.muli %26, %27 : tensor<1x128xi64, #blocked1> | |
%29 = tt.broadcast %28 : tensor<1x128xi64, #blocked1> -> tensor<16x128xi64, #blocked1> | |
%30 = arith.addi %25, %29 : tensor<16x128xi64, #blocked1> | |
%31 = tt.addptr %23, %30 : tensor<16x128x!tt.ptr<f16>, #blocked1>, tensor<16x128xi64, #blocked1> | |
%32 = tt.load %31 : tensor<16x128x!tt.ptr<f16>, #blocked1> | |
%33 = ttg.local_alloc %32 : (tensor<16x128xf16, #blocked1>) -> !ttg.memdesc<16x128xf16, #shared1, #smem> | |
%34 = ttg.local_load %22 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> | |
%35 = ttg.local_load %33 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | |
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> | |
%37 = tt.expand_dims %6 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> | |
%38 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2> | |
%39 = arith.muli %37, %38 : tensor<128x1xi32, #blocked2> | |
%40 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked2> | |
%41 = tt.addptr %40, %39 : tensor<128x1x!tt.ptr<f32>, #blocked2>, tensor<128x1xi32, #blocked2> | |
%42 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi32, #blocked2> | |
%43 = tt.broadcast %41 : tensor<128x1x!tt.ptr<f32>, #blocked2> -> tensor<128x128x!tt.ptr<f32>, #blocked2> | |
%44 = tt.broadcast %42 : tensor<1x128xi32, #blocked2> -> tensor<128x128xi32, #blocked2> | |
%45 = tt.addptr %43, %44 : tensor<128x128x!tt.ptr<f32>, #blocked2>, tensor<128x128xi32, #blocked2> | |
%46 = ttg.convert_layout %36 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked2> | |
tt.store %45, %46 : tensor<128x128x!tt.ptr<f32>, #blocked2> | |
tt.return | |
} | |
} | |
{-# | |
external_resources: { | |
mlir_reproducer: { | |
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)", | |
disable_threading: false, | |
verify_each: true | |
} | |
} | |
#-} | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline | |
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.` | |
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float16-False-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True | |
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float16-True-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True | |
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float32-False-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True | |
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float32-True-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True | |
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float16-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float16-True-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float16 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float32-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float32-True-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float32 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float16-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float16-True-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float32-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float32-True-3] __ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float16-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float16-True-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float32-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float32-True-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1 | |
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False | |
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float16-False-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float16-True-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float32-False-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float32-True-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float16-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float16-True-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float32-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float32-True-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float16-False-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float16-True-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float32-False-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float32-True-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64 | |
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float16-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float16-True-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float32-False-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float32-True-3] _ | |
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float16-False-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float16-True-3] ____ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float32-False-3] ___ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float32-True-3] ____ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float16-False-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float16-True-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float16 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float32-False-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float32-True-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float32 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float16-False-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float16-True-3] __ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float32-False-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float32-True-3] __ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float16-False-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-cols', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float16-True-3] __ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-cols', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float32-False-3] _ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-cols', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float32-True-3] __ | |
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128 | |
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-cols', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float16-False-3] ___ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float16-True-3] ___ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
__ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float32-False-3] ___ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = False | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
___ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float32-True-3] ___ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float16-False-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float16-True-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float16 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float32-False-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float32-True-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-matrix', out_dtype = triton.language.float32 | |
USE_TMA_STORE = True, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float16-False-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float16-True-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float32-False-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float32 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float32-True-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-rows', out_dtype = triton.language.float32, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-cols-float16-False-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-cols', out_dtype = triton.language.float16 | |
USE_TMA_STORE = False, NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], | |
[128, 128, 128, 4, 2, 256, 256, 192], | |
# small BLOCK_M and BLOCK_K | |
[16, 32, 32, 4, 1, 128, 256, 64], | |
[32, 32, 16, 4, 1, 256, 256, 192], | |
[16, 32, 64, 4, 4, 512, 256, 64], | |
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in | |
[False, True] for trans_output in [False, True] for num_stages in [3] | |
] + [ | |
# loop over instr shapes & pipeline stages | |
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) | |
for n in [16, 32, 64, 128, 256] | |
for trans_output in [False] | |
for out_dtype in ['float32'] | |
for use_tma_store in [False] | |
for num_stages in [2, 4, 5, 7] | |
] + [ | |
# irregular shapes | |
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[128, 128, 64, 4, 1], | |
[256, 128, 64, 4, 2], | |
[128, 128, 128, 4, 2], | |
] for shape in [ | |
[512, 360, 1024], | |
[360, 4096, 512], | |
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in | |
[3, 4] | |
]) | |
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") | |
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, | |
out_dtype, USE_TMA_STORE, NUM_STAGES): | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ | |
'16-32-64-4-4-512-256-64-True-False', | |
'16-32-64-4-4-512-256-64-True-True', | |
'16-32-64-4-4-512-256-64-False-False', | |
'16-32-64-4-4-512-256-64-False-True', | |
]: | |
pytest.skip('shapePerCTA[1] < 16 not supported') | |
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ | |
'16-32-64-4-1-256-256-256-False', | |
'16-32-64-4-2-256-256-256-False', | |
'16-32-64-4-2-256-256-256-True', | |
'16-32-64-8-2-256-256-256-False', | |
'16-32-64-8-2-256-256-256-True', | |
]: | |
pytest.skip('Known legacy issue, ldmatrix can only support x4') | |
if is_hip() and NUM_CTAS > 1: | |
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend") | |
if epilogue == 'add-rows' and NUM_CTAS > 1: | |
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.') | |
M = BLOCK_M if M is None else M | |
N = BLOCK_N if N is None else N | |
K = BLOCK_K if K is None else K | |
if (TRANS_A): | |
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T | |
a_order = [0, 1] | |
else: | |
a = torch.randn((M, K), device='cuda', dtype=torch.float16) | |
a_order = [1, 0] | |
if (TRANS_B): | |
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T | |
b_order = [0, 1] | |
else: | |
b = torch.randn((K, N), device='cuda', dtype=torch.float16) | |
b_order = [1, 0] | |
if out_dtype == 'float16' and epilogue != 'softmax': | |
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will | |
# fail with the following error: 'llvm.fmul' op requires the same type | |
# for all operands and results | |
out_dtype = tl.float16 | |
torch_out_dtype = torch.float16 | |
else: | |
out_dtype = tl.float32 | |
torch_out_dtype = torch.float32 | |
# avoid out of memory | |
if epilogue in ['add-matrix', 'add-rows', 'add-cols']: | |
if (TRANS_OUTPUT): | |
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T | |
else: | |
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype) | |
else: | |
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype) | |
# for chain-dot only | |
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T | |
w_order = [0, 1] | |
if (TRANS_OUTPUT): | |
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T | |
z_order = [0, 1] | |
else: | |
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype) | |
z_order = [1, 0] | |
# torch result | |
a_f32 = a.to(torch.float32) | |
b_f32 = b.to(torch.float32) | |
dot = torch.matmul(a_f32, b_f32) | |
def process_epilogue(d, bias, w, epilogue): | |
if epilogue == 'add-matrix': | |
ref = d + bias | |
elif epilogue == 'add-rows': | |
ref = d + bias[:, 0][:, None] | |
elif epilogue == 'add-cols': | |
ref = d + bias[0, :][None, :] | |
elif epilogue == 'softmax': | |
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0]) | |
denom = torch.sum(num, dim=-1, keepdims=True) | |
ref = num / denom | |
# ref = torch.softmax(d, 1) | |
elif epilogue == 'chain-dot': | |
ref = torch.matmul(d, w.to(torch.float32)) | |
else: | |
ref = d | |
return ref | |
golden = process_epilogue(dot, bias, w, epilogue) | |
def grid(META): | |
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) | |
pgm = matmul_kernel[grid]( | |
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # | |
M=M, N=N, K=K, # | |
stride_am=a.stride(0), stride_ak=a.stride(1), # | |
stride_bk=b.stride(0), stride_bn=b.stride(1), # | |
stride_wm=w.stride(0), stride_wn=w.stride(1), # | |
stride_zm=z.stride(0), stride_zn=z.stride(1), # | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # | |
out_dtype=out_dtype, # | |
USE_TMA_STORE=USE_TMA_STORE, # | |
ADD_MATRIX=epilogue == 'add-matrix', # | |
ADD_ROWS=epilogue == 'add-rows', # | |
ADD_COLS=epilogue == 'add-cols', # | |
DO_SOFTMAX=epilogue == 'softmax', # | |
CHAIN_DOT=epilogue == 'chain-dot', # | |
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # | |
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # | |
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # | |
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # | |
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES) | |
torch.set_printoptions(profile="full") | |
golden = torch.nn.functional.normalize(golden) | |
z = torch.nn.functional.normalize(z) | |
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) | |
# check is cuda backend specific | |
if is_hip(): | |
return | |
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() | |
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: | |
is_tcgen5 = (torch.cuda.get_device_capability()[0] | |
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0 | |
ptx = pgm.asm['ptx'] | |
if is_tcgen5: | |
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx) | |
else: | |
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8)) | |
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx) | |
E AssertionError: assert None | |
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n') | |
E + where <function search at 0x7f1ec8c06b60> = re.search | |
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64) | |
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format | |
unit/cuda/test_gemm.py:463: AssertionError | |
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-cols-float16-True-3] _ | |
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256 | |
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False | |
epilogue = 'add-cols', out_dtype = triton.language.float16, USE_TMA_STORE = True | |
NUM_STAGES = 3 | |
@pytest.mark.parametrize( | |
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES', | |
[ | |
# corner shapes | |
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3) | |
for shape_w_c in [ | |
[4096, 1, 1024, False, False, True], | |
[2048, 204, 1000, True, False, True], | |
[4096, 1, 1024, False, False, False], | |
[2048, 204, 1000, True, False, False], | |
] | |
for out_dtype in ['float16', 'float32'] # | |
for use_tma_store in [False, True] # | |
] + [ | |
# softmax epilogue | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, None, None, None], | |
[16, 16, 64, 4, 1, 16, 16, 64], | |
[64, 64, 32, 8, 1, 64, 64, 64], | |
[128, 128, 64, 4, 1, 128, 128, 128], | |
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for | |
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3] | |
] + [ | |
# loop over epilogues besides of softmax | |
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 16, 4, 1, 128, 128, 64], | |
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], | |
# for chain-dot | |
[128, 128, 64, 4, 1, None, None, None], | |
[64, 64, 16, 4, 1, None, None, None], | |
# small BLOCK_M and BLOCK_K | |
[16, 16, 64, 4, 1, 128, 128, 64], | |
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], | |
# repeat | |
[64, 64, 32, 8, 1, 128, 256, 64], | |
[64, 64, 16, 8, 2, 128, 128, 64], | |
# irregular shape | |
[128, 128, 64, 4, 1, 500, 200, 128], | |
[128, 128, 64, 4, 2, 513, 193, 192], | |
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot' | |
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in | |
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not ( | |
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) | |
] + [ | |
# loop over tile shapes and transpose combinations | |
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [ | |
[64, 64, 32, 4, 1, 128, 256, 64], | |
[128, 128, 16, 4, 4, 512, 256, 64], | |
[128, 256, 32, 4, 8, 256, 256, 192], | |
[512, 256, 32, 4, 8, 1024, 256, 192], | |
# BLOCK_K >= 128 | |
[64, 128, 128, 4, 1, 512, 256, 256], | |
[128, 128, 128, 4, 1, 256, 256, 192], |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment