Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active May 9, 2025 01:26
Show Gist options
  • Save davidberard98/75b8ba682421568a9d5386fceec6eb24 to your computer and use it in GitHub Desktop.
Save davidberard98/75b8ba682421568a9d5386fceec6eb24 to your computer and use it in GitHub Desktop.
This file has been truncated, but you can view the full file.
============================= test session starts ==============================
platform linux -- Python 3.12.9, pytest-8.3.5, pluggy-1.5.0
rootdir: /workspace/triton/python
configfile: pyproject.toml
collected 21530 items
unit/blackwell/test_tmem.py s [ 0%]
unit/cuda/test_experimental_tma.py .........F.......F................... [ 0%]
......................s..s................................s..s.......... [ 0%]
......................s..s.....s..s.....s..s.....s..s................... [ 0%]
........................................................................ [ 1%]
......................s..s.....s..s.....s..s.....s..s................... [ 1%]
........................................................................ [ 1%]
..........................F.FFssssssssssssssssssssssssssssssssssssssssss [ 2%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 2%]
unit/cuda/test_flashattention.py .FFFFF [ 2%]
unit/cuda/test_gemm.py ........FFFF............FFFF....FFFFFFFFFFFFFFFFF [ 2%]
FFFFFFF....FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFssssFFFFFFFFF [ 3%]
FFFssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFssssFFFFFFFFFFFFssssFFFFFFFFFFFFFFFFF [ 3%]
FFFFFFFFFFFFFFFFFFF................................ssssssssssssssss..... [ 3%]
...........ssssssssssssssssFFFFFFFFFFFFFFFFFFFFFFFFssssFFFFFFFFFFFFFFFFF [ 4%]
FFFFFFFFFFFssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 4%]
FFF................ssssssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 4%]
unit/cuda/test_gemm_fusion.py ss [ 4%]
unit/cuda/test_mixed_io.py ... [ 4%]
unit/cuda/test_tma_descriptor.py sssssssss [ 4%]
unit/cuda/test_tma_store_gemm.py ........ [ 4%]
unit/instrumentation/test_gpuhello.py . [ 4%]
unit/language/test_annotations.py .......... [ 4%]
unit/language/test_block_pointer.py ......sss......sss......sss......sss [ 5%]
......sss......sss......sss......sss......sss......sss......sss......sss [ 5%]
......sss......sss......sss............................................. [ 5%]
........................................................................ [ 6%]
........................ [ 6%]
unit/language/test_compile_errors.py .............................. [ 6%]
unit/language/test_compile_only.py .... [ 6%]
unit/language/test_conversions.py .....................ss....ssss [ 6%]
unit/language/test_core.py ............................................. [ 6%]
........................................................................ [ 7%]
........................................................................ [ 7%]
........................................................................ [ 7%]
........................................................................ [ 8%]
........................................................................ [ 8%]
........................................................................ [ 8%]
........................................................................ [ 9%]
........................................................................ [ 9%]
........................................................................ [ 9%]
........................................................................ [ 10%]
........................................................................ [ 10%]
........................................................................ [ 10%]
........................................................................ [ 11%]
........................................................................ [ 11%]
........................................................................ [ 11%]
........................................................................ [ 12%]
........................................................................ [ 12%]
........................................................................ [ 12%]
........................................................................ [ 13%]
........................................................................ [ 13%]
........................................................................ [ 13%]
........................................................................ [ 14%]
........................................................................ [ 14%]
........................................................................ [ 14%]
........................................................................ [ 15%]
........................................................................ [ 15%]
........................................................................ [ 15%]
........................................................................ [ 16%]
........................................................................ [ 16%]
........................................................................ [ 16%]
........................................................................ [ 17%]
........................................................................ [ 17%]
........................................................................ [ 17%]
........................................................................ [ 18%]
........................................................................ [ 18%]
........................................................................ [ 18%]
........................................................................ [ 19%]
........................................................................ [ 19%]
........................................................................ [ 19%]
........................................................................ [ 20%]
........................................................................ [ 20%]
........................................................................ [ 20%]
........................................................................ [ 21%]
........................................................................ [ 21%]
........................................................................ [ 21%]
....s................................................................... [ 22%]
........................................................................ [ 22%]
........................................................................ [ 22%]
........................................................................ [ 23%]
........................................................................ [ 23%]
........................................................................ [ 23%]
........................................................................ [ 24%]
........................................................................ [ 24%]
........................................................................ [ 24%]
........................................................................ [ 25%]
........................................................................ [ 25%]
........................................................................ [ 25%]
........................................................................ [ 26%]
........................................................................ [ 26%]
........................................................................ [ 26%]
........................................................................ [ 27%]
........................................................................ [ 27%]
........................................................................ [ 27%]
........................................................................ [ 28%]
........................................................................ [ 28%]
........................................................................ [ 28%]
........................................................................ [ 29%]
........................................................................ [ 29%]
...................................................................ss... [ 29%]
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 30%]
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 30%]
.ss....ss....ss......................................................... [ 30%]
........................................................................ [ 31%]
........................................................................ [ 31%]
........................................................................ [ 31%]
...................................................................ss... [ 32%]
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 32%]
.ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss....ss... [ 32%]
.ss....ss....ss......................................................... [ 33%]
........................................................................ [ 33%]
........................................................................ [ 33%]
........................................................................ [ 34%]
........................................................................ [ 34%]
........................................................................ [ 34%]
........................................................................ [ 35%]
........................................................................ [ 35%]
........................................................................ [ 35%]
........................................................................ [ 36%]
........................................................................ [ 36%]
........................................................................ [ 36%]
........................................................................ [ 37%]
........................................................................ [ 37%]
...............................................................sss..sss. [ 37%]
.sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..s [ 38%]
ss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss [ 38%]
..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss..sss.. [ 38%]
sss..sss..sss..sss..sss..sss..sss..sss..sss............................. [ 39%]
........................................................................ [ 39%]
........................................................................ [ 39%]
........................................................................ [ 40%]
........................................................................ [ 40%]
........................................................................ [ 40%]
........................................................................ [ 41%]
........................................................................ [ 41%]
........................................................................ [ 41%]
........................................................................ [ 42%]
........................................................................ [ 42%]
........................................................................ [ 42%]
........................................................................ [ 43%]
........................................................................ [ 43%]
........................................................................ [ 43%]
........................................................................ [ 44%]
........................................................................ [ 44%]
........................................................................ [ 44%]
........................................................................ [ 45%]
........................................................................ [ 45%]
........................................................................ [ 45%]
........................................................................ [ 46%]
..................................................ss.................... [ 46%]
........................................................................ [ 46%]
........................................................................ [ 47%]
........................................................................ [ 47%]
........................................................................ [ 47%]
........................................................................ [ 48%]
........................................................................ [ 48%]
........................................................................ [ 48%]
................ssssssssssssssssssssssssssssssssssssssss................ [ 49%]
........................................................................ [ 49%]
........................................................................ [ 49%]
........................................................................ [ 50%]
........................................................................ [ 50%]
........................................................................ [ 50%]
........................................................................ [ 51%]
........................................................................ [ 51%]
........................................................................ [ 51%]
........................................................................ [ 52%]
........................................................................ [ 52%]
........................................................................ [ 52%]
........................................................................ [ 53%]
........................................................................ [ 53%]
........................................................................ [ 53%]
........................................................................ [ 54%]
........................................................................ [ 54%]
........................................................................ [ 54%]
........................................................................ [ 55%]
........................................................................ [ 55%]
........................................................................ [ 55%]
........................................................................ [ 56%]
........................................................................ [ 56%]
F....................................................................... [ 56%]
........................................................................ [ 57%]
........................................................................ [ 57%]
...................................s.......s.......s.......s.......s.... [ 57%]
.ssssssssssssssssssssssssssssssssssssssssssssssssssssss................. [ 58%]
........................................................................ [ 58%]
........................................................................ [ 58%]
........................................................................ [ 59%]
........................................................................ [ 59%]
........................................................................ [ 59%]
.....................................................ssss............... [ 60%]
.............................................ssss........ssssssss....sss [ 60%]
sssssssssssssssssssssssssssss........ssss........ssssssss....sssssssssss [ 60%]
sssssssssssssssssssss........ssss........ssssssss....sssssssssssssssssss [ 61%]
sssssssssssss........ssss........ssssssss....sssssssssssssssssssssssssss [ 61%]
sssss............ssss................................................... [ 61%]
.........ssss....ssssssss....ssssssssssssssssssssssssssssssss........... [ 62%]
.ssss....ssssssss....ssssssssssssssssssssssssssssssss............ssss... [ 62%]
.ssssssss....ssssssssssssssssssssssssssssssss............ssss....sssssss [ 62%]
s....ssssssssssssssssssssssssssssssss................ssss............... [ 63%]
.............................................ssssssssssss....sssssssssss [ 63%]
sssssssssssssssssssss................ssssssssssss....sssssssssssssssssss [ 63%]
sssssssssssss................ssssssssssss....sssssssssssssssssssssssssss [ 64%]
sssss................ssssssssssss....ssssssssssssssssssssssssssssssss... [ 64%]
.................ssss................................................sss [ 64%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 65%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 65%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 65%]
sssssssssssssssssssssssssssssssssssss................ssss............... [ 66%]
.............................sssssssssssssssssssssssssssssssssssssssssss [ 66%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 66%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 67%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss... [ 67%]
.................ssss................................................... [ 67%]
.ssssssssssssssssssssssssssssssssssssssssssss....................sssssss [ 68%]
sssssssssssssssssssssssssssssssssssss....................sssssssssssssss [ 68%]
sssssssssssssssssssssssssssss....................sssssssssssssssssssssss [ 68%]
sssssssssssssssssssss................................ssss............... [ 69%]
.....................sssssssssssssssssssssssssssssssssssssssssssssssssss [ 69%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 69%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 70%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss........... [ 70%]
.................ssss................................sssssssssssssssssss [ 70%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 71%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 71%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 71%]
sssssssssssssssssssss................................ssss............... [ 72%]
.............sssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 73%]
sssssssssssssssssssssssssssssssssssssssssssssssssssss................... [ 73%]
.................ssss........................sssssssssssssssssssssssssss [ 73%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 74%]
sssssssssssss........................................ssss............... [ 75%]
.....sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 75%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 75%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 76%]
sssssssssssssssssssssssssssssssssssssssssssss........................... [ 76%]
.................ssss................sssssssssssssssssssssssssssssssssss [ 76%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 77%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 77%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 77%]
sssss................................................ssss............sss [ 78%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 78%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 78%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 79%]
sssssssssssssssssssssssssssssssssssss................................... [ 79%]
.................ssss........sssssssssssssssssssssssssssssssssssssssssss [ 79%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 80%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 80%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss... [ 80%]
.....................................................ssss............... [ 81%]
.ssssssss....ssssssssssssssssssssssssssssssssssss................sssssss [ 81%]
s....ssssssssssssssssssssssssssssssssssss................ssssssss....sss [ 81%]
sssssssssssssssssssssssssssssssss................ssssssss....sssssssssss [ 82%]
sssssssssssssssssssssssss............................................... [ 82%]
.................ssss............ssssssss....sssssssssssssssssssssssssss [ 82%]
sssss....ssss............ssssssss....ssssssssssssssssssssssssssssssss... [ 83%]
.ssss............ssssssss....ssssssssssssssssssssssssssssssss....ssss... [ 83%]
.........ssssssss....ssssssssssssssssssssssssssssssss....ssss........... [ 83%]
........................................................................ [ 84%]
................................................F............ss..sssssss [ 84%]
sssss....ssssssss..ss..ssssssssss......ssssssss..ss..ssssssssss......sss [ 84%]
sssss..ss..ssss..ssss......ssssssss..ss......ssssss......ssssssss..ss... [ 85%]
.ssssssss....ssssssss....ss..ssssssss....ssssss........ssssssss......... [ 85%]
.ssssssssssssss..ssssssss......ssssssssss..ss..ssssssss......ssssssssss. [ 85%]
.ss..ssssssss......ssss..ssss..ss..ssssssss..........ssssss..ss..sssssss [ 86%]
s........ssssssss..ss.......ss....ss....ss....ss........................ [ 86%]
........................................................................ [ 86%]
.... [ 86%]
unit/language/test_decorator.py .. [ 86%]
unit/language/test_libdevice.py ............. [ 87%]
unit/language/test_line_info.py ......ssssss.. [ 87%]
unit/language/test_matmul.py ...sss.s....ss......ss..FF..ss......ss..... [ 87%]
.ss.....sss.ssFFsssFssFFFssFFsssssssssFFFssFFsFFFssFFsFFFssFFssssssss... [ 87%]
sss.s....ss......ss..FF..ss......ss......ss.....sss.ssFFsssFssFFFssFFsss [ 87%]
ssssssFFFssFFsFFFssFFsFFFssFFssssssss............................FFFFFFF [ 88%]
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 88%]
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 88%]
FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 89%]
FFFFFFFFFFFxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx [ 89%]
xxxxxxxxxxxFFFFFFFFFFFFFssssssssssssssssssssssssssssssssssssFFFFFFFFFFFF [ 89%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 90%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 90%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 90%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 91%]
ssssssssssssssssssssssssssssssssssssFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF [ 91%]
FFFFFFFFFFFFssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 91%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 92%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 92%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 92%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 94%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 94%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 95%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 98%]
ssssssssssssssssssssssssssssssssssss [ 98%]
unit/language/test_mxfp.py ........................ [ 98%]
unit/language/test_pipeliner.py .............................F..s [ 98%]
unit/language/test_random.py ........................................... [ 98%]
................................................................ [ 98%]
unit/language/test_reproducer.py . [ 98%]
unit/language/test_standard.py ......................................... [ 99%]
.................................. [ 99%]
unit/language/test_subprocess.py ....................F......... [ 99%]
unit/language/test_tuple.py ........ [ 99%]
unit/language/test_warp_specialization.py .. [ 99%]
unit/runtime/test_autotuner.py .......s [ 99%]
unit/runtime/test_bindings.py .. [ 99%]
unit/runtime/test_cache.py ..........................s. [ 99%]
unit/runtime/test_cublas.py ........ [ 99%]
unit/runtime/test_driver.py .. [ 99%]
unit/runtime/test_jit.py . [ 99%]
unit/runtime/test_launch.py .. [ 99%]
unit/runtime/test_subproc.py ... [ 99%]
unit/test_debug.py ..FFFFFFFFFFFFFFFFFFFFFFF.FFFFFFFFFFFFFFFFFFFFF [ 99%]
unit/test_debug_dump.py F [ 99%]
unit/test_perf_warning.py s.F [ 99%]
unit/tools/test_aot.py ..... [ 99%]
unit/tools/test_disasm.py F [ 99%]
unit/tools/test_irsource.py . [100%]
=================================== FAILURES ===================================
_______________ test_experimental_tma_matmul[True-128-256-64-4] ________________
num_stages = 4, BLOCK_M = 128, BLOCK_N = 256, BLOCK_K = 64, byval_tma = True
@pytest.mark.parametrize("num_stages", [1, 4])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
@pytest.mark.parametrize("byval_tma", [True, False])
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma):
if not supports_tma(byval_tma):
pytest.skip(tma_skip_msg(byval_tma))
device = "cuda"
M, N, K = 8192, 8192, 1024
torch.manual_seed(42)
A = torch.randn((M, K), dtype=torch.float16, device=device)
B = torch.randn((K, N), dtype=torch.float16, device=device)
C = torch.empty((M, N), dtype=torch.float16, device=device)
if byval_tma:
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size())
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size())
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size())
else:
desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size())
desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size())
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
> kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
num_warps=8, num_stages=num_stages, dtype=tl.float16)
unit/cuda/test_experimental_tma.py:108:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0f64b90>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147480, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
_______________ test_experimental_tma_matmul[False-128-256-64-4] _______________
num_stages = 4, BLOCK_M = 128, BLOCK_N = 256, BLOCK_K = 64, byval_tma = False
@pytest.mark.parametrize("num_stages", [1, 4])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
@pytest.mark.parametrize("byval_tma", [True, False])
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma):
if not supports_tma(byval_tma):
pytest.skip(tma_skip_msg(byval_tma))
device = "cuda"
M, N, K = 8192, 8192, 1024
torch.manual_seed(42)
A = torch.randn((M, K), dtype=torch.float16, device=device)
B = torch.randn((K, N), dtype=torch.float16, device=device)
C = torch.empty((M, N), dtype=torch.float16, device=device)
if byval_tma:
desc_a = create_2d_tma_descriptor(A.data_ptr(), M, K, BLOCK_M, BLOCK_K, A.element_size())
desc_b = create_2d_tma_descriptor(B.data_ptr(), K, N, BLOCK_K, BLOCK_N, B.element_size())
desc_c = create_2d_tma_descriptor(C.data_ptr(), M, N, BLOCK_M, BLOCK_N, C.element_size())
else:
desc_a = create_tma_desc_gmem_ptr(A.data_ptr(), [M, K], [BLOCK_M, BLOCK_K], A.element_size())
desc_b = create_tma_desc_gmem_ptr(B.data_ptr(), [K, N], [BLOCK_K, BLOCK_N], B.element_size())
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
> kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
num_warps=8, num_stages=num_stages, dtype=tl.float16)
unit/cuda/test_experimental_tma.py:108:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0f780e0>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147480, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
________ test_experimental_make_tensor_descriptor_matmul[128-256-64-4] _________
num_stages = 4, BLOCK_M = 128, BLOCK_N = 256, BLOCK_K = 64
@requires_tma
@pytest.mark.interpreter
@pytest.mark.parametrize("num_stages", [1, 4])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
def test_experimental_make_tensor_descriptor_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K):
device = "cuda"
if is_interpreter():
M, N, K = BLOCK_M, BLOCK_N, BLOCK_K
else:
M, N, K = 8192, 8192, 1024
torch.manual_seed(42)
A = torch.randn((M, K), dtype=torch.float16, device=device)
B = torch.randn((K, N), dtype=torch.float16, device=device)
C = torch.empty((M, N), dtype=torch.float16, device=device)
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1, 1)
def alloc_fn(size: int, align: int, stream: Optional[int]):
assert size == 3 * 128 * grid[0]
assert align == 128
assert stream == 0
return torch.empty(size, dtype=torch.int8, device="cuda")
triton.set_allocator(alloc_fn)
> kernel = matmul_kernel_make_tensor_desciptor[grid](
A,
B,
C,
M,
N,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
num_warps=8,
num_stages=num_stages,
)
unit/cuda/test_experimental_tma.py:723:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1d68557d70>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 147480, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
__________________ test_tensor_descriptor_batched_gemm_2d_tma __________________
@requires_tma
@pytest.mark.interpreter
def test_tensor_descriptor_batched_gemm_2d_tma():
device = "cuda"
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64
if is_interpreter():
B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K
else:
B, M, N, K = 2, 1024, 1024, 128
NUM_SMS = 96
num_stages = 3
grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), )
a = torch.randn((B, M, K), device=device, dtype=torch.float16)
b = torch.randn((B, N, K), device=device, dtype=torch.float16)
c = torch.empty((B, M, N), device=device, dtype=torch.float16)
expect = torch.bmm(a, b.mT)
def alloc_fn(size: int, align: int, stream: Optional[int]):
# TODO: should only need num_stages * 3 descriptors per SM
assert size == 128 * 3 * (num_stages + 1) * grid[0]
assert align == 128
assert stream == 0
return torch.empty(size, dtype=torch.int8, device="cuda")
triton.set_allocator(alloc_fn)
> batched_gemm_2d_tma_kernel[grid](
a, b, c, #
B, M, N, K, #
tl.float16, #
BLOCK_M, BLOCK_N, BLOCK_K, #
NUM_SMS, #
num_stages=num_stages, num_warps=8)
unit/cuda/test_experimental_tma.py:920:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af14fafc0>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 196624, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
__________________ test_tensor_descriptor_batched_gemm_3d_tma __________________
@requires_tma
@pytest.mark.interpreter
def test_tensor_descriptor_batched_gemm_3d_tma():
device = "cuda"
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64
if is_interpreter():
B, M, N, K = 2, BLOCK_M, BLOCK_N, BLOCK_K
else:
B, M, N, K = 2, 1024, 1024, 128
NUM_SMS = 96
num_stages = 3
grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), )
a = torch.randn((B, M, K), device=device, dtype=torch.float16)
b = torch.randn((B, N, K), device=device, dtype=torch.float16)
c = torch.empty((B, M, N), device=device, dtype=torch.float16)
expect = torch.bmm(a, b.mT)
def alloc_fn(size: int, align: int, stream: Optional[int]):
# TODO: should only need num_stages * 3 descriptors per SM
assert size == 128 * 3 * grid[0]
assert align == 128
assert stream == 0
return torch.empty(size, dtype=torch.int8, device="cuda")
triton.set_allocator(alloc_fn)
> h = batched_gemm_3d_tma_kernel[grid](
a, b, c, #
B, M, N, K, #
tl.float16, #
BLOCK_M, BLOCK_N, BLOCK_K, #
NUM_SMS, #
num_stages=num_stages, num_warps=8)
unit/cuda/test_experimental_tma.py:1020:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af212ce90>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 196624, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
_____________________________ test_op[4-48-256-64] _____________________________
Z = 4, H = 48, N_CTX = 256, D_HEAD = 64, dtype = torch.float16
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 128, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
> tri_out.backward(dout)
unit/cuda/test_flashattention.py:395:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward
torch.autograd.backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward
_engine_run_backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/function.py:307: in apply
return user_fn(self, *args)
unit/cuda/test_flashattention.py:342: in backward
_bwd_kernel[(ctx.grid[1], )](
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0248620>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
_____________________________ test_op[4-48-512-64] _____________________________
Z = 4, H = 48, N_CTX = 512, D_HEAD = 64, dtype = torch.float16
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 128, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
> tri_out.backward(dout)
unit/cuda/test_flashattention.py:395:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward
torch.autograd.backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward
_engine_run_backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/function.py:307: in apply
return user_fn(self, *args)
unit/cuda/test_flashattention.py:342: in backward
_bwd_kernel[(ctx.grid[1], )](
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0248620>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
____________________________ test_op[4-48-1024-64] _____________________________
Z = 4, H = 48, N_CTX = 1024, D_HEAD = 64, dtype = torch.float16
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 128, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
tri_out = attention(q, k, v, sm_scale)
# print(ref_out)
# print(tri_out)
> tri_out.backward(dout)
unit/cuda/test_flashattention.py:395:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward
torch.autograd.backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward
_engine_run_backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/function.py:307: in apply
return user_fn(self, *args)
unit/cuda/test_flashattention.py:342: in backward
_bwd_kernel[(ctx.grid[1], )](
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:591: in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata,
../triton/compiler/compiler.py:413: in __getattribute__
self._init_handles()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <triton.compiler.compiler.CompiledKernel object at 0x7f1af0248620>
def _init_handles(self):
if self.module is not None:
return
device = driver.active.get_current_device()
# create launcher
self.run = driver.active.launcher_cls(self.src, self.metadata)
# not enough shared memory to run the kernel
max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
if self.metadata.shared > max_shared:
> raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
E triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 114688, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
../triton/compiler/compiler.py:401: OutOfResources
____________________________ test_op[4-48-2048-64] _____________________________
Z = 4, H = 48, N_CTX = 2048, D_HEAD = 64, dtype = torch.float16
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 128, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
for z in range(Z):
for h in range(H):
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
> ref_out.backward(dout)
unit/cuda/test_flashattention.py:387:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/_tensor.py:648: in backward
torch.autograd.backward(
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/__init__.py:353: in backward
_engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
t_outputs = (tensor([[[[ 0.3435, 0.1052, 0.6152, ..., 0.2805, 0.1891, 0.5288],
[ 0.4490, 0.2205, 0.3833, ..., ...993, ..., 0.2983, 0.3005, 0.2988]]]],
device='cuda:0', dtype=torch.float16, grad_fn=<UnsafeViewBackward0>),)
args = ((tensor([[[[-8.0859e-01, 7.9541e-01, 2.0740e-01, ..., 7.9773e-02,
-1.7402e+00, 4.6851e-01],
... ..., 6.1133e-01,
1.3564e+00, -1.8457e+00]]]], device='cuda:0', dtype=torch.float16),), False, False, ())
kwargs = {'accumulate_grad': True, 'allow_unreachable': True}
attach_logging_hooks = False
def _engine_run_backward(
t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
*args: Any,
**kwargs: Any,
) -> tuple[torch.Tensor, ...]:
attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
if attach_logging_hooks:
unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
try:
> return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
t_outputs, *args, **kwargs
) # Calls into the C++ engine to run the backward pass
E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB. GPU 0 has a total capacity of 15.48 GiB of which 2.32 GiB is free. Including non-PyTorch memory, this process has 13.15 GiB memory in use. Of the allocated memory 9.70 GiB is allocated by PyTorch, and 1.74 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
/root/miniconda3/envs/triton/lib/python3.12/site-packages/torch/autograd/graph.py:824: OutOfMemoryError
____________________________ test_op[4-48-4096-64] _____________________________
Z = 4, H = 48, N_CTX = 4096, D_HEAD = 64, dtype = torch.float16
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 128, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
sm_scale = 0.2
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
> p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 15.48 GiB of which 5.32 GiB is free. Including non-PyTorch memory, this process has 10.15 GiB memory in use. Of the allocated memory 7.13 GiB is allocated by PyTorch, and 1.30 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
unit/cuda/test_flashattention.py:380: OutOfMemoryError
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float16-True] ____________
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = True
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f19bb9dd220>
src = <triton._C.libtriton.ir.module object at 0x7f19bb94bbf0>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%5 = arith.extsi %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>>
%6 = arith.extsi %4 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked>
%8 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi64, #blocked1>
%9 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked>
%10 = arith.muli %7, %9 : tensor<256x1xi64, #blocked>
%11 = tt.broadcast %10 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked>
%12 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%14 = arith.extsi %12 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%15 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%17 = tt.broadcast %16 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked>
%18 = arith.addi %11, %17 : tensor<256x16xi64, #blocked>
%19 = tt.addptr %2, %18 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked>
%20 = tt.load %19 : tensor<256x16x!tt.ptr<f16>, #blocked>
%21 = ttg.local_alloc %20 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem>
%22 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%23 = tt.expand_dims %14 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2>
%24 = tt.broadcast %23 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%26 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%27 = arith.extsi %26 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%28 = arith.extsi %25 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>>
%29 = tt.expand_dims %27 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2>
%30 = tt.expand_dims %28 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi64, #blocked1>
%31 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2>
%32 = arith.muli %29, %31 : tensor<1x64xi64, #blocked2>
%33 = tt.broadcast %32 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%34 = arith.addi %24, %33 : tensor<16x64xi64, #blocked2>
%35 = tt.addptr %22, %34 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2>
%36 = tt.load %35 : tensor<16x64x!tt.ptr<f16>, #blocked2>
%37 = ttg.local_alloc %36 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem>
%38 = ttg.local_load %21 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%39 = ttg.local_load %37 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma>
%41 = arith.truncf %40 : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma>
%42 = arith.extsi %arg8 : i32 to i64
%43 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
%44 = tt.splat %42 : i64 -> tensor<256x1xi64, #blocked1>
%45 = arith.muli %8, %44 : tensor<256x1xi64, #blocked1>
%46 = tt.broadcast %45 : tensor<256x1xi64, #blocked1> -> tensor<256x64xi64, #blocked1>
%47 = tt.broadcast %30 : tensor<1x64xi64, #blocked1> -> tensor<256x64xi64, #blocked1>
%48 = arith.addi %46, %47 : tensor<256x64xi64, #blocked1>
%49 = tt.addptr %43, %48 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi64, #blocked1>
%50 = ttg.convert_layout %41 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #blocked1>
tt.store %49, %50 : tensor<256x64x!tt.ptr<f16>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float32-True] ____________
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = True
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f19bba4d370>
src = <triton._C.libtriton.ir.module object at 0x7f19bba77890>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%5 = arith.extsi %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>>
%6 = arith.extsi %4 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%7 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi64, #blocked1>
%8 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked>
%9 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked>
%10 = arith.muli %8, %9 : tensor<256x1xi64, #blocked>
%11 = tt.broadcast %10 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked>
%12 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%14 = arith.extsi %12 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%15 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%16 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%17 = tt.broadcast %16 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked>
%18 = arith.addi %11, %17 : tensor<256x16xi64, #blocked>
%19 = tt.addptr %2, %18 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked>
%20 = tt.load %19 : tensor<256x16x!tt.ptr<f16>, #blocked>
%21 = ttg.local_alloc %20 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem>
%22 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%23 = tt.expand_dims %14 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2>
%24 = tt.broadcast %23 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%25 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%26 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%27 = arith.extsi %26 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>>
%28 = arith.extsi %25 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%29 = tt.expand_dims %27 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi64, #blocked1>
%30 = tt.expand_dims %28 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2>
%31 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2>
%32 = arith.muli %30, %31 : tensor<1x64xi64, #blocked2>
%33 = tt.broadcast %32 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%34 = arith.addi %24, %33 : tensor<16x64xi64, #blocked2>
%35 = tt.addptr %22, %34 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2>
%36 = tt.load %35 : tensor<16x64x!tt.ptr<f16>, #blocked2>
%37 = ttg.local_alloc %36 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem>
%38 = ttg.local_load %21 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%39 = ttg.local_load %37 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma>
%41 = arith.extsi %arg8 : i32 to i64
%42 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x64x!tt.ptr<f32>, #blocked1>
%43 = tt.splat %41 : i64 -> tensor<256x1xi64, #blocked1>
%44 = arith.muli %7, %43 : tensor<256x1xi64, #blocked1>
%45 = tt.broadcast %44 : tensor<256x1xi64, #blocked1> -> tensor<256x64xi64, #blocked1>
%46 = tt.broadcast %29 : tensor<1x64xi64, #blocked1> -> tensor<256x64xi64, #blocked1>
%47 = arith.addi %45, %46 : tensor<256x64xi64, #blocked1>
%48 = tt.addptr %42, %47 : tensor<256x64x!tt.ptr<f32>, #blocked1>, tensor<256x64xi64, #blocked1>
%49 = ttg.convert_layout %40 : tensor<256x64xf32, #mma> -> tensor<256x64xf32, #blocked1>
tt.store %48, %49 : tensor<256x64x!tt.ptr<f32>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
___________ test_gemm_no_scf[128-128-16-4-4-False-True-float16-True] ___________
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = True
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f1d685f4110>
src = <triton._C.libtriton.ir.module object at 0x7f19bba77170>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%7 = arith.extsi %3 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%8 = arith.extsi %6 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%9 = arith.extsi %5 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>>
%10 = arith.extsi %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked1}>>
%11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi64, #blocked1>
%12 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked>
%13 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked>
%14 = arith.muli %12, %13 : tensor<128x1xi64, #blocked>
%15 = tt.broadcast %14 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked>
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%18 = arith.extsi %17 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%19 = arith.extsi %16 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%20 = tt.expand_dims %19 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%21 = tt.broadcast %20 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked>
%22 = arith.addi %15, %21 : tensor<128x16xi64, #blocked>
%23 = tt.addptr %2, %22 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked>
%24 = tt.load %23 : tensor<128x16x!tt.ptr<f16>, #blocked>
%25 = ttg.local_alloc %24 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem>
%26 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked2>
%27 = tt.expand_dims %18 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2>
%28 = tt.broadcast %27 : tensor<16x1xi64, #blocked2> -> tensor<16x128xi64, #blocked2>
%29 = tt.expand_dims %8 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi64, #blocked2>
%30 = tt.expand_dims %9 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi64, #blocked1>
%31 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked2>
%32 = arith.muli %29, %31 : tensor<1x128xi64, #blocked2>
%33 = tt.broadcast %32 : tensor<1x128xi64, #blocked2> -> tensor<16x128xi64, #blocked2>
%34 = arith.addi %28, %33 : tensor<16x128xi64, #blocked2>
%35 = tt.addptr %26, %34 : tensor<16x128x!tt.ptr<f16>, #blocked2>, tensor<16x128xi64, #blocked2>
%36 = tt.load %35 : tensor<16x128x!tt.ptr<f16>, #blocked2>
%37 = ttg.local_alloc %36 : (tensor<16x128xf16, #blocked2>) -> !ttg.memdesc<16x128xf16, #shared1, #smem>
%38 = ttg.local_load %25 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%39 = ttg.local_load %37 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%41 = arith.truncf %40 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%42 = arith.extsi %arg8 : i32 to i64
%43 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked1>
%44 = tt.splat %42 : i64 -> tensor<128x1xi64, #blocked1>
%45 = arith.muli %11, %44 : tensor<128x1xi64, #blocked1>
%46 = tt.broadcast %45 : tensor<128x1xi64, #blocked1> -> tensor<128x128xi64, #blocked1>
%47 = tt.broadcast %30 : tensor<1x128xi64, #blocked1> -> tensor<128x128xi64, #blocked1>
%48 = arith.addi %46, %47 : tensor<128x128xi64, #blocked1>
%49 = tt.addptr %43, %48 : tensor<128x128x!tt.ptr<f16>, #blocked1>, tensor<128x128xi64, #blocked1>
%50 = ttg.convert_layout %41 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1>
tt.store %49, %50 : tensor<128x128x!tt.ptr<f16>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
___________ test_gemm_no_scf[128-128-16-4-4-False-True-float32-True] ___________
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = True
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f1ae013cd10>
src = <triton._C.libtriton.ir.module object at 0x7f19bba76570>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%7 = arith.extsi %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%8 = arith.extsi %6 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%9 = arith.extsi %3 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>>
%10 = arith.extsi %5 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%11 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked>
%12 = tt.expand_dims %7 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi64, #blocked2>
%13 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked>
%14 = arith.muli %11, %13 : tensor<128x1xi64, #blocked>
%15 = tt.broadcast %14 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked>
%16 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%17 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%18 = arith.extsi %16 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%19 = arith.extsi %17 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>>
%20 = tt.expand_dims %18 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%21 = tt.broadcast %20 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked>
%22 = arith.addi %15, %21 : tensor<128x16xi64, #blocked>
%23 = tt.addptr %2, %22 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked>
%24 = tt.load %23 : tensor<128x16x!tt.ptr<f16>, #blocked>
%25 = ttg.local_alloc %24 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem>
%26 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked1>
%27 = tt.expand_dims %19 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi64, #blocked1>
%28 = tt.broadcast %27 : tensor<16x1xi64, #blocked1> -> tensor<16x128xi64, #blocked1>
%29 = tt.expand_dims %9 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi64, #blocked1>
%30 = tt.expand_dims %10 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi64, #blocked2>
%31 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked1>
%32 = arith.muli %29, %31 : tensor<1x128xi64, #blocked1>
%33 = tt.broadcast %32 : tensor<1x128xi64, #blocked1> -> tensor<16x128xi64, #blocked1>
%34 = arith.addi %28, %33 : tensor<16x128xi64, #blocked1>
%35 = tt.addptr %26, %34 : tensor<16x128x!tt.ptr<f16>, #blocked1>, tensor<16x128xi64, #blocked1>
%36 = tt.load %35 : tensor<16x128x!tt.ptr<f16>, #blocked1>
%37 = ttg.local_alloc %36 : (tensor<16x128xf16, #blocked1>) -> !ttg.memdesc<16x128xf16, #shared1, #smem>
%38 = ttg.local_load %25 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%39 = ttg.local_load %37 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%40 = tt.dot %38, %39, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%41 = arith.extsi %arg8 : i32 to i64
%42 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x128x!tt.ptr<f32>, #blocked2>
%43 = tt.splat %41 : i64 -> tensor<128x1xi64, #blocked2>
%44 = arith.muli %12, %43 : tensor<128x1xi64, #blocked2>
%45 = tt.broadcast %44 : tensor<128x1xi64, #blocked2> -> tensor<128x128xi64, #blocked2>
%46 = tt.broadcast %30 : tensor<1x128xi64, #blocked2> -> tensor<128x128xi64, #blocked2>
%47 = arith.addi %45, %46 : tensor<128x128xi64, #blocked2>
%48 = tt.addptr %42, %47 : tensor<128x128x!tt.ptr<f32>, #blocked2>, tensor<128x128xi64, #blocked2>
%49 = ttg.convert_layout %40 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked2>
tt.store %48, %49 : tensor<128x128x!tt.ptr<f32>, #blocked2>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float16-False] ___________
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = False
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f1af0f3ab70>
src = <triton._C.libtriton.ir.module object at 0x7f1af0fb3e90>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%5 = arith.extsi %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked>
%7 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked>
%8 = arith.muli %6, %7 : tensor<256x1xi64, #blocked>
%9 = tt.broadcast %8 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked>
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%12 = arith.extsi %11 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%13 = arith.extsi %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%14 = tt.expand_dims %12 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%15 = tt.broadcast %14 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked>
%16 = arith.addi %9, %15 : tensor<256x16xi64, #blocked>
%17 = tt.addptr %2, %16 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked>
%18 = tt.load %17 : tensor<256x16x!tt.ptr<f16>, #blocked>
%19 = ttg.local_alloc %18 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem>
%20 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%21 = tt.expand_dims %13 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2>
%22 = tt.broadcast %21 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%25 = arith.extsi %24 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%26 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2>
%27 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2>
%28 = arith.muli %26, %27 : tensor<1x64xi64, #blocked2>
%29 = tt.broadcast %28 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%30 = arith.addi %22, %29 : tensor<16x64xi64, #blocked2>
%31 = tt.addptr %20, %30 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2>
%32 = tt.load %31 : tensor<16x64x!tt.ptr<f16>, #blocked2>
%33 = ttg.local_alloc %32 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem>
%34 = ttg.local_load %19 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%35 = ttg.local_load %33 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma>
%37 = arith.truncf %36 : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma>
%38 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
%39 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked1>
%40 = arith.muli %38, %39 : tensor<256x1xi32, #blocked1>
%41 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<256x1x!tt.ptr<f16>, #blocked1>
%42 = tt.addptr %41, %40 : tensor<256x1x!tt.ptr<f16>, #blocked1>, tensor<256x1xi32, #blocked1>
%43 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%44 = tt.broadcast %42 : tensor<256x1x!tt.ptr<f16>, #blocked1> -> tensor<256x64x!tt.ptr<f16>, #blocked1>
%45 = tt.broadcast %43 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
%46 = tt.addptr %44, %45 : tensor<256x64x!tt.ptr<f16>, #blocked1>, tensor<256x64xi32, #blocked1>
%47 = ttg.convert_layout %37 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #blocked1>
tt.store %46, %47 : tensor<256x64x!tt.ptr<f16>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
___________ test_gemm_no_scf[256-64-16-4-4-False-True-float32-False] ___________
M = 256, N = 64, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = False
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f1af0fd04d0>
src = <triton._C.libtriton.ir.module object at 0x7f1af0fb37d0>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (4, 1, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [4, 1], CTASplitNum = [4, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [4, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%4 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%5 = arith.extsi %4 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<256xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi64, #blocked>
%7 = tt.splat %0 : i64 -> tensor<256x1xi64, #blocked>
%8 = arith.muli %6, %7 : tensor<256x1xi64, #blocked>
%9 = tt.broadcast %8 : tensor<256x1xi64, #blocked> -> tensor<256x16xi64, #blocked>
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%12 = arith.extsi %10 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%13 = arith.extsi %11 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%15 = tt.broadcast %14 : tensor<1x16xi64, #blocked> -> tensor<256x16xi64, #blocked>
%16 = arith.addi %9, %15 : tensor<256x16xi64, #blocked>
%17 = tt.addptr %2, %16 : tensor<256x16x!tt.ptr<f16>, #blocked>, tensor<256x16xi64, #blocked>
%18 = tt.load %17 : tensor<256x16x!tt.ptr<f16>, #blocked>
%19 = ttg.local_alloc %18 : (tensor<256x16xf16, #blocked>) -> !ttg.memdesc<256x16xf16, #shared, #smem>
%20 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x64x!tt.ptr<f16>, #blocked2>
%21 = tt.expand_dims %12 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2>
%22 = tt.broadcast %21 : tensor<16x1xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%23 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%24 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%25 = arith.extsi %24 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%26 = tt.expand_dims %25 {axis = 0 : i32} : tensor<64xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi64, #blocked2>
%27 = tt.splat %1 : i64 -> tensor<1x64xi64, #blocked2>
%28 = arith.muli %26, %27 : tensor<1x64xi64, #blocked2>
%29 = tt.broadcast %28 : tensor<1x64xi64, #blocked2> -> tensor<16x64xi64, #blocked2>
%30 = arith.addi %22, %29 : tensor<16x64xi64, #blocked2>
%31 = tt.addptr %20, %30 : tensor<16x64x!tt.ptr<f16>, #blocked2>, tensor<16x64xi64, #blocked2>
%32 = tt.load %31 : tensor<16x64x!tt.ptr<f16>, #blocked2>
%33 = ttg.local_alloc %32 : (tensor<16x64xf16, #blocked2>) -> !ttg.memdesc<16x64xf16, #shared1, #smem>
%34 = ttg.local_load %19 : !ttg.memdesc<256x16xf16, #shared, #smem> -> tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%35 = ttg.local_load %33 : !ttg.memdesc<16x64xf16, #shared1, #smem> -> tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<256x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x64xf32, #mma>
%37 = tt.expand_dims %3 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1>
%38 = tt.splat %arg8 : i32 -> tensor<256x1xi32, #blocked1>
%39 = arith.muli %37, %38 : tensor<256x1xi32, #blocked1>
%40 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x1x!tt.ptr<f32>, #blocked1>
%41 = tt.addptr %40, %39 : tensor<256x1x!tt.ptr<f32>, #blocked1>, tensor<256x1xi32, #blocked1>
%42 = tt.expand_dims %23 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%43 = tt.broadcast %41 : tensor<256x1x!tt.ptr<f32>, #blocked1> -> tensor<256x64x!tt.ptr<f32>, #blocked1>
%44 = tt.broadcast %42 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1>
%45 = tt.addptr %43, %44 : tensor<256x64x!tt.ptr<f32>, #blocked1>, tensor<256x64xi32, #blocked1>
%46 = ttg.convert_layout %36 : tensor<256x64xf32, #mma> -> tensor<256x64xf32, #blocked1>
tt.store %45, %46 : tensor<256x64x!tt.ptr<f32>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
__________ test_gemm_no_scf[128-128-16-4-4-False-True-float16-False] ___________
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float16', USE_TMA_EPILOGUE = False
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f1af0f756a0>
src = <triton._C.libtriton.ir.module object at 0x7f19bba77dd0>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%7 = arith.extsi %6 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>>
%8 = arith.extsi %5 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%9 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked>
%10 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked>
%11 = arith.muli %9, %10 : tensor<128x1xi64, #blocked>
%12 = tt.broadcast %11 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked>
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%15 = arith.extsi %14 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>>
%16 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%18 = tt.broadcast %17 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked>
%19 = arith.addi %12, %18 : tensor<128x16xi64, #blocked>
%20 = tt.addptr %2, %19 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked>
%21 = tt.load %20 : tensor<128x16x!tt.ptr<f16>, #blocked>
%22 = ttg.local_alloc %21 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem>
%23 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked2>
%24 = tt.expand_dims %15 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<16x1xi64, #blocked2>
%25 = tt.broadcast %24 : tensor<16x1xi64, #blocked2> -> tensor<16x128xi64, #blocked2>
%26 = tt.expand_dims %7 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi64, #blocked2>
%27 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked2>
%28 = arith.muli %26, %27 : tensor<1x128xi64, #blocked2>
%29 = tt.broadcast %28 : tensor<1x128xi64, #blocked2> -> tensor<16x128xi64, #blocked2>
%30 = arith.addi %25, %29 : tensor<16x128xi64, #blocked2>
%31 = tt.addptr %23, %30 : tensor<16x128x!tt.ptr<f16>, #blocked2>, tensor<16x128xi64, #blocked2>
%32 = tt.load %31 : tensor<16x128x!tt.ptr<f16>, #blocked2>
%33 = ttg.local_alloc %32 : (tensor<16x128xf16, #blocked2>) -> !ttg.memdesc<16x128xf16, #shared1, #smem>
%34 = ttg.local_load %22 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%35 = ttg.local_load %33 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%37 = arith.truncf %36 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma>
%38 = tt.expand_dims %3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%39 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked1>
%40 = arith.muli %38, %39 : tensor<128x1xi32, #blocked1>
%41 = tt.splat %arg2 : !tt.ptr<f16> -> tensor<128x1x!tt.ptr<f16>, #blocked1>
%42 = tt.addptr %41, %40 : tensor<128x1x!tt.ptr<f16>, #blocked1>, tensor<128x1xi32, #blocked1>
%43 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1>
%44 = tt.broadcast %42 : tensor<128x1x!tt.ptr<f16>, #blocked1> -> tensor<128x128x!tt.ptr<f16>, #blocked1>
%45 = tt.broadcast %43 : tensor<1x128xi32, #blocked1> -> tensor<128x128xi32, #blocked1>
%46 = tt.addptr %44, %45 : tensor<128x128x!tt.ptr<f16>, #blocked1>, tensor<128x128xi32, #blocked1>
%47 = ttg.convert_layout %37 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1>
tt.store %46, %47 : tensor<128x128x!tt.ptr<f16>, #blocked1>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
__________ test_gemm_no_scf[128-128-16-4-4-False-True-float32-False] ___________
M = 128, N = 128, K = 16, NUM_CTAS = 4, NUM_WARPS = 4, TRANS_A = False
TRANS_B = True, OUTPUT_TYPE = 'float32', USE_TMA_EPILOGUE = False
@pytest.mark.parametrize(
'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE',
itertools.chain(*[[
# numCTAs = 1, no TMA multicast:
[64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
# static mask, cluster 4x1
[256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# dynamic mask, cluster 2x2
[128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE],
[128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE],
# small M, N
[16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
[32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE],
] for USE_TMA_EPILOGUE in [True, False]]))
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE):
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
if OUTPUT_TYPE == "float16":
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
else:
c = torch.empty((M, N), device=a.device, dtype=torch.float32)
> matmul_no_scf_kernel[(1, 1)](
a_ptr=a, b_ptr=b, c_ptr=c, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_cm=c.stride(0), stride_cn=c.stride(1), #
BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, #
num_warps=NUM_WARPS, #
num_ctas=NUM_CTAS, #
FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), #
USE_TMA_EPILOGUE=USE_TMA_EPILOGUE)
unit/cuda/test_gemm.py:125:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../triton/runtime/jit.py:347: in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
../triton/runtime/jit.py:569: in run
kernel = self.compile(src, target=target, options=options.__dict__)
../triton/compiler/compiler.py:284: in compile
next_module = compile_ir(module, metadata)
../triton/backends/nvidia/compiler.py:450: in <lambda>
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <nvidia.CUDABackend object at 0x7f1a540990a0>
src = <triton._C.libtriton.ir.module object at 0x7f19bb94b050>
metadata = {'allowed_dot_input_precisions': ('tf32', 'tf32x3', 'ieee'), 'arch': 'sm120', 'backend_name': 'cuda', 'cluster_dims': (2, 2, 1), ...}
options = CUDAOptions(num_warps=4, num_ctas=4, num_stages=3, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0,...triton/backends/nvidia/lib/libdevice.10.bc'),), debug=False, backend_name='cuda', sanitize_overflow=True, arch='sm120')
capability = 120
def make_llir(self, src, metadata, options, capability):
ptx_version = get_ptx_version_from_options(options, self.target.arch)
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_allocate_warp_groups(pm)
passes.convert.add_scf_to_cf(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
> pm.run(mod)
E RuntimeError: PassManager::run failed
../triton/backends/nvidia/compiler.py:341: RuntimeError
----------------------------- Captured stderr call -----------------------------
python: /workspace/triton/lib/Tools/LinearLayout.cpp:441: mlir::triton::LinearLayout mlir::triton::LinearLayout::reshapeIns(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) const: Assertion `getTotalInDimSize() == std::accumulate(newInDims.begin(), newInDims.end(), 1, [&](int32_t acc, auto &inDim) { return acc * inDim.second; })' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], CTAsPerCGA = [2, 2], CTASplitNum = [2, 2], CTAOrder = [1, 0], instrShape = [16, 8]}>
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [2, 1], CTAOrder = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [2, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:120", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_no_scf_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
%0 = arith.extsi %arg6 : i32 to i64
%1 = arith.extsi %arg7 : i32 to i64
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%3 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%5 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
%6 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
%7 = arith.extsi %3 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> to tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>>
%8 = arith.extsi %4 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> to tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>>
%9 = tt.expand_dims %8 {axis = 1 : i32} : tensor<128xi64, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi64, #blocked>
%10 = tt.splat %0 : i64 -> tensor<128x1xi64, #blocked>
%11 = arith.muli %9, %10 : tensor<128x1xi64, #blocked>
%12 = tt.broadcast %11 : tensor<128x1xi64, #blocked> -> tensor<128x16xi64, #blocked>
%13 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%14 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%15 = arith.extsi %14 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> to tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>>
%16 = arith.extsi %13 : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> to tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>>
%17 = tt.expand_dims %15 {axis = 0 : i32} : tensor<16xi64, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi64, #blocked>
%18 = tt.broadcast %17 : tensor<1x16xi64, #blocked> -> tensor<128x16xi64, #blocked>
%19 = arith.addi %12, %18 : tensor<128x16xi64, #blocked>
%20 = tt.addptr %2, %19 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi64, #blocked>
%21 = tt.load %20 : tensor<128x16x!tt.ptr<f16>, #blocked>
%22 = ttg.local_alloc %21 : (tensor<128x16xf16, #blocked>) -> !ttg.memdesc<128x16xf16, #shared, #smem>
%23 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<16x128x!tt.ptr<f16>, #blocked1>
%24 = tt.expand_dims %16 {axis = 1 : i32} : tensor<16xi64, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<16x1xi64, #blocked1>
%25 = tt.broadcast %24 : tensor<16x1xi64, #blocked1> -> tensor<16x128xi64, #blocked1>
%26 = tt.expand_dims %7 {axis = 0 : i32} : tensor<128xi64, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi64, #blocked1>
%27 = tt.splat %1 : i64 -> tensor<1x128xi64, #blocked1>
%28 = arith.muli %26, %27 : tensor<1x128xi64, #blocked1>
%29 = tt.broadcast %28 : tensor<1x128xi64, #blocked1> -> tensor<16x128xi64, #blocked1>
%30 = arith.addi %25, %29 : tensor<16x128xi64, #blocked1>
%31 = tt.addptr %23, %30 : tensor<16x128x!tt.ptr<f16>, #blocked1>, tensor<16x128xi64, #blocked1>
%32 = tt.load %31 : tensor<16x128x!tt.ptr<f16>, #blocked1>
%33 = ttg.local_alloc %32 : (tensor<16x128xf16, #blocked1>) -> !ttg.memdesc<16x128xf16, #shared1, #smem>
%34 = ttg.local_load %22 : !ttg.memdesc<128x16xf16, #shared, #smem> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%35 = ttg.local_load %33 : !ttg.memdesc<16x128xf16, #shared1, #smem> -> tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%36 = tt.dot %34, %35, %cst, inputPrecision = tf32 : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma>
%37 = tt.expand_dims %6 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2>
%38 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #blocked2>
%39 = arith.muli %37, %38 : tensor<128x1xi32, #blocked2>
%40 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<128x1x!tt.ptr<f32>, #blocked2>
%41 = tt.addptr %40, %39 : tensor<128x1x!tt.ptr<f32>, #blocked2>, tensor<128x1xi32, #blocked2>
%42 = tt.expand_dims %5 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x128xi32, #blocked2>
%43 = tt.broadcast %41 : tensor<128x1x!tt.ptr<f32>, #blocked2> -> tensor<128x128x!tt.ptr<f32>, #blocked2>
%44 = tt.broadcast %42 : tensor<1x128xi32, #blocked2> -> tensor<128x128xi32, #blocked2>
%45 = tt.addptr %43, %44 : tensor<128x128x!tt.ptr<f32>, #blocked2>, tensor<128x128xi32, #blocked2>
%46 = ttg.convert_layout %36 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked2>
tt.store %45, %46 : tensor<128x128x!tt.ptr<f32>, #blocked2>
tt.return
}
}
{-#
external_resources: {
mlir_reproducer: {
pipeline: "builtin.module(triton-nvidia-mma-lowering, tritongpu-combine-tensor-select-and-if, tritongpu-allocate-warp-groups, convert-scf-to-cf, allocate-shared-memory, triton-tensor-memory-allocation, tritongpu-global-scratch-memory-allocation, convert-triton-gpu-to-llvm{compute-capability=120 ptx-version=87}, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-nv-gpu-to-llvm, convert-warp-specialize-to-llvm, canonicalize{ max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info)",
disable_threading: false,
verify_each: true
}
}
#-}
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: error: Failures have been detected while processing an MLIR pass pipeline
/workspace/triton/python/test/unit/cuda/test_gemm.py:39:0: note: Pipeline failed while executing [`ConvertTritonGPUToLLVM` on 'builtin.module' operation]: reproducer generated at `std::errs, please share the reproducer above with Triton project.`
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float16-False-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float16-True-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float32-False-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[128-128-64-4-1-4096-1-1024-False-False-True-none-float32-True-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False, TRANS_OUTPUT = True
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float16-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float16-True-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float16
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float32-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-True-none-float32-True-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = True, epilogue = 'none', out_dtype = triton.language.float32
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float16-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float16-True-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float32-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-4096-1-1024-False-False-False-none-float32-True-3] __
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 4096, N = 1, K = 1024, TRANS_A = False, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float16-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float16-True-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float16
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float32-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-2048-204-1000-True-False-False-none-float32-True-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1
M = 2048, N = 204, K = 1000, TRANS_A = True, TRANS_B = False
TRANS_OUTPUT = False, epilogue = 'none', out_dtype = triton.language.float32
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float16-False-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float16-True-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float32-False-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-16-4-1-64-64-64-False-True-False-softmax-float32-True-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float16-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float16-True-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float32-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-None-None-None-False-True-False-softmax-float32-True-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float16-False-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float16-True-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float32-False-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-32-8-1-64-64-64-False-True-False-softmax-float32-True-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 32, NUM_WARPS = 8, NUM_CTAS = 1, M = 64
N = 64, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n32k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(32)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float16-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float16-True-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float32-False-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[128-128-64-4-1-128-128-128-False-True-False-softmax-float32-True-3] _
BLOCK_M = 128, BLOCK_N = 128, BLOCK_K = 64, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 128, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'softmax', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n128k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(128)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float16-False-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float16-True-3] ____
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float32-False-3] ___
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[64-64-16-4-1-128-128-64-False-True-False-none-float32-True-3] ____
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float16-False-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float16-True-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float16
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float32-False-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-matrix-float32-True-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float32
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float16-False-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float16-True-3] __
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float32-False-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-rows-float32-True-3] __
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float16-False-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-cols', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float16-True-3] __
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-cols', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float32-False-3] _
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-cols', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[64-64-16-4-1-128-128-64-False-True-False-add-cols-float32-True-3] __
BLOCK_M = 64, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 128
N = 128, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-cols', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float16-False-3] ___
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float16-True-3] ___
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
__ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float32-False-3] ___
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = False
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
___ test_gemm[256-64-16-4-1-256-256-64-False-True-False-none-float32-True-3] ___
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'none', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float16-False-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float16-True-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float16
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float32-False-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-matrix-float32-True-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-matrix', out_dtype = triton.language.float32
USE_TMA_STORE = True, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float16-False-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float16-True-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float32-False-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float32
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-rows-float32-True-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-rows', out_dtype = triton.language.float32, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-cols-float16-False-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-cols', out_dtype = triton.language.float16
USE_TMA_STORE = False, NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
[128, 128, 128, 4, 2, 256, 256, 192],
# small BLOCK_M and BLOCK_K
[16, 32, 32, 4, 1, 128, 256, 64],
[32, 32, 16, 4, 1, 256, 256, 192],
[16, 32, 64, 4, 4, 512, 256, 64],
] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in
[False, True] for trans_output in [False, True] for num_stages in [3]
] + [
# loop over instr shapes & pipeline stages
(64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages)
for n in [16, 32, 64, 128, 256]
for trans_output in [False]
for out_dtype in ['float32']
for use_tma_store in [False]
for num_stages in [2, 4, 5, 7]
] + [
# irregular shapes
(*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[128, 128, 64, 4, 1],
[256, 128, 64, 4, 2],
[128, 128, 128, 4, 2],
] for shape in [
[512, 360, 1024],
[360, 4096, 512],
] for trans_output in [False] for out_dtype in ['float32'] for use_tma_store in [False, True] for num_stages in
[3, 4]
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9")
def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue,
out_dtype, USE_TMA_STORE, NUM_STAGES):
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [
'16-32-64-4-4-512-256-64-True-False',
'16-32-64-4-4-512-256-64-True-True',
'16-32-64-4-4-512-256-64-False-False',
'16-32-64-4-4-512-256-64-False-True',
]:
pytest.skip('shapePerCTA[1] < 16 not supported')
if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [
'16-32-64-4-1-256-256-256-False',
'16-32-64-4-2-256-256-256-False',
'16-32-64-4-2-256-256-256-True',
'16-32-64-8-2-256-256-256-False',
'16-32-64-8-2-256-256-256-True',
]:
pytest.skip('Known legacy issue, ldmatrix can only support x4')
if is_hip() and NUM_CTAS > 1:
pytest.skip("NUM_CTAS > 1 is not supported in HIP backend")
if epilogue == 'add-rows' and NUM_CTAS > 1:
pytest.skip('known failure: error getCTAsPerCGA for SliceEncodingAttr is not well-defined.')
M = BLOCK_M if M is None else M
N = BLOCK_N if N is None else N
K = BLOCK_K if K is None else K
if (TRANS_A):
a = torch.randn((K, M), device='cuda', dtype=torch.float16).T
a_order = [0, 1]
else:
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
a_order = [1, 0]
if (TRANS_B):
b = torch.randn((N, K), device='cuda', dtype=torch.float16).T
b_order = [0, 1]
else:
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
b_order = [1, 0]
if out_dtype == 'float16' and epilogue != 'softmax':
# TODO: for out_dtype == 'float16' and epilogue == 'softmax', it will
# fail with the following error: 'llvm.fmul' op requires the same type
# for all operands and results
out_dtype = tl.float16
torch_out_dtype = torch.float16
else:
out_dtype = tl.float32
torch_out_dtype = torch.float32
# avoid out of memory
if epilogue in ['add-matrix', 'add-rows', 'add-cols']:
if (TRANS_OUTPUT):
bias = torch.randn((N, M), device='cuda', dtype=torch_out_dtype).T
else:
bias = torch.randn((M, N), device='cuda', dtype=torch_out_dtype)
else:
bias = torch.randn((1, 1), device='cuda', dtype=torch_out_dtype)
# for chain-dot only
w = torch.randn((N, N), device='cuda', dtype=torch.float16).T
w_order = [0, 1]
if (TRANS_OUTPUT):
z = torch.full((N, M), 1., device='cuda', dtype=torch_out_dtype).T
z_order = [0, 1]
else:
z = torch.full((M, N), 1., device='cuda', dtype=torch_out_dtype)
z_order = [1, 0]
# torch result
a_f32 = a.to(torch.float32)
b_f32 = b.to(torch.float32)
dot = torch.matmul(a_f32, b_f32)
def process_epilogue(d, bias, w, epilogue):
if epilogue == 'add-matrix':
ref = d + bias
elif epilogue == 'add-rows':
ref = d + bias[:, 0][:, None]
elif epilogue == 'add-cols':
ref = d + bias[0, :][None, :]
elif epilogue == 'softmax':
num = torch.exp(d - torch.max(d, dim=-1, keepdims=True)[0])
denom = torch.sum(num, dim=-1, keepdims=True)
ref = num / denom
# ref = torch.softmax(d, 1)
elif epilogue == 'chain-dot':
ref = torch.matmul(d, w.to(torch.float32))
else:
ref = d
return ref
golden = process_epilogue(dot, bias, w, epilogue)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
pgm = matmul_kernel[grid](
a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, #
M=M, N=N, K=K, #
stride_am=a.stride(0), stride_ak=a.stride(1), #
stride_bk=b.stride(0), stride_bn=b.stride(1), #
stride_wm=w.stride(0), stride_wn=w.stride(1), #
stride_zm=z.stride(0), stride_zn=z.stride(1), #
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, #
out_dtype=out_dtype, #
USE_TMA_STORE=USE_TMA_STORE, #
ADD_MATRIX=epilogue == 'add-matrix', #
ADD_ROWS=epilogue == 'add-rows', #
ADD_COLS=epilogue == 'add-cols', #
DO_SOFTMAX=epilogue == 'softmax', #
CHAIN_DOT=epilogue == 'chain-dot', #
A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], #
B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], #
W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], #
Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], #
num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES)
torch.set_printoptions(profile="full")
golden = torch.nn.functional.normalize(golden)
z = torch.nn.functional.normalize(z)
assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False)
# check is cuda backend specific
if is_hip():
return
disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower()
if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256:
is_tcgen5 = (torch.cuda.get_device_capability()[0]
== 10) and (NUM_WARPS % 4) == 0 and (BLOCK_M % 64) == 0 and (BLOCK_N % 8) == 0
ptx = pgm.asm['ptx']
if is_tcgen5:
assert re.search(r'tcgen05.mma.cta_group::1.kind::f16', ptx)
else:
wgmma_n = int(max(BLOCK_N / max(NUM_WARPS / max(BLOCK_M / 16, 1), 1), 8))
> assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(wgmma_n), ptx)
E AssertionError: assert None
E + where None = <function search at 0x7f1ec8c06b60>('wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16', '//\n// Generated by LLVM NVPTX Back-End\n//\n\n.version 8.7\n.target sm_120a\n.address_size 64\n\n\t// .globl\tmatmul...hildren Mark\n.b8 0 // End Of Children Mark\n\t}\n\t.section\t.debug_macinfo\t{\t}\n')
E + where <function search at 0x7f1ec8c06b60> = re.search
E + and 'wgmma.mma_async.sync.aligned.m\\d+n64k16(?:.row.col)?.f32.f16.f16' = <built-in method format of str object at 0x7f1d684945e0>(64)
E + where <built-in method format of str object at 0x7f1d684945e0> = 'wgmma.mma_async.sync.aligned.m\\d+n{}k16(?:.row.col)?.f32.f16.f16'.format
unit/cuda/test_gemm.py:463: AssertionError
_ test_gemm[256-64-16-4-1-256-256-64-False-True-False-add-cols-float16-True-3] _
BLOCK_M = 256, BLOCK_N = 64, BLOCK_K = 16, NUM_WARPS = 4, NUM_CTAS = 1, M = 256
N = 256, K = 64, TRANS_A = False, TRANS_B = True, TRANS_OUTPUT = False
epilogue = 'add-cols', out_dtype = triton.language.float16, USE_TMA_STORE = True
NUM_STAGES = 3
@pytest.mark.parametrize(
'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES',
[
# corner shapes
(128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3)
for shape_w_c in [
[4096, 1, 1024, False, False, True],
[2048, 204, 1000, True, False, True],
[4096, 1, 1024, False, False, False],
[2048, 204, 1000, True, False, False],
]
for out_dtype in ['float16', 'float32'] #
for use_tma_store in [False, True] #
] + [
# softmax epilogue
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 64, 64, 64],
[128, 128, 64, 4, 1, None, None, None],
[16, 16, 64, 4, 1, 16, 16, 64],
[64, 64, 32, 8, 1, 64, 64, 64],
[128, 128, 64, 4, 1, 128, 128, 128],
] for epilogue in ['softmax'] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for
trans_a in [False] for trans_b in [True] for trans_output in [False] for num_stages in [3]
] + [
# loop over epilogues besides of softmax
(*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 16, 4, 1, 128, 128, 64],
*[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]],
# for chain-dot
[128, 128, 64, 4, 1, None, None, None],
[64, 64, 16, 4, 1, None, None, None],
# small BLOCK_M and BLOCK_K
[16, 16, 64, 4, 1, 128, 128, 64],
*[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]],
# repeat
[64, 64, 32, 8, 1, 128, 256, 64],
[64, 64, 16, 8, 2, 128, 128, 64],
# irregular shape
[128, 128, 64, 4, 1, 500, 200, 128],
[128, 128, 64, 4, 2, 513, 193, 192],
] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'
] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in
[False] for trans_b in [True] for trans_output in [False] for num_stages in [3] if not (
epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6]))
] + [
# loop over tile shapes and transpose combinations
(*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages) for shape_w_c in [
[64, 64, 32, 4, 1, 128, 256, 64],
[128, 128, 16, 4, 4, 512, 256, 64],
[128, 256, 32, 4, 8, 256, 256, 192],
[512, 256, 32, 4, 8, 1024, 256, 192],
# BLOCK_K >= 128
[64, 128, 128, 4, 1, 512, 256, 256],
[128, 128, 128, 4, 1, 256, 256, 192],
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment