Skip to content

Instantly share code, notes, and snippets.

@makslevental
Created April 4, 2025 19:34
Show Gist options
  • Save makslevental/85eafd04981f5bada72bfd66586f6ef9 to your computer and use it in GitHub Desktop.
Save makslevental/85eafd04981f5bada72bfd66586f6ef9 to your computer and use it in GitHub Desktop.
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%true = arith.constant true
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%c1_i32 = arith.constant 1 : i32
%c304_i32 = arith.constant 304 : i32
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%cst_3 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%c128_i32 = arith.constant 128 : i32
%c127_i32 = arith.constant 127 : i32
%c8_i32 = arith.constant 8 : i32
%c256_i32 = arith.constant 256 : i32
%c255_i32 = arith.constant 255 : i32
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%1 = tt.get_program_id x : i32
%2 = arith.addi %arg4, %c255_i32 : i32
%3 = arith.divsi %2, %c256_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %1, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.addi %arg3, %c127_i32 : i32
%8 = arith.divsi %7, %c128_i32 : i32
%9 = arith.subi %8, %6 : i32
%10 = arith.minsi %9, %c8_i32 : i32
%11 = arith.remsi %1, %10 : i32
%12 = arith.addi %6, %11 : i32
%13 = arith.muli %12, %c128_i32 : i32
%14 = arith.subi %arg3, %13 : i32
%15 = tt.splat %14 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%16 = arith.cmpi slt, %0, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%17 = arith.select %16, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%19 = tt.expand_dims %17 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%20 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%21 = arith.muli %19, %20 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%22 = tt.broadcast %21 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%23 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%24 = tt.broadcast %23 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%25 = arith.addi %22, %24 : tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%26 = arith.addi %arg5, %c63_i32 : i32
%27 = arith.divsi %26, %c64_i32 : i32
%28 = arith.muli %8, %3 : i32
%29 = arith.divsi %28, %c304_i32 : i32
%30 = arith.remsi %28, %c304_i32 : i32
%31 = arith.cmpi slt, %1, %30 : i32
%32 = scf.if %31 -> (i32) {
%92 = arith.addi %29, %c1_i32 : i32
scf.yield %92 : i32
} else {
scf.yield %29 : i32
}
%33 = arith.muli %27, %32 : i32
%34 = arith.cmpi sgt, %33, %c0_i32 : i32
%35 = tt.splat %34 : i1 -> tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%36 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%37 = tt.expand_dims %36 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%38 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%39 = arith.cmpi slt, %37, %38 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%40 = tt.broadcast %39 : tensor<1x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%41 = arith.andi %35, %40 : tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%42 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%43 = tt.addptr %42, %25 : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>, tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%44 = tt.load %43, %41, %cst_2 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%46 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%47 = arith.remsi %1, %4 : i32
%48 = arith.divsi %47, %10 : i32
%49 = arith.muli %48, %c256_i32 : i32
%50 = arith.subi %arg4, %49 : i32
%51 = tt.splat %50 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%52 = arith.cmpi slt, %46, %51 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%53 = arith.select %52, %46, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%54 = tt.expand_dims %45 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%55 = tt.broadcast %54 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%56 = tt.expand_dims %53 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%57 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%58 = arith.muli %56, %57 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%59 = tt.broadcast %58 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%60 = arith.addi %55, %59 : tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%61 = tt.splat %34 : i1 -> tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%62 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%63 = tt.expand_dims %62 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%64 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%65 = arith.cmpi slt, %63, %64 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%66 = tt.broadcast %65 : tensor<64x1xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%67 = arith.andi %61, %66 : tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%68 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%69 = tt.addptr %68, %60 : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>, tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%70 = tt.load %69, %67, %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%71 = arith.cmpi sgt, %arg3, %c0_i32 : i32
llvm.intr.assume %71 : i1
%72 = arith.cmpi sgt, %arg4, %c0_i32 : i32
llvm.intr.assume %72 : i1
%73 = arith.cmpi sgt, %arg5, %c0_i32 : i32
llvm.intr.assume %73 : i1
%74 = arith.cmpi sgt, %arg6, %c0_i32 : i32
llvm.intr.assume %74 : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
%75 = arith.cmpi sgt, %arg7, %c0_i32 : i32
llvm.intr.assume %75 : i1
%76 = arith.cmpi sgt, %arg8, %c0_i32 : i32
llvm.intr.assume %76 : i1
llvm.intr.assume %true : i1
%77 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%78 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%79 = arith.subi %27, %c1_i32 : i32
%80 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>
%81 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
%82 = ttg.memdesc_subview %80[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>
ttg.local_store %44, %82 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>
%83 = ttg.memdesc_subview %81[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
ttg.local_store %70, %83 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
%84 = arith.subi %33, %c1_i32 : i32
%85:10 = scf.for %arg9 = %c0_i32 to %84 step %c1_i32 iter_args(%arg10 = %c0_i32, %arg11 = %1, %arg12 = %12, %arg13 = %48, %arg14 = %cst, %arg15 = %17, %arg16 = %53, %arg17 = %c0_i32, %arg18 = %82, %arg19 = %83) -> (i32, i32, i32, i32, tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, i32, !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>) : i32 {
%92 = arith.cmpi eq, %arg10, %79 : i32
%93 = arith.addi %arg10, %c1_i32 : i32
%94 = arith.select %92, %c0_i32, %93 : i32
%95 = arith.cmpi eq, %94, %c0_i32 : i32
%96:5 = scf.if %95 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>) {
%137 = arith.addi %arg11, %c304_i32 : i32
%138 = arith.divsi %137, %4 : i32
%139 = arith.muli %138, %c8_i32 : i32
%140 = arith.subi %8, %139 : i32
%141 = arith.minsi %140, %c8_i32 : i32
%142 = arith.remsi %137, %141 : i32
%143 = arith.addi %139, %142 : i32
%144 = arith.remsi %137, %4 : i32
%145 = arith.divsi %144, %141 : i32
%146 = arith.muli %143, %c128_i32 : i32
%147 = arith.muli %145, %c256_i32 : i32
%148 = arith.subi %arg3, %146 : i32
%149 = tt.splat %148 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%150 = arith.cmpi slt, %0, %149 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%151 = arith.select %150, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>
%152 = arith.subi %arg4, %147 : i32
%153 = tt.splat %152 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%154 = arith.cmpi slt, %46, %153 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
%155 = arith.select %154, %46, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
scf.yield %137, %143, %145, %151, %155 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
} else {
scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>
}
%97 = arith.muli %94, %c64_i32 : i32
%98 = tt.expand_dims %96#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%99 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%100 = arith.muli %98, %99 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%101 = tt.broadcast %100 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%102 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%103 = tt.broadcast %102 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%104 = arith.addi %101, %103 : tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%105 = tt.addptr %arg0, %97 : !tt.ptr<f16>, i32
%106 = arith.subi %arg5, %97 : i32
%107 = tt.splat %106 : i32 -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%108 = arith.cmpi slt, %37, %107 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%109 = tt.broadcast %108 : tensor<1x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%110 = tt.splat %105 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%111 = tt.addptr %110, %104 : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>, tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%112 = tt.load %111, %109, %cst_2 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>
%113 = tt.expand_dims %45 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%114 = tt.broadcast %113 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%115 = tt.expand_dims %96#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%116 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%117 = arith.muli %115, %116 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%118 = tt.broadcast %117 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%119 = arith.addi %114, %118 : tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%120 = tt.addptr %arg1, %97 : !tt.ptr<f16>, i32
%121 = tt.splat %106 : i32 -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%122 = arith.cmpi slt, %63, %121 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%123 = tt.broadcast %122 : tensor<64x1xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%124 = ttg.local_load %arg18 : !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>>
%125 = ttg.local_load %arg19 : !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>>
%126 = tt.splat %120 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%127 = tt.addptr %126, %119 : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>, tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%128 = tt.load %127, %123, %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>
%129 = tt.dot %124, %125, %arg14, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> -> tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%130 = arith.cmpi eq, %arg10, %79 : i32
%131 = arith.select %130, %cst, %129 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
scf.if %130 {
%137 = arith.muli %arg12, %c128_i32 : i32
%138 = tt.splat %137 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%139 = arith.addi %138, %77 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%140 = arith.muli %arg13, %c256_i32 : i32
%141 = tt.splat %140 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%142 = arith.addi %141, %78 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%143 = tt.expand_dims %139 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%144 = tt.expand_dims %77 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%145 = arith.muli %arg8, %137 : i32
%146 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%147 = arith.muli %146, %144 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%148 = tt.addptr %arg2, %145 : !tt.ptr<f16>, i32
%149 = tt.expand_dims %142 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%150 = tt.broadcast %147 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%151 = tt.expand_dims %78 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%152 = tt.broadcast %151 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%153 = tt.addptr %148, %140 : !tt.ptr<f16>, i32
%154 = arith.addi %152, %150 : tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%155 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%156 = arith.cmpi slt, %143, %155 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%157 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%158 = arith.cmpi slt, %149, %157 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%159 = tt.broadcast %156 : tensor<128x1xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%160 = tt.broadcast %158 : tensor<1x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%161 = arith.andi %159, %160 : tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%162 = arith.truncf %129 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> to tensor<128x256xf16, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%163 = tt.splat %153 : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%164 = tt.addptr %163, %154 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
tt.store %164, %162, %161 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
}
%132 = arith.addi %arg17, %c1_i32 : i32
%133 = arith.cmpi slt, %132, %c1_i32 : i32
%134 = arith.select %133, %132, %c0_i32 : i32
%135 = ttg.memdesc_subview %80[%134, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>
ttg.local_store %112, %135 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>
%136 = ttg.memdesc_subview %81[%134, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
ttg.local_store %128, %136 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
scf.yield %94, %96#0, %96#1, %96#2, %131, %96#3, %96#4, %134, %135, %136 : i32, i32, i32, i32, tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, i32, !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
}
%86 = arith.cmpi sge, %33, %c1_i32 : i32
%87 = ttg.local_load %85#8 : !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>>
%88 = ttg.local_load %85#9 : !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>>
%89 = scf.if %86 -> (tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>) {
%92 = tt.dot %87, %88, %85#4, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> -> tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
scf.yield %92 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
} else {
scf.yield %85#4 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
}
%90 = arith.cmpi eq, %85#0, %79 : i32
%91 = arith.andi %86, %90 : i1
scf.if %91 {
%92 = arith.muli %85#2, %c128_i32 : i32
%93 = tt.splat %92 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%94 = arith.addi %93, %77 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%95 = arith.muli %85#3, %c256_i32 : i32
%96 = tt.splat %95 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%97 = arith.addi %96, %78 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>>
%98 = tt.expand_dims %94 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%99 = tt.expand_dims %77 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%100 = arith.muli %arg8, %92 : i32
%101 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%102 = arith.muli %101, %99 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%103 = tt.addptr %arg2, %100 : !tt.ptr<f16>, i32
%104 = tt.expand_dims %97 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%105 = tt.broadcast %102 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%106 = tt.expand_dims %78 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%107 = tt.broadcast %106 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%108 = tt.addptr %103, %95 : !tt.ptr<f16>, i32
%109 = arith.addi %107, %105 : tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%110 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%111 = arith.cmpi slt, %98, %110 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%112 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%113 = arith.cmpi slt, %104, %112 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%114 = tt.broadcast %111 : tensor<128x1xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%115 = tt.broadcast %113 : tensor<1x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%116 = arith.andi %114, %115 : tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%117 = arith.truncf %89 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> to tensor<128x256xf16, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%118 = tt.splat %108 : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
%119 = tt.addptr %118, %109 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
tt.store %119, %117, %116 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>
}
ttg.local_dealloc %80 : !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>
ttg.local_dealloc %81 : !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>
tt.return
}
}
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
%true = arith.constant true
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked>
%cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1>
%c1_i32 = arith.constant 1 : i32
%c304_i32 = arith.constant 304 : i32
%c64_i32 = arith.constant 64 : i32
%c63_i32 = arith.constant 63 : i32
%cst_3 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%c128_i32 = arith.constant 128 : i32
%c127_i32 = arith.constant 127 : i32
%c8_i32 = arith.constant 8 : i32
%c256_i32 = arith.constant 256 : i32
%c255_i32 = arith.constant 255 : i32
%c0_i32 = arith.constant 0 : i32
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%1 = tt.get_program_id x : i32
%2 = arith.addi %arg4, %c255_i32 : i32
%3 = arith.divsi %2, %c256_i32 : i32
%4 = arith.muli %3, %c8_i32 : i32
%5 = arith.divsi %1, %4 : i32
%6 = arith.muli %5, %c8_i32 : i32
%7 = arith.addi %arg3, %c127_i32 : i32
%8 = arith.divsi %7, %c128_i32 : i32
%9 = arith.subi %8, %6 : i32
%10 = arith.minsi %9, %c8_i32 : i32
%11 = arith.remsi %1, %10 : i32
%12 = arith.addi %6, %11 : i32
%13 = arith.muli %12, %c128_i32 : i32
%14 = arith.subi %arg3, %13 : i32
%15 = tt.splat %14 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%16 = arith.cmpi slt, %0, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%17 = arith.select %16, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%19 = tt.expand_dims %17 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%20 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1>
%21 = arith.muli %19, %20 : tensor<128x1xi32, #blocked1>
%22 = tt.broadcast %21 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%23 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%24 = tt.broadcast %23 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%25 = arith.addi %22, %24 : tensor<128x64xi32, #blocked1>
%26 = arith.addi %arg5, %c63_i32 : i32
%27 = arith.divsi %26, %c64_i32 : i32
%28 = arith.muli %8, %3 : i32
%29 = arith.divsi %28, %c304_i32 : i32
%30 = arith.remsi %28, %c304_i32 : i32
%31 = arith.cmpi slt, %1, %30 : i32
%32 = scf.if %31 -> (i32) {
%88 = arith.addi %29, %c1_i32 : i32
scf.yield %88 : i32
} else {
scf.yield %29 : i32
}
%33 = arith.muli %27, %32 : i32
%34 = arith.cmpi sgt, %33, %c0_i32 : i32
%35 = tt.splat %34 : i1 -> tensor<128x64xi1, #blocked1>
%36 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
%37 = tt.expand_dims %36 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%38 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1>
%39 = arith.cmpi slt, %37, %38 : tensor<1x64xi32, #blocked1>
%40 = tt.broadcast %39 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1>
%41 = arith.andi %35, %40 : tensor<128x64xi1, #blocked1>
%42 = amdgpu.buffer_load %arg0[%25], %41 stride = %arg6 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #blocked1>
%43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%44 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%45 = arith.remsi %1, %4 : i32
%46 = arith.divsi %45, %10 : i32
%47 = arith.muli %46, %c256_i32 : i32
%48 = arith.subi %arg4, %47 : i32
%49 = tt.splat %48 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%50 = arith.cmpi slt, %44, %49 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%51 = arith.select %50, %44, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%52 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
%53 = tt.broadcast %52 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked>
%54 = tt.expand_dims %51 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked>
%55 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked>
%56 = arith.muli %54, %55 : tensor<1x256xi32, #blocked>
%57 = tt.broadcast %56 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked>
%58 = arith.addi %53, %57 : tensor<64x256xi32, #blocked>
%59 = tt.splat %34 : i1 -> tensor<64x256xi1, #blocked>
%60 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%61 = tt.expand_dims %60 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
%62 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked>
%63 = arith.cmpi slt, %61, %62 : tensor<64x1xi32, #blocked>
%64 = tt.broadcast %63 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked>
%65 = arith.andi %59, %64 : tensor<64x256xi1, #blocked>
%66 = amdgpu.buffer_load %arg1[%58], %65 stride = %arg7 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #blocked>
%67 = arith.cmpi sgt, %arg3, %c0_i32 : i32
llvm.intr.assume %67 : i1
%68 = arith.cmpi sgt, %arg4, %c0_i32 : i32
llvm.intr.assume %68 : i1
%69 = arith.cmpi sgt, %arg5, %c0_i32 : i32
llvm.intr.assume %69 : i1
%70 = arith.cmpi sgt, %arg6, %c0_i32 : i32
llvm.intr.assume %70 : i1
llvm.intr.assume %true : i1
llvm.intr.assume %true : i1
%71 = arith.cmpi sgt, %arg7, %c0_i32 : i32
llvm.intr.assume %71 : i1
%72 = arith.cmpi sgt, %arg8, %c0_i32 : i32
llvm.intr.assume %72 : i1
llvm.intr.assume %true : i1
%73 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%74 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%75 = arith.subi %27, %c1_i32 : i32
%76 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>
%77 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable>
%78 = ttg.memdesc_subview %76[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
ttg.local_store %42, %78 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
%79 = ttg.memdesc_subview %77[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>
ttg.local_store %66, %79 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>
%80 = arith.subi %33, %c1_i32 : i32
%81:10 = scf.for %arg9 = %c0_i32 to %80 step %c1_i32 iter_args(%arg10 = %c0_i32, %arg11 = %1, %arg12 = %12, %arg13 = %46, %arg14 = %cst, %arg15 = %17, %arg16 = %51, %arg17 = %c0_i32, %arg18 = %78, %arg19 = %79) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) : i32 {
%88 = arith.cmpi eq, %arg10, %75 : i32
%89 = arith.addi %arg10, %c1_i32 : i32
%90 = arith.select %88, %c0_i32, %89 : i32
%91 = arith.cmpi eq, %90, %c0_i32 : i32
%92:5 = scf.if %91 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) {
%133 = arith.addi %arg11, %c304_i32 : i32
%134 = arith.divsi %133, %4 : i32
%135 = arith.muli %134, %c8_i32 : i32
%136 = arith.subi %8, %135 : i32
%137 = arith.minsi %136, %c8_i32 : i32
%138 = arith.remsi %133, %137 : i32
%139 = arith.addi %135, %138 : i32
%140 = arith.remsi %133, %4 : i32
%141 = arith.divsi %140, %137 : i32
%142 = arith.muli %139, %c128_i32 : i32
%143 = arith.muli %141, %c256_i32 : i32
%144 = arith.subi %arg3, %142 : i32
%145 = tt.splat %144 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%146 = arith.cmpi slt, %0, %145 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%147 = arith.select %146, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
%148 = arith.subi %arg4, %143 : i32
%149 = tt.splat %148 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%150 = arith.cmpi slt, %44, %149 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%151 = arith.select %150, %44, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
scf.yield %133, %139, %141, %147, %151 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
} else {
scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
}
%93 = arith.muli %90, %c64_i32 : i32
%94 = tt.expand_dims %92#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
%95 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1>
%96 = arith.muli %94, %95 : tensor<128x1xi32, #blocked1>
%97 = tt.broadcast %96 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%98 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1>
%99 = tt.broadcast %98 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1>
%100 = arith.addi %97, %99 : tensor<128x64xi32, #blocked1>
%101 = tt.addptr %arg0, %93 : !tt.ptr<f16>, i32
%102 = arith.subi %arg5, %93 : i32
%103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1>
%104 = arith.cmpi slt, %37, %103 : tensor<1x64xi32, #blocked1>
%105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1>
%106 = tt.splat %101 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked1>
%107 = tt.addptr %106, %100 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1>
%108 = tt.load %107, %105, %cst_2 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64x!tt.ptr<f16>, #blocked1>
%109 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked>
%110 = tt.broadcast %109 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked>
%111 = tt.expand_dims %92#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked>
%112 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked>
%113 = arith.muli %111, %112 : tensor<1x256xi32, #blocked>
%114 = tt.broadcast %113 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked>
%115 = arith.addi %110, %114 : tensor<64x256xi32, #blocked>
%116 = tt.addptr %arg1, %93 : !tt.ptr<f16>, i32
%117 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked>
%118 = arith.cmpi slt, %61, %117 : tensor<64x1xi32, #blocked>
%119 = tt.broadcast %118 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked>
%120 = ttg.local_load %arg18 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%121 = ttg.local_load %arg19 : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%122 = tt.splat %116 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked>
%123 = tt.addptr %122, %115 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked>
%124 = tt.load %123, %119, %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256x!tt.ptr<f16>, #blocked>
%125 = tt.dot %120, %121, %arg14, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x256xf32, #mma>
%126 = arith.cmpi eq, %arg10, %75 : i32
%127 = arith.select %126, %cst, %125 : tensor<128x256xf32, #mma>
scf.if %126 {
%133 = arith.muli %arg12, %c128_i32 : i32
%134 = tt.splat %133 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%135 = arith.addi %134, %73 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%136 = arith.muli %arg13, %c256_i32 : i32
%137 = tt.splat %136 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%138 = arith.addi %137, %74 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%139 = tt.expand_dims %135 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
%140 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
%141 = arith.muli %arg8, %133 : i32
%142 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #mma>
%143 = arith.muli %142, %140 : tensor<128x1xi32, #mma>
%144 = tt.addptr %arg2, %141 : !tt.ptr<f16>, i32
%145 = tt.expand_dims %138 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma>
%146 = tt.broadcast %143 : tensor<128x1xi32, #mma> -> tensor<128x256xi32, #mma>
%147 = tt.expand_dims %74 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma>
%148 = tt.broadcast %147 : tensor<1x256xi32, #mma> -> tensor<128x256xi32, #mma>
%149 = tt.addptr %144, %136 : !tt.ptr<f16>, i32
%150 = arith.addi %148, %146 : tensor<128x256xi32, #mma>
%151 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #mma>
%152 = arith.cmpi slt, %139, %151 : tensor<128x1xi32, #mma>
%153 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #mma>
%154 = arith.cmpi slt, %145, %153 : tensor<1x256xi32, #mma>
%155 = tt.broadcast %152 : tensor<128x1xi1, #mma> -> tensor<128x256xi1, #mma>
%156 = tt.broadcast %154 : tensor<1x256xi1, #mma> -> tensor<128x256xi1, #mma>
%157 = arith.andi %155, %156 : tensor<128x256xi1, #mma>
%158 = arith.truncf %125 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma>
amdgpu.buffer_store %158, %149[%150], %157 stride = %arg8 : tensor<128x256xf16, #mma>
}
%128 = arith.addi %arg17, %c1_i32 : i32
%129 = arith.cmpi slt, %128, %c1_i32 : i32
%130 = arith.select %129, %128, %c0_i32 : i32
%131 = ttg.memdesc_subview %76[%130, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
ttg.local_store %108, %131 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>
%132 = ttg.memdesc_subview %77[%130, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>
ttg.local_store %124, %132 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>
scf.yield %90, %92#0, %92#1, %92#2, %127, %92#3, %92#4, %130, %131, %132 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>
}
%82 = arith.cmpi sge, %33, %c1_i32 : i32
%83 = ttg.local_load %81#8 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
%84 = ttg.local_load %81#9 : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>
%85 = scf.if %82 -> (tensor<128x256xf32, #mma>) {
%88 = tt.dot %83, %84, %81#4, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x256xf32, #mma>
scf.yield %88 : tensor<128x256xf32, #mma>
} else {
scf.yield %81#4 : tensor<128x256xf32, #mma>
}
%86 = arith.cmpi eq, %81#0, %75 : i32
%87 = arith.andi %82, %86 : i1
scf.if %87 {
%88 = arith.muli %81#2, %c128_i32 : i32
%89 = tt.splat %88 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%90 = arith.addi %89, %73 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>>
%91 = arith.muli %81#3, %c256_i32 : i32
%92 = tt.splat %91 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%93 = arith.addi %92, %74 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>>
%94 = tt.expand_dims %90 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
%95 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma>
%96 = arith.muli %arg8, %88 : i32
%97 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #mma>
%98 = arith.muli %97, %95 : tensor<128x1xi32, #mma>
%99 = tt.addptr %arg2, %96 : !tt.ptr<f16>, i32
%100 = tt.expand_dims %93 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma>
%101 = tt.broadcast %98 : tensor<128x1xi32, #mma> -> tensor<128x256xi32, #mma>
%102 = tt.expand_dims %74 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma>
%103 = tt.broadcast %102 : tensor<1x256xi32, #mma> -> tensor<128x256xi32, #mma>
%104 = tt.addptr %99, %91 : !tt.ptr<f16>, i32
%105 = arith.addi %103, %101 : tensor<128x256xi32, #mma>
%106 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #mma>
%107 = arith.cmpi slt, %94, %106 : tensor<128x1xi32, #mma>
%108 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #mma>
%109 = arith.cmpi slt, %100, %108 : tensor<1x256xi32, #mma>
%110 = tt.broadcast %107 : tensor<128x1xi1, #mma> -> tensor<128x256xi1, #mma>
%111 = tt.broadcast %109 : tensor<1x256xi1, #mma> -> tensor<128x256xi1, #mma>
%112 = arith.andi %110, %111 : tensor<128x256xi1, #mma>
%113 = arith.truncf %85 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma>
amdgpu.buffer_store %113, %104[%105], %112 stride = %arg8 : tensor<128x256xf16, #mma>
}
ttg.local_dealloc %76 : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>
ttg.local_dealloc %77 : !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable>
tt.return
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment