Created
April 4, 2025 19:34
-
-
Save makslevental/85eafd04981f5bada72bfd66586f6ef9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { | |
tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%true = arith.constant true | |
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%c1_i32 = arith.constant 1 : i32 | |
%c304_i32 = arith.constant 304 : i32 | |
%c64_i32 = arith.constant 64 : i32 | |
%c63_i32 = arith.constant 63 : i32 | |
%cst_3 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%c128_i32 = arith.constant 128 : i32 | |
%c127_i32 = arith.constant 127 : i32 | |
%c8_i32 = arith.constant 8 : i32 | |
%c256_i32 = arith.constant 256 : i32 | |
%c255_i32 = arith.constant 255 : i32 | |
%c0_i32 = arith.constant 0 : i32 | |
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%1 = tt.get_program_id x : i32 | |
%2 = arith.addi %arg4, %c255_i32 : i32 | |
%3 = arith.divsi %2, %c256_i32 : i32 | |
%4 = arith.muli %3, %c8_i32 : i32 | |
%5 = arith.divsi %1, %4 : i32 | |
%6 = arith.muli %5, %c8_i32 : i32 | |
%7 = arith.addi %arg3, %c127_i32 : i32 | |
%8 = arith.divsi %7, %c128_i32 : i32 | |
%9 = arith.subi %8, %6 : i32 | |
%10 = arith.minsi %9, %c8_i32 : i32 | |
%11 = arith.remsi %1, %10 : i32 | |
%12 = arith.addi %6, %11 : i32 | |
%13 = arith.muli %12, %c128_i32 : i32 | |
%14 = arith.subi %arg3, %13 : i32 | |
%15 = tt.splat %14 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%16 = arith.cmpi slt, %0, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%17 = arith.select %16, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%19 = tt.expand_dims %17 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%20 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%21 = arith.muli %19, %20 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%22 = tt.broadcast %21 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%23 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%24 = tt.broadcast %23 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%25 = arith.addi %22, %24 : tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%26 = arith.addi %arg5, %c63_i32 : i32 | |
%27 = arith.divsi %26, %c64_i32 : i32 | |
%28 = arith.muli %8, %3 : i32 | |
%29 = arith.divsi %28, %c304_i32 : i32 | |
%30 = arith.remsi %28, %c304_i32 : i32 | |
%31 = arith.cmpi slt, %1, %30 : i32 | |
%32 = scf.if %31 -> (i32) { | |
%92 = arith.addi %29, %c1_i32 : i32 | |
scf.yield %92 : i32 | |
} else { | |
scf.yield %29 : i32 | |
} | |
%33 = arith.muli %27, %32 : i32 | |
%34 = arith.cmpi sgt, %33, %c0_i32 : i32 | |
%35 = tt.splat %34 : i1 -> tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%36 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%37 = tt.expand_dims %36 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%38 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%39 = arith.cmpi slt, %37, %38 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%40 = tt.broadcast %39 : tensor<1x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%41 = arith.andi %35, %40 : tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%42 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%43 = tt.addptr %42, %25 : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>, tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%44 = tt.load %43, %41, %cst_2 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%45 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%46 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%47 = arith.remsi %1, %4 : i32 | |
%48 = arith.divsi %47, %10 : i32 | |
%49 = arith.muli %48, %c256_i32 : i32 | |
%50 = arith.subi %arg4, %49 : i32 | |
%51 = tt.splat %50 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%52 = arith.cmpi slt, %46, %51 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%53 = arith.select %52, %46, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%54 = tt.expand_dims %45 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%55 = tt.broadcast %54 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%56 = tt.expand_dims %53 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%57 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%58 = arith.muli %56, %57 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%59 = tt.broadcast %58 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%60 = arith.addi %55, %59 : tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%61 = tt.splat %34 : i1 -> tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%62 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%63 = tt.expand_dims %62 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%64 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%65 = arith.cmpi slt, %63, %64 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%66 = tt.broadcast %65 : tensor<64x1xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%67 = arith.andi %61, %66 : tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%68 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%69 = tt.addptr %68, %60 : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>, tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%70 = tt.load %69, %67, %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%71 = arith.cmpi sgt, %arg3, %c0_i32 : i32 | |
llvm.intr.assume %71 : i1 | |
%72 = arith.cmpi sgt, %arg4, %c0_i32 : i32 | |
llvm.intr.assume %72 : i1 | |
%73 = arith.cmpi sgt, %arg5, %c0_i32 : i32 | |
llvm.intr.assume %73 : i1 | |
%74 = arith.cmpi sgt, %arg6, %c0_i32 : i32 | |
llvm.intr.assume %74 : i1 | |
llvm.intr.assume %true : i1 | |
llvm.intr.assume %true : i1 | |
%75 = arith.cmpi sgt, %arg7, %c0_i32 : i32 | |
llvm.intr.assume %75 : i1 | |
%76 = arith.cmpi sgt, %arg8, %c0_i32 : i32 | |
llvm.intr.assume %76 : i1 | |
llvm.intr.assume %true : i1 | |
%77 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%78 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%79 = arith.subi %27, %c1_i32 : i32 | |
%80 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> | |
%81 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
%82 = ttg.memdesc_subview %80[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> | |
ttg.local_store %44, %82 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> | |
%83 = ttg.memdesc_subview %81[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
ttg.local_store %70, %83 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
%84 = arith.subi %33, %c1_i32 : i32 | |
%85:10 = scf.for %arg9 = %c0_i32 to %84 step %c1_i32 iter_args(%arg10 = %c0_i32, %arg11 = %1, %arg12 = %12, %arg13 = %48, %arg14 = %cst, %arg15 = %17, %arg16 = %53, %arg17 = %c0_i32, %arg18 = %82, %arg19 = %83) -> (i32, i32, i32, i32, tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, i32, !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable>) : i32 { | |
%92 = arith.cmpi eq, %arg10, %79 : i32 | |
%93 = arith.addi %arg10, %c1_i32 : i32 | |
%94 = arith.select %92, %c0_i32, %93 : i32 | |
%95 = arith.cmpi eq, %94, %c0_i32 : i32 | |
%96:5 = scf.if %95 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>) { | |
%137 = arith.addi %arg11, %c304_i32 : i32 | |
%138 = arith.divsi %137, %4 : i32 | |
%139 = arith.muli %138, %c8_i32 : i32 | |
%140 = arith.subi %8, %139 : i32 | |
%141 = arith.minsi %140, %c8_i32 : i32 | |
%142 = arith.remsi %137, %141 : i32 | |
%143 = arith.addi %139, %142 : i32 | |
%144 = arith.remsi %137, %4 : i32 | |
%145 = arith.divsi %144, %141 : i32 | |
%146 = arith.muli %143, %c128_i32 : i32 | |
%147 = arith.muli %145, %c256_i32 : i32 | |
%148 = arith.subi %arg3, %146 : i32 | |
%149 = tt.splat %148 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%150 = arith.cmpi slt, %0, %149 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%151 = arith.select %150, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> | |
%152 = arith.subi %arg4, %147 : i32 | |
%153 = tt.splat %152 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%154 = arith.cmpi slt, %46, %153 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
%155 = arith.select %154, %46, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
scf.yield %137, %143, %145, %151, %155 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
} else { | |
scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> | |
} | |
%97 = arith.muli %94, %c64_i32 : i32 | |
%98 = tt.expand_dims %96#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%99 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%100 = arith.muli %98, %99 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%101 = tt.broadcast %100 : tensor<128x1xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%102 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>> -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%103 = tt.broadcast %102 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%104 = arith.addi %101, %103 : tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%105 = tt.addptr %arg0, %97 : !tt.ptr<f16>, i32 | |
%106 = arith.subi %arg5, %97 : i32 | |
%107 = tt.splat %106 : i32 -> tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%108 = arith.cmpi slt, %37, %107 : tensor<1x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%109 = tt.broadcast %108 : tensor<1x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> tensor<128x64xi1, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%110 = tt.splat %105 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%111 = tt.addptr %110, %104 : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>>, tensor<128x64xi32, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%112 = tt.load %111, %109, %cst_2 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> | |
%113 = tt.expand_dims %45 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%114 = tt.broadcast %113 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%115 = tt.expand_dims %96#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>> -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%116 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%117 = arith.muli %115, %116 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%118 = tt.broadcast %117 : tensor<1x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%119 = arith.addi %114, %118 : tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%120 = tt.addptr %arg1, %97 : !tt.ptr<f16>, i32 | |
%121 = tt.splat %106 : i32 -> tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%122 = arith.cmpi slt, %63, %121 : tensor<64x1xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%123 = tt.broadcast %122 : tensor<64x1xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> tensor<64x256xi1, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%124 = ttg.local_load %arg18 : !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> | |
%125 = ttg.local_load %arg19 : !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> | |
%126 = tt.splat %120 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%127 = tt.addptr %126, %119 : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>>, tensor<64x256xi32, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%128 = tt.load %127, %123, %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256x!tt.ptr<f16>, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> | |
%129 = tt.dot %124, %125, %arg14, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> -> tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%130 = arith.cmpi eq, %arg10, %79 : i32 | |
%131 = arith.select %130, %cst, %129 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
scf.if %130 { | |
%137 = arith.muli %arg12, %c128_i32 : i32 | |
%138 = tt.splat %137 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%139 = arith.addi %138, %77 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%140 = arith.muli %arg13, %c256_i32 : i32 | |
%141 = tt.splat %140 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%142 = arith.addi %141, %78 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%143 = tt.expand_dims %139 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%144 = tt.expand_dims %77 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%145 = arith.muli %arg8, %137 : i32 | |
%146 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%147 = arith.muli %146, %144 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%148 = tt.addptr %arg2, %145 : !tt.ptr<f16>, i32 | |
%149 = tt.expand_dims %142 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%150 = tt.broadcast %147 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%151 = tt.expand_dims %78 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%152 = tt.broadcast %151 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%153 = tt.addptr %148, %140 : !tt.ptr<f16>, i32 | |
%154 = arith.addi %152, %150 : tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%155 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%156 = arith.cmpi slt, %143, %155 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%157 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%158 = arith.cmpi slt, %149, %157 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%159 = tt.broadcast %156 : tensor<128x1xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%160 = tt.broadcast %158 : tensor<1x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%161 = arith.andi %159, %160 : tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%162 = arith.truncf %129 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> to tensor<128x256xf16, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%163 = tt.splat %153 : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%164 = tt.addptr %163, %154 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
tt.store %164, %162, %161 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
} | |
%132 = arith.addi %arg17, %c1_i32 : i32 | |
%133 = arith.cmpi slt, %132, %c1_i32 : i32 | |
%134 = arith.select %133, %132, %c0_i32 : i32 | |
%135 = ttg.memdesc_subview %80[%134, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> | |
ttg.local_store %112, %135 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>> -> !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> | |
%136 = ttg.memdesc_subview %81[%134, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
ttg.local_store %128, %136 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>> -> !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
scf.yield %94, %96#0, %96#1, %96#2, %131, %96#3, %96#4, %134, %135, %136 : i32, i32, i32, i32, tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>}>>, i32, !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
} | |
%86 = arith.cmpi sge, %33, %c1_i32 : i32 | |
%87 = ttg.local_load %85#8 : !ttg.memdesc<128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> | |
%88 = ttg.local_load %85#9 : !ttg.memdesc<64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> | |
%89 = scf.if %86 -> (tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>) { | |
%92 = tt.dot %87, %88, %85#4, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>, kWidth = 4}>> -> tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
scf.yield %92 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
} else { | |
scf.yield %85#4 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
} | |
%90 = arith.cmpi eq, %85#0, %79 : i32 | |
%91 = arith.andi %86, %90 : i1 | |
scf.if %91 { | |
%92 = arith.muli %85#2, %c128_i32 : i32 | |
%93 = tt.splat %92 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%94 = arith.addi %93, %77 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%95 = arith.muli %85#3, %c256_i32 : i32 | |
%96 = tt.splat %95 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%97 = arith.addi %96, %78 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> | |
%98 = tt.expand_dims %94 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%99 = tt.expand_dims %77 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%100 = arith.muli %arg8, %92 : i32 | |
%101 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%102 = arith.muli %101, %99 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%103 = tt.addptr %arg2, %100 : !tt.ptr<f16>, i32 | |
%104 = tt.expand_dims %97 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%105 = tt.broadcast %102 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%106 = tt.expand_dims %78 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>}>> -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%107 = tt.broadcast %106 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%108 = tt.addptr %103, %95 : !tt.ptr<f16>, i32 | |
%109 = arith.addi %107, %105 : tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%110 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%111 = arith.cmpi slt, %98, %110 : tensor<128x1xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%112 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%113 = arith.cmpi slt, %104, %112 : tensor<1x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%114 = tt.broadcast %111 : tensor<128x1xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%115 = tt.broadcast %113 : tensor<1x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> -> tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%116 = arith.andi %114, %115 : tensor<128x256xi1, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%117 = arith.truncf %89 : tensor<128x256xf32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> to tensor<128x256xf16, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%118 = tt.splat %108 : !tt.ptr<f16> -> tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
%119 = tt.addptr %118, %109 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>>, tensor<128x256xi32, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
tt.store %119, %117, %116 : tensor<128x256x!tt.ptr<f16>, #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}>> | |
} | |
ttg.local_dealloc %80 : !ttg.memdesc<1x128x64xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}>, #ttg.shared_memory, mutable> | |
ttg.local_dealloc %81 : !ttg.memdesc<1x64x256xf16, #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}>, #ttg.shared_memory, mutable> | |
tt.return | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> | |
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> | |
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> | |
#shared = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0]}> | |
#shared1 = #ttg.swizzled_shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1]}> | |
#smem = #ttg.shared_memory | |
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { | |
tt.func public @matmul_kernel_persistent(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { | |
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> | |
%true = arith.constant true | |
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked> | |
%cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> | |
%c1_i32 = arith.constant 1 : i32 | |
%c304_i32 = arith.constant 304 : i32 | |
%c64_i32 = arith.constant 64 : i32 | |
%c63_i32 = arith.constant 63 : i32 | |
%cst_3 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%c128_i32 = arith.constant 128 : i32 | |
%c127_i32 = arith.constant 127 : i32 | |
%c8_i32 = arith.constant 8 : i32 | |
%c256_i32 = arith.constant 256 : i32 | |
%c255_i32 = arith.constant 255 : i32 | |
%c0_i32 = arith.constant 0 : i32 | |
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%1 = tt.get_program_id x : i32 | |
%2 = arith.addi %arg4, %c255_i32 : i32 | |
%3 = arith.divsi %2, %c256_i32 : i32 | |
%4 = arith.muli %3, %c8_i32 : i32 | |
%5 = arith.divsi %1, %4 : i32 | |
%6 = arith.muli %5, %c8_i32 : i32 | |
%7 = arith.addi %arg3, %c127_i32 : i32 | |
%8 = arith.divsi %7, %c128_i32 : i32 | |
%9 = arith.subi %8, %6 : i32 | |
%10 = arith.minsi %9, %c8_i32 : i32 | |
%11 = arith.remsi %1, %10 : i32 | |
%12 = arith.addi %6, %11 : i32 | |
%13 = arith.muli %12, %c128_i32 : i32 | |
%14 = arith.subi %arg3, %13 : i32 | |
%15 = tt.splat %14 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%16 = arith.cmpi slt, %0, %15 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%17 = arith.select %16, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%18 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%19 = tt.expand_dims %17 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> | |
%20 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> | |
%21 = arith.muli %19, %20 : tensor<128x1xi32, #blocked1> | |
%22 = tt.broadcast %21 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> | |
%23 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> | |
%24 = tt.broadcast %23 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> | |
%25 = arith.addi %22, %24 : tensor<128x64xi32, #blocked1> | |
%26 = arith.addi %arg5, %c63_i32 : i32 | |
%27 = arith.divsi %26, %c64_i32 : i32 | |
%28 = arith.muli %8, %3 : i32 | |
%29 = arith.divsi %28, %c304_i32 : i32 | |
%30 = arith.remsi %28, %c304_i32 : i32 | |
%31 = arith.cmpi slt, %1, %30 : i32 | |
%32 = scf.if %31 -> (i32) { | |
%88 = arith.addi %29, %c1_i32 : i32 | |
scf.yield %88 : i32 | |
} else { | |
scf.yield %29 : i32 | |
} | |
%33 = arith.muli %27, %32 : i32 | |
%34 = arith.cmpi sgt, %33, %c0_i32 : i32 | |
%35 = tt.splat %34 : i1 -> tensor<128x64xi1, #blocked1> | |
%36 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> | |
%37 = tt.expand_dims %36 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> | |
%38 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked1> | |
%39 = arith.cmpi slt, %37, %38 : tensor<1x64xi32, #blocked1> | |
%40 = tt.broadcast %39 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> | |
%41 = arith.andi %35, %40 : tensor<128x64xi1, #blocked1> | |
%42 = amdgpu.buffer_load %arg0[%25], %41 stride = %arg6 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #blocked1> | |
%43 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%44 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%45 = arith.remsi %1, %4 : i32 | |
%46 = arith.divsi %45, %10 : i32 | |
%47 = arith.muli %46, %c256_i32 : i32 | |
%48 = arith.subi %arg4, %47 : i32 | |
%49 = tt.splat %48 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%50 = arith.cmpi slt, %44, %49 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%51 = arith.select %50, %44, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%52 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> | |
%53 = tt.broadcast %52 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> | |
%54 = tt.expand_dims %51 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> | |
%55 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> | |
%56 = arith.muli %54, %55 : tensor<1x256xi32, #blocked> | |
%57 = tt.broadcast %56 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> | |
%58 = arith.addi %53, %57 : tensor<64x256xi32, #blocked> | |
%59 = tt.splat %34 : i1 -> tensor<64x256xi1, #blocked> | |
%60 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> | |
%61 = tt.expand_dims %60 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> | |
%62 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked> | |
%63 = arith.cmpi slt, %61, %62 : tensor<64x1xi32, #blocked> | |
%64 = tt.broadcast %63 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> | |
%65 = arith.andi %59, %64 : tensor<64x256xi1, #blocked> | |
%66 = amdgpu.buffer_load %arg1[%58], %65 stride = %arg7 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #blocked> | |
%67 = arith.cmpi sgt, %arg3, %c0_i32 : i32 | |
llvm.intr.assume %67 : i1 | |
%68 = arith.cmpi sgt, %arg4, %c0_i32 : i32 | |
llvm.intr.assume %68 : i1 | |
%69 = arith.cmpi sgt, %arg5, %c0_i32 : i32 | |
llvm.intr.assume %69 : i1 | |
%70 = arith.cmpi sgt, %arg6, %c0_i32 : i32 | |
llvm.intr.assume %70 : i1 | |
llvm.intr.assume %true : i1 | |
llvm.intr.assume %true : i1 | |
%71 = arith.cmpi sgt, %arg7, %c0_i32 : i32 | |
llvm.intr.assume %71 : i1 | |
%72 = arith.cmpi sgt, %arg8, %c0_i32 : i32 | |
llvm.intr.assume %72 : i1 | |
llvm.intr.assume %true : i1 | |
%73 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> | |
%74 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> | |
%75 = arith.subi %27, %c1_i32 : i32 | |
%76 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> | |
%77 = ttg.local_alloc : () -> !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable> | |
%78 = ttg.memdesc_subview %76[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> | |
ttg.local_store %42, %78 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> | |
%79 = ttg.memdesc_subview %77[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> | |
ttg.local_store %66, %79 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> | |
%80 = arith.subi %33, %c1_i32 : i32 | |
%81:10 = scf.for %arg9 = %c0_i32 to %80 step %c1_i32 iter_args(%arg10 = %c0_i32, %arg11 = %1, %arg12 = %12, %arg13 = %46, %arg14 = %cst, %arg15 = %17, %arg16 = %51, %arg17 = %c0_i32, %arg18 = %78, %arg19 = %79) -> (i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable>) : i32 { | |
%88 = arith.cmpi eq, %arg10, %75 : i32 | |
%89 = arith.addi %arg10, %c1_i32 : i32 | |
%90 = arith.select %88, %c0_i32, %89 : i32 | |
%91 = arith.cmpi eq, %90, %c0_i32 : i32 | |
%92:5 = scf.if %91 -> (i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>) { | |
%133 = arith.addi %arg11, %c304_i32 : i32 | |
%134 = arith.divsi %133, %4 : i32 | |
%135 = arith.muli %134, %c8_i32 : i32 | |
%136 = arith.subi %8, %135 : i32 | |
%137 = arith.minsi %136, %c8_i32 : i32 | |
%138 = arith.remsi %133, %137 : i32 | |
%139 = arith.addi %135, %138 : i32 | |
%140 = arith.remsi %133, %4 : i32 | |
%141 = arith.divsi %140, %137 : i32 | |
%142 = arith.muli %139, %c128_i32 : i32 | |
%143 = arith.muli %141, %c256_i32 : i32 | |
%144 = arith.subi %arg3, %142 : i32 | |
%145 = tt.splat %144 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%146 = arith.cmpi slt, %0, %145 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%147 = arith.select %146, %0, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> | |
%148 = arith.subi %arg4, %143 : i32 | |
%149 = tt.splat %148 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%150 = arith.cmpi slt, %44, %149 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
%151 = arith.select %150, %44, %cst_1 {tt.contiguity = dense<256> : tensor<1xi32>, tt.divisibility = dense<256> : tensor<1xi32>} : tensor<256xi1, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
scf.yield %133, %139, %141, %147, %151 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
} else { | |
scf.yield %arg11, %arg12, %arg13, %arg15, %arg16 : i32, i32, i32, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> | |
} | |
%93 = arith.muli %90, %c64_i32 : i32 | |
%94 = tt.expand_dims %92#3 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> | |
%95 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked1> | |
%96 = arith.muli %94, %95 : tensor<128x1xi32, #blocked1> | |
%97 = tt.broadcast %96 : tensor<128x1xi32, #blocked1> -> tensor<128x64xi32, #blocked1> | |
%98 = tt.expand_dims %18 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> | |
%99 = tt.broadcast %98 : tensor<1x64xi32, #blocked1> -> tensor<128x64xi32, #blocked1> | |
%100 = arith.addi %97, %99 : tensor<128x64xi32, #blocked1> | |
%101 = tt.addptr %arg0, %93 : !tt.ptr<f16>, i32 | |
%102 = arith.subi %arg5, %93 : i32 | |
%103 = tt.splat %102 : i32 -> tensor<1x64xi32, #blocked1> | |
%104 = arith.cmpi slt, %37, %103 : tensor<1x64xi32, #blocked1> | |
%105 = tt.broadcast %104 : tensor<1x64xi1, #blocked1> -> tensor<128x64xi1, #blocked1> | |
%106 = tt.splat %101 : !tt.ptr<f16> -> tensor<128x64x!tt.ptr<f16>, #blocked1> | |
%107 = tt.addptr %106, %100 : tensor<128x64x!tt.ptr<f16>, #blocked1>, tensor<128x64xi32, #blocked1> | |
%108 = tt.load %107, %105, %cst_2 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64x!tt.ptr<f16>, #blocked1> | |
%109 = tt.expand_dims %43 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> | |
%110 = tt.broadcast %109 : tensor<64x1xi32, #blocked> -> tensor<64x256xi32, #blocked> | |
%111 = tt.expand_dims %92#4 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x256xi32, #blocked> | |
%112 = tt.splat %arg7 : i32 -> tensor<1x256xi32, #blocked> | |
%113 = arith.muli %111, %112 : tensor<1x256xi32, #blocked> | |
%114 = tt.broadcast %113 : tensor<1x256xi32, #blocked> -> tensor<64x256xi32, #blocked> | |
%115 = arith.addi %110, %114 : tensor<64x256xi32, #blocked> | |
%116 = tt.addptr %arg1, %93 : !tt.ptr<f16>, i32 | |
%117 = tt.splat %102 : i32 -> tensor<64x1xi32, #blocked> | |
%118 = arith.cmpi slt, %61, %117 : tensor<64x1xi32, #blocked> | |
%119 = tt.broadcast %118 : tensor<64x1xi1, #blocked> -> tensor<64x256xi1, #blocked> | |
%120 = ttg.local_load %arg18 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> | |
%121 = ttg.local_load %arg19 : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> | |
%122 = tt.splat %116 : !tt.ptr<f16> -> tensor<64x256x!tt.ptr<f16>, #blocked> | |
%123 = tt.addptr %122, %115 : tensor<64x256x!tt.ptr<f16>, #blocked>, tensor<64x256xi32, #blocked> | |
%124 = tt.load %123, %119, %cst_0 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256x!tt.ptr<f16>, #blocked> | |
%125 = tt.dot %120, %121, %arg14, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x256xf32, #mma> | |
%126 = arith.cmpi eq, %arg10, %75 : i32 | |
%127 = arith.select %126, %cst, %125 : tensor<128x256xf32, #mma> | |
scf.if %126 { | |
%133 = arith.muli %arg12, %c128_i32 : i32 | |
%134 = tt.splat %133 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> | |
%135 = arith.addi %134, %73 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> | |
%136 = arith.muli %arg13, %c256_i32 : i32 | |
%137 = tt.splat %136 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> | |
%138 = arith.addi %137, %74 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> | |
%139 = tt.expand_dims %135 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> | |
%140 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> | |
%141 = arith.muli %arg8, %133 : i32 | |
%142 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #mma> | |
%143 = arith.muli %142, %140 : tensor<128x1xi32, #mma> | |
%144 = tt.addptr %arg2, %141 : !tt.ptr<f16>, i32 | |
%145 = tt.expand_dims %138 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma> | |
%146 = tt.broadcast %143 : tensor<128x1xi32, #mma> -> tensor<128x256xi32, #mma> | |
%147 = tt.expand_dims %74 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma> | |
%148 = tt.broadcast %147 : tensor<1x256xi32, #mma> -> tensor<128x256xi32, #mma> | |
%149 = tt.addptr %144, %136 : !tt.ptr<f16>, i32 | |
%150 = arith.addi %148, %146 : tensor<128x256xi32, #mma> | |
%151 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #mma> | |
%152 = arith.cmpi slt, %139, %151 : tensor<128x1xi32, #mma> | |
%153 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #mma> | |
%154 = arith.cmpi slt, %145, %153 : tensor<1x256xi32, #mma> | |
%155 = tt.broadcast %152 : tensor<128x1xi1, #mma> -> tensor<128x256xi1, #mma> | |
%156 = tt.broadcast %154 : tensor<1x256xi1, #mma> -> tensor<128x256xi1, #mma> | |
%157 = arith.andi %155, %156 : tensor<128x256xi1, #mma> | |
%158 = arith.truncf %125 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> | |
amdgpu.buffer_store %158, %149[%150], %157 stride = %arg8 : tensor<128x256xf16, #mma> | |
} | |
%128 = arith.addi %arg17, %c1_i32 : i32 | |
%129 = arith.cmpi slt, %128, %c1_i32 : i32 | |
%130 = arith.select %129, %128, %c0_i32 : i32 | |
%131 = ttg.memdesc_subview %76[%130, %c0_i32, %c0_i32] : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> | |
ttg.local_store %108, %131 {OpIdx = #amdgpu.OpIdx<0>} : tensor<128x64xf16, #blocked1> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable> | |
%132 = ttg.memdesc_subview %77[%130, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> | |
ttg.local_store %124, %132 {OpIdx = #amdgpu.OpIdx<1>} : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> | |
scf.yield %90, %92#0, %92#1, %92#2, %127, %92#3, %92#4, %130, %131, %132 : i32, i32, i32, i32, tensor<128x256xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>, tensor<256xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> | |
} | |
%82 = arith.cmpi sge, %33, %c1_i32 : i32 | |
%83 = ttg.local_load %81#8 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> | |
%84 = ttg.local_load %81#9 : !ttg.memdesc<64x256xf16, #shared1, #smem, mutable> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> | |
%85 = scf.if %82 -> (tensor<128x256xf32, #mma>) { | |
%88 = tt.dot %83, %84, %81#4, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<128x256xf32, #mma> | |
scf.yield %88 : tensor<128x256xf32, #mma> | |
} else { | |
scf.yield %81#4 : tensor<128x256xf32, #mma> | |
} | |
%86 = arith.cmpi eq, %81#0, %75 : i32 | |
%87 = arith.andi %82, %86 : i1 | |
scf.if %87 { | |
%88 = arith.muli %81#2, %c128_i32 : i32 | |
%89 = tt.splat %88 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> | |
%90 = arith.addi %89, %73 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> | |
%91 = arith.muli %81#3, %c256_i32 : i32 | |
%92 = tt.splat %91 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> | |
%93 = arith.addi %92, %74 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> | |
%94 = tt.expand_dims %90 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> | |
%95 = tt.expand_dims %73 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xi32, #mma> | |
%96 = arith.muli %arg8, %88 : i32 | |
%97 = tt.splat %arg8 : i32 -> tensor<128x1xi32, #mma> | |
%98 = arith.muli %97, %95 : tensor<128x1xi32, #mma> | |
%99 = tt.addptr %arg2, %96 : !tt.ptr<f16>, i32 | |
%100 = tt.expand_dims %93 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma> | |
%101 = tt.broadcast %98 : tensor<128x1xi32, #mma> -> tensor<128x256xi32, #mma> | |
%102 = tt.expand_dims %74 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x256xi32, #mma> | |
%103 = tt.broadcast %102 : tensor<1x256xi32, #mma> -> tensor<128x256xi32, #mma> | |
%104 = tt.addptr %99, %91 : !tt.ptr<f16>, i32 | |
%105 = arith.addi %103, %101 : tensor<128x256xi32, #mma> | |
%106 = tt.splat %arg3 : i32 -> tensor<128x1xi32, #mma> | |
%107 = arith.cmpi slt, %94, %106 : tensor<128x1xi32, #mma> | |
%108 = tt.splat %arg4 : i32 -> tensor<1x256xi32, #mma> | |
%109 = arith.cmpi slt, %100, %108 : tensor<1x256xi32, #mma> | |
%110 = tt.broadcast %107 : tensor<128x1xi1, #mma> -> tensor<128x256xi1, #mma> | |
%111 = tt.broadcast %109 : tensor<1x256xi1, #mma> -> tensor<128x256xi1, #mma> | |
%112 = arith.andi %110, %111 : tensor<128x256xi1, #mma> | |
%113 = arith.truncf %85 : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> | |
amdgpu.buffer_store %113, %104[%105], %112 stride = %arg8 : tensor<128x256xf16, #mma> | |
} | |
ttg.local_dealloc %76 : !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> | |
ttg.local_dealloc %77 : !ttg.memdesc<1x64x256xf16, #shared1, #smem, mutable> | |
tt.return | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment