Last active
May 13, 2025 18:41
-
-
Save davidberard98/f10db5520c96111254e614b53db9f501 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_backward'] | |
from ctypes import c_void_p, c_long, c_int | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from cmath import nanj | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import ( | |
grid, | |
split_scan_grid, | |
grid_combo_kernels, | |
start_graph, | |
end_graph, | |
cooperative_reduction_grid, | |
) | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch._inductor.kernel.flex_attention | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/mc/cmc66sj3co2qioktlabn5mqewgz4bthhd7gyxhv7ppsvnpgk2niq.py | |
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data] | |
# Source node to ATen node mapping: | |
# add_116 => add_179 | |
# logits => convert_element_type_323 | |
# loss => full_default_12, full_default_13, sub_4, sub_5 | |
# mul_176 => mul_239 | |
# mul_177 => mul_240 | |
# rsqrt => rsqrt_63 | |
# square_16 => pow_80 | |
# Graph fragment: | |
# %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%tangents_1, %convert_element_type_324), kwargs = {}) | |
# %ne_3 : [num_users=2] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_164, -100), kwargs = {}) | |
# %full_default_12 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_6 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_164, %full_default_12), kwargs = {}) | |
# %scatter_upon_const_tensor : [num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [65536, 50304], background_val: 0, dtype: torch.float32, dim: 1, selector: %where_6, val: -1.0}) | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_7 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %div_2, %full_default_13), kwargs = {}) | |
# %mul_241 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter_upon_const_tensor, %where_7), kwargs = {}) | |
# %convert_element_type_323 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_62, torch.float32), kwargs = {}) | |
# %mul_239 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_323, 15), kwargs = {}) | |
# %pow_80 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 2), kwargs = {}) | |
# %add_179 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%pow_80, 225), kwargs = {}) | |
# %rsqrt_63 : [num_users=3] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_179,), kwargs = {}) | |
# %mul_240 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_239, %rsqrt_63), kwargs = {}) | |
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_240, %amax), kwargs = {}) | |
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_4, %log), kwargs = {}) | |
# %exp_1 : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {}) | |
# %sum_10 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_241, [1], True), kwargs = {}) | |
# %mul_242 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp_1, %sum_10), kwargs = {}) | |
# %sub_6 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_241, %mul_242), kwargs = {}) | |
# %mul_243 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %mul_239), kwargs = {}) | |
# %mul_244 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %rsqrt_63), kwargs = {}) | |
# %pow_81 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt_63, 3), kwargs = {}) | |
# %mul_245 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%mul_243, -0.5), kwargs = {}) | |
# %mul_246 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_245, %pow_81), kwargs = {}) | |
# %pow_82 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 1.0), kwargs = {}) | |
# %mul_247 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_82, 2.0), kwargs = {}) | |
# %mul_248 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_246, %mul_247), kwargs = {}) | |
# %mul_249 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_244, 15), kwargs = {}) | |
# %add_180 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_248, %mul_249), kwargs = {}) | |
# %convert_element_type_325 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_180, torch.bfloat16), kwargs = {}) | |
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0 = async_compile.triton('triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 65536, 'r0_': 65536}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 8, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 65536 | |
r0_numel = 50304 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) | |
rbase = r0_base | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
tmp10 = tl.load(in_ptr1 + (0)) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK]) | |
tmp12 = tl.load(in_ptr2 + (0)) | |
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) | |
_tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp1 = tl.full([1, 1], -100, tl.int64) | |
tmp2 = tmp0 != tmp1 | |
tmp3 = tl.full([1, 1], 0, tl.int64) | |
tmp4 = tl.where(tmp2, tmp0, tmp3) | |
tmp5 = r0_1 | |
tmp6 = tmp4 == tmp5 | |
tmp7 = -1.0 | |
tmp8 = 0.0 | |
tmp9 = tl.where(tmp6, tmp7, tmp8) | |
tmp14 = (tmp11 / tmp13) | |
tmp15 = tl.where(tmp2, tmp14, tmp8) | |
tmp16 = tmp9 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) | |
tmp19 = _tmp18 + tmp17 | |
_tmp18 = tl.where(r0_mask, tmp19, _tmp18) | |
tmp18 = tl.sum(_tmp18, 1)[:, None] | |
tmp29 = tl.load(in_ptr1 + (0)) | |
tmp30 = tl.broadcast_to(tmp29, [XBLOCK, R0_BLOCK]) | |
tmp31 = tl.load(in_ptr2 + (0)) | |
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK]) | |
tmp45 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp47 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp36 = tl.load(in_out_ptr0 + (r0_1 + 50304*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp20 = tl.full([1, 1], -100, tl.int64) | |
tmp21 = tmp0 != tmp20 | |
tmp22 = tl.full([1, 1], 0, tl.int64) | |
tmp23 = tl.where(tmp21, tmp0, tmp22) | |
tmp24 = r0_1 | |
tmp25 = tmp23 == tmp24 | |
tmp26 = -1.0 | |
tmp27 = 0.0 | |
tmp28 = tl.where(tmp25, tmp26, tmp27) | |
tmp33 = (tmp30 / tmp32) | |
tmp34 = tl.where(tmp21, tmp33, tmp27) | |
tmp35 = tmp28 * tmp34 | |
tmp37 = tmp36.to(tl.float32) | |
tmp38 = 15.0 | |
tmp39 = tmp37 * tmp38 | |
tmp40 = tmp37 * tmp37 | |
tmp41 = 225.0 | |
tmp42 = tmp40 + tmp41 | |
tmp43 = libdevice.rsqrt(tmp42) | |
tmp44 = tmp39 * tmp43 | |
tmp46 = tmp44 - tmp45 | |
tmp48 = tmp46 - tmp47 | |
tmp49 = tl_math.exp(tmp48) | |
tmp50 = tmp49 * tmp18 | |
tmp51 = tmp35 - tmp50 | |
tmp52 = tmp51 * tmp39 | |
tmp53 = -0.5 | |
tmp54 = tmp52 * tmp53 | |
tmp55 = tmp43 * tmp43 | |
tmp56 = tmp55 * tmp43 | |
tmp57 = tmp54 * tmp56 | |
tmp58 = 2.0 | |
tmp59 = tmp37 * tmp58 | |
tmp60 = tmp57 * tmp59 | |
tmp61 = tmp51 * tmp43 | |
tmp62 = tmp61 * tmp38 | |
tmp63 = tmp60 + tmp62 | |
tmp64 = tmp63.to(tl.float32) | |
tl.store(in_out_ptr0 + (r0_1 + 50304*x0), tmp64, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/hc/chcgxykfor4ysb6yjas3ttkjs6ocau7waeacbq3dzddsuruaodll.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %convert_element_type_330 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_63, torch.float32), kwargs = {}) | |
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 51511296 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/rj/crjpkoeeqp2o6eema7exap5o6s65g7qiso5aoy7y5k5i6vlpjzmf.py | |
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# x_144 => convert_element_type_318 | |
# Graph fragment: | |
# %convert_element_type_331 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_184, torch.float32), kwargs = {}) | |
# %convert_element_type_318 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_177, torch.float32), kwargs = {}) | |
# %mul_250 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %convert_element_type_318), kwargs = {}) | |
# %mul_251 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %rsqrt_62), kwargs = {}) | |
# %sum_11 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_250, [2], True), kwargs = {}) | |
# %div_3 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_20, 1024), kwargs = {}) | |
# %pow_84 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_318, 1.0), kwargs = {}) | |
# %mul_254 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_84, 2.0), kwargs = {}) | |
# %mul_255 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %mul_254), kwargs = {}) | |
# %add_181 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_251, %mul_255), kwargs = {}) | |
# %convert_element_type_332 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_181, torch.bfloat16), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp9 = tmp1 * tmp8 | |
tmp10 = -0.5 | |
tmp11 = tmp7 * tmp10 | |
tmp12 = tmp8 * tmp8 | |
tmp13 = tmp12 * tmp8 | |
tmp14 = tmp11 * tmp13 | |
tmp15 = 0.0009765625 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = 2.0 | |
tmp18 = tmp3 * tmp17 | |
tmp19 = tmp16 * tmp18 | |
tmp20 = tmp9 + tmp19 | |
tmp21 = tmp20.to(tl.float32) | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp21, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/3h/c3hwz76pisnffnc7nnxhxwrt6t4rxytpglcpb6fhj2bl3pieorlm.py | |
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward] | |
# Source node to ATen node mapping: | |
# relu_15 => relu_15 | |
# Graph fragment: | |
# %relu_15 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%view_180,), kwargs = {}) | |
# %pow_85 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%relu_15, 1.0), kwargs = {}) | |
# %mul_256 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_85, 2.0), kwargs = {}) | |
# %mul_257 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_186, %mul_256), kwargs = {}) | |
# %le_1 : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu_15, 0), kwargs = {}) | |
# %full_default_17 : [num_users=16] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le_1, %full_default_17, %mul_257), kwargs = {}) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3 = async_compile.triton('triton_poi_fused_mul_pow_relu_threshold_backward_3', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 268435456}, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_pow_relu_threshold_backward_3', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_mul_pow_relu_threshold_backward_3(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 268435456 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.full([1], 0, tl.int32) | |
tmp2 = triton_helpers.maximum(tmp1, tmp0) | |
tmp3 = 0.0 | |
tmp4 = tmp2 <= tmp3 | |
tmp6 = 2.0 | |
tmp7 = tmp2 * tmp6 | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.where(tmp4, tmp3, tmp8) | |
tl.store(in_out_ptr0 + (x0), tmp9, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ha/chamkr3xbm5j7zmqnyuw7rzcqqcagrpu5phejb47d5uhs5kil4oo.py | |
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# rms_norm_61 => convert_element_type_312 | |
# Graph fragment: | |
# %convert_element_type_341 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_188, torch.float32), kwargs = {}) | |
# %convert_element_type_312 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_175, torch.float32), kwargs = {}) | |
# %mul_258 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %convert_element_type_312), kwargs = {}) | |
# %mul_259 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %rsqrt_61), kwargs = {}) | |
# %sum_12 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_258, [2], True), kwargs = {}) | |
# %div_4 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_21, 1024), kwargs = {}) | |
# %pow_87 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 1.0), kwargs = {}) | |
# %mul_262 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_87, 2.0), kwargs = {}) | |
# %mul_263 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_4, %mul_262), kwargs = {}) | |
# %add_182 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_259, %mul_263), kwargs = {}) | |
# %convert_element_type_342 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_182, torch.bfloat16), kwargs = {}) | |
# %add_183 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_332, %convert_element_type_342), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_4', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_4', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_4(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp10 = tmp1 * tmp9 | |
tmp11 = -0.5 | |
tmp12 = tmp7 * tmp11 | |
tmp13 = tmp9 * tmp9 | |
tmp14 = tmp13 * tmp9 | |
tmp15 = tmp12 * tmp14 | |
tmp16 = 0.0009765625 | |
tmp17 = tmp15 * tmp16 | |
tmp18 = 2.0 | |
tmp19 = tmp3 * tmp18 | |
tmp20 = tmp17 * tmp19 | |
tmp21 = tmp10 + tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp8 + tmp22 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/js/cjsujn4ityfatbyjrtktt6lsm5m7xfr7x34cjcfu46bev6xvbhap.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {}) | |
triton_per_fused_zeros_5 = async_compile.triton('triton_per_fused_zeros_5', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_zeros_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused_zeros_5(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
R0_BLOCK: tl.constexpr = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
r0_offset = 0 | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 128*x0), None).to(tl.float32) | |
tmp2 = tmp0 * tmp1 | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) | |
tmp5 = tl.sum(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp5, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/yq/cyqtxyu2pr6foycexsqodih4tqalw2u65pbtgsks4qsewqny5wvg.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {}) | |
triton_poi_fused_zeros_6 = async_compile.triton('triton_poi_fused_zeros_6', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'y': 8, 'x': 65536}, tile_hint=TileHint.SQUARE, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_zeros_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_zeros_6(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): | |
ynumel = 8 | |
xnumel = 65536 | |
yoffset = tl.program_id(1) * YBLOCK | |
yindex = yoffset + tl.arange(0, YBLOCK)[None, :] | |
ymask = yindex < ynumel | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, YBLOCK], True, tl.int1) | |
x1 = xindex | |
y0 = yindex | |
tmp0 = tl.load(in_ptr0 + (y0 + 8*x1), ymask, eviction_policy='evict_last').to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = 0.0 | |
tmp3 = tmp1 - tmp2 | |
tl.store(out_ptr0 + (x1 + 65536*y0), tmp3, ymask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/3y/c3yyb2vkcu6dpcv66oqlrqufxdjxjeocdhudp7s5hgiowvxx7smq.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {}) | |
triton_tem_fused_zeros_7 = async_compile.triton('triton_tem_fused_zeros_7', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=3, | |
num_warps=8, | |
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_zeros_7', 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
) | |
@triton.jit | |
def triton_tem_fused_zeros_7(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
Q = arg_Q | |
K = arg_K | |
V = arg_V | |
LSE = arg_LSE | |
DELTA = arg_DELTA | |
DO = arg_DO | |
DQ = arg_DQ | |
DV = arg_DV | |
KV_NUM_BLKS = arg_KV_NUM_BLKS | |
KV_IDX = arg_KV_IDX | |
Q_NUM_BLKS = arg_Q_NUM_BLKS | |
Q_IDX = arg_Q_IDX | |
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS | |
FULL_KV_IDX = arg_FULL_KV_IDX | |
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS | |
FULL_Q_IDX = arg_FULL_Q_IDX | |
# Sub notation for this kernel: | |
# | |
# Q: Query, K: Key, V: Value | |
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) | |
# DELTA: Precomputed sum(OUT*DO, axis=-1) | |
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value | |
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with | |
# inductor codegen | |
# M: Number of queries, N: Number of keys/values | |
# QK_HEAD_DIM: The dimension of the query and key embeddings | |
# V_HEAD_DIM: The dimension of the value embeddings | |
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim | |
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. | |
# (Modifiable) Performance tuning options | |
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. | |
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. | |
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. | |
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. | |
# | |
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. | |
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. | |
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. | |
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. | |
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. | |
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. | |
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. | |
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. | |
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. | |
# The below are kernel options that can be applied for certain score_mods, | |
# or involve a numerics vs. perf tradeoff | |
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has | |
# about 20% more numerical error, but slightly faster. | |
# Define strides of inputs | |
stride_qz, stride_qh, stride_qm, stride_qd = 67108864, 128, 1024, 1 | |
stride_kz, stride_kh, stride_kn, stride_kd = 67108864, 128, 1024, 1 | |
stride_vz, stride_vh, stride_vn, stride_vd = 67108864, 128, 1024, 1 | |
stride_doz, stride_doh, stride_dom, stride_dod = 67108864, 128, 1024, 1 | |
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 67108864, 128, 1024, 1 | |
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 67108864, 128, 1024, 1 | |
ZQ = 1 | |
HQ = 8 | |
HKV = 8 | |
Q_LEN = 65536 | |
ZKV = 1 | |
KV_LEN = 65536 | |
MATMUL_PRECISION = Q.dtype.element_ty | |
pid = tl.program_id(0) | |
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) | |
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) | |
off_hz = tl.program_id(2) | |
off_zq = off_hz // HKV # q batch idx | |
off_hkv = off_hz % HKV # kv head idx | |
off_zkv = off_zq % ZKV # kv batch idx | |
SPARSE_Z = 1 | |
SPARSE_HQ = 1 | |
sparse_idx_z = off_zq % SPARSE_Z | |
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) | |
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) | |
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) | |
# offset K, V, DV pointers for batch/kv-head | |
K += k_adj | |
V += v_adj | |
DV += dv_adj | |
RCP_LN2 = 1.44269504 | |
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
if pid >= NUM_KV_BLOCKS: | |
off_pid = pid - NUM_KV_BLOCKS | |
# THIS BLOCK DOES DQ | |
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) | |
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS | |
start_m2_block = off_pid % NUM_Q_BLOCKS | |
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE | |
stride_kv_num_blks_h = 512 | |
stride_kv_idx_h = 262144 | |
stride_kv_idx_m = 512 | |
sparse_idx_hq2 = off_hq2 % SPARSE_HQ | |
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 | |
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask | |
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 | |
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. | |
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) | |
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) | |
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) | |
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) | |
Q2 = Q + q_adj2 | |
DO2 = DO + do_adj2 | |
# TODO: This does not work if DQ is not the same layout as Q (for example, | |
# if Q is broadcasted) | |
DQ2 = DQ + dq_adj2 | |
LSE2 = LSE + off_chz2 | |
DELTA2 = DELTA + off_chz2 | |
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) | |
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
start_m2 = start_m2_block * BLOCK_M2 | |
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) | |
# load Q and do: they stay in SRAM throughout the inner loop. | |
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) | |
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
if PRESCALE_QK: | |
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
if IS_DIVISIBLE: | |
Di = tl.load(DELTA2 + offs_m2) | |
lse = tl.load(LSE2 + offs_m2) | |
else: | |
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) | |
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) | |
lse = tl.where(lse == -float("inf"), 0.0, lse) | |
lse = lse[:, None] | |
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# KV_IDX and KV_NUM_BLKS are always contiguous. | |
kv_indices = KV_IDX + sparse_kv_idx_offset | |
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
dq = bwd_dq_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
K, V, | |
dq, q, do, Di, lse, | |
off_zq, off_hq2, offs_m2, offs_n2, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=False, | |
) | |
if HAS_FULL_BLOCKS: | |
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. | |
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset | |
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
dq = bwd_dq_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
K, V, | |
dq, q, do, Di, lse, | |
off_zq, off_hq2, offs_m2, offs_n2, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=True, | |
) | |
# Write back dQ. | |
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd | |
dq *= SM_SCALE | |
if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
tl.store(dq_ptrs, dq) | |
else: | |
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) | |
else: | |
# THIS BLOCK DOES DK & DV | |
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) | |
pid_mask = pid // SPARSE_KV_MULTIPLE | |
stride_q_num_blks_h = 512 | |
stride_q_idx_h = 262144 | |
stride_q_idx_n = 512 | |
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
start_n1 = pid * BLOCK_N1 | |
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) | |
# load K and V: they stay in SRAM throughout the inner loop. | |
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) | |
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) | |
if PRESCALE_QK: | |
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
for off_g in range(0, GQA_SHARED_HEADS): | |
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g | |
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. | |
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) | |
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) | |
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) | |
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) | |
Q1 = Q + q_adj1 | |
DO1 = DO + do_adj1 | |
# TODO: This does not work if DQ is not the same layout as Q (for example, | |
# if Q is broadcasted) | |
LSE1 = LSE + off_chz1 | |
DELTA1 = DELTA + off_chz1 | |
sparse_idx_hq1 = off_hq1 % SPARSE_HQ | |
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 | |
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask | |
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 | |
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Q_IDX and Q_NUM_BLKS are always contiguous. | |
q_indices = Q_IDX + sparse_q_idx_offset | |
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) | |
offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
dk, dv = bwd_dkdv_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
Q1, DO1, DELTA1, LSE1, | |
dk, dv, k, v, | |
off_zq, off_hq1, offs_n1, offs_m1, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=False, | |
) | |
if HAS_FULL_BLOCKS: | |
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. | |
q_indices = FULL_Q_IDX + sparse_q_idx_offset | |
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) | |
offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
dk, dv = bwd_dkdv_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
Q1, DO1, DELTA1, LSE1, | |
dk, dv, k, v, | |
off_zq, off_hq1, offs_n1, offs_m1, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=True, | |
) | |
# Write back dV and dK. | |
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd | |
index_n = offs_n1[:, None] | |
index_k = offs_k[None, :] | |
index_v = offs_v[None, :] | |
if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
tl.store(dv_ptrs, dv) | |
else: | |
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) | |
dk *= SM_SCALE | |
if SAFE_HEAD_DIM: | |
mask = index_n < KV_LEN | |
else: | |
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) | |
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
xindex = index_k + 128*index_n + 8388608*off_hkv + 67108864*off_zq | |
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) | |
@triton.jit | |
def bwd_dq_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
K, V, # pointers | |
dq, q, do, Di, lse, | |
off_z, off_hq, offs_m2, offs_n2, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
RCP_LN2: tl.constexpr = 1.44269504 | |
Q_LEN = 65536 | |
KV_LEN = 65536 | |
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd | |
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd | |
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) | |
if not IS_DIVISIBLE: | |
if hi >= 1: | |
for start_n in range(0, hi - 1): | |
dq = bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_n, kv_indices, sparse_kv_num_blocks, | |
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
) | |
kT_ptrs += offset * stride_kn | |
vT_ptrs += offset * stride_vn | |
offs_n2 += offset | |
dq = bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
) | |
else: | |
for start_n in range(0, hi): | |
dq = bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_n, kv_indices, sparse_kv_num_blocks, | |
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
) | |
kT_ptrs += offset * stride_kn | |
vT_ptrs += offset * stride_vn | |
offs_n2 += offset | |
return dq | |
@triton.jit | |
def bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
# NB reversed order to since K is transposed | |
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) | |
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) | |
if not PRESCALE_QK: | |
qk *= SM_SCALE | |
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
pre_mod_scores = qk | |
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) | |
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim | |
# that the M reads out of bounds prior to the last loop | |
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) | |
tmp0 = (qk) | |
post_mod_scores = tmp0 | |
if CHECK_BLOCK_BOUNDARY: | |
# Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) | |
if not IS_FULL_BLOCKS: | |
tmp1 = (m) | |
tmp2 = (n) | |
tmp3 = tmp1 >= tmp2 | |
tmp4 = tl.load(in_ptr16 + tmp1) | |
tmp5 = tl.load(in_ptr16 + tmp2) | |
tmp6 = tmp4 == tmp5 | |
tmp7 = tmp3 & tmp6 | |
mask_mod_output = tmp7 | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) | |
# apply mask for partial masked block | |
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if not PRESCALE_QK: | |
post_mod_scores *= RCP_LN2 | |
p = tl.math.exp2(post_mod_scores - lse) | |
# Compute dP and dS. | |
# NB reversed order to since V is transposed | |
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) | |
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) | |
ds = p * (dp - Di[:, None]) | |
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
tmp8 = (ds) | |
grad_scores = tmp8 | |
if CHECK_BLOCK_BOUNDARY: | |
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) | |
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
if WRITE_DQ: | |
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
ds = grad_scores | |
if not IS_FULL_BLOCKS: | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) | |
# (grads) apply mask for partially unmasked block | |
ds = tl.where(mask_mod_output, ds, 0.0) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
ds = ds.to(MATMUL_PRECISION) | |
# Compute dQ. | |
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) | |
return dq | |
@triton.jit | |
def bwd_dkdv_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
Q, DO, DELTA, LSE, # pointers | |
dk, dv, k, v, | |
off_z, off_hq, offs_n1, offs_m1, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
RCP_LN2: tl.constexpr = 1.44269504 | |
Q_LEN = 65536 | |
KV_LEN = 65536 | |
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd | |
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod | |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) | |
if not IS_DIVISIBLE: | |
if hi >= 1: | |
for start_m in range(0, hi - 1): | |
dk, dv = bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_m, q_indices, sparse_q_num_blocks, | |
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
) | |
qT_ptrs += offset * stride_qm | |
do_ptrs += offset * stride_dom | |
offs_m1 += offset | |
dk, dv = bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
) | |
else: | |
for start_m in range(0, hi): | |
dk, dv = bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_m, q_indices, sparse_q_num_blocks, | |
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
) | |
qT_ptrs += offset * stride_qm | |
do_ptrs += offset * stride_dom | |
offs_m1 += offset | |
return dk, dv | |
@triton.jit | |
def bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
# NB reversed order since Q is transposed | |
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) | |
# Load LSE before computing qk to reduce pipeline stall. | |
if IS_DIVISIBLE: | |
lse = tl.load(LSE + offs_m1) | |
else: | |
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) | |
lse = tl.where(lse == -float("inf"), 0.0, lse) | |
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) | |
if not PRESCALE_QK: | |
qkT *= SM_SCALE | |
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) | |
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim | |
# that the n reads out of bounds prior to the last loop | |
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) | |
pre_mod_scores = qkT | |
tmp9 = (qkT) | |
post_mod_scores = tmp9 | |
if CHECK_BLOCK_BOUNDARY: | |
# Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) | |
if not IS_FULL_BLOCKS: | |
tmp10 = (m) | |
tmp11 = (n) | |
tmp12 = tmp10 >= tmp11 | |
tmp13 = tl.load(in_ptr16 + tmp10) | |
tmp14 = tl.load(in_ptr16 + tmp11) | |
tmp15 = tmp13 == tmp14 | |
tmp16 = tmp12 & tmp15 | |
mask_mod_output = tmp16 | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) | |
# (grads) apply mask for fully masked block | |
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if not PRESCALE_QK: | |
post_mod_scores *= RCP_LN2 | |
pT = tl.math.exp2(post_mod_scores - lse[None, :]) | |
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
# Compute dV. | |
ppT = pT | |
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) | |
if IS_DIVISIBLE: | |
Di = tl.load(DELTA + offs_m1) | |
else: | |
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) | |
# Compute dP and dS. | |
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) | |
dsT = pT * (dpT - Di[None, :]) | |
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
tmp17 = (dsT) | |
grad_scores = tmp17 | |
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
if not WRITE_DQ: | |
idx_b = off_z | |
idx_h = off_hq | |
idx_m = m | |
idx_n = n | |
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if CHECK_BLOCK_BOUNDARY: | |
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) | |
dsT = grad_scores | |
if not IS_FULL_BLOCKS: | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) | |
# (grads) apply mask for partially unmasked block | |
dsT = tl.where(mask_mod_output, dsT, 0.0) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) | |
return dk, dv | |
@triton.jit | |
def get_offset_for_next_block( | |
loop_iter, col_indices, total_blocks, | |
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, | |
BLOCKS_ARE_CONTIGUOUS: tl.constexpr | |
): | |
if BLOCKS_ARE_CONTIGUOUS: | |
return BLOCK | |
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE | |
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") | |
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) | |
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 | |
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK | |
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK | |
return offset | |
@triton.jit | |
def get_bounded_indices(indices, max_len=None): | |
return indices % max_len if max_len is not None else indices | |
@triton.jit | |
def load_checked_2d( | |
ptr, | |
offs_m, | |
offs_n, | |
stride_m, | |
stride_n, | |
IS_DIVISIBLE_M: tl.constexpr, | |
IS_DIVISIBLE_N: tl.constexpr, | |
M_LEN: tl.constexpr, | |
N_DIM: tl.constexpr, | |
): | |
# Calculate final pointer if strides are provided | |
if stride_m is not None and stride_n is not None: | |
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | |
# Handle all masking cases | |
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) | |
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) | |
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: | |
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) | |
else: # Both divisible | |
return tl.load(ptr) | |
''', device_str='cuda') | |
meta0 = {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.12, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128} | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/zf/czf7pjyultaixltpmtzbkgpcs6y2wothlasa3ypnkt5l3damv72z.py | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
# Source node to ATen node mapping: | |
# v_43 => convert_element_type_308, convert_element_type_309, mul_234 | |
# Graph fragment: | |
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {}) | |
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {}) | |
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {}) | |
# %mul_234 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_308, %rsqrt_60), kwargs = {}) | |
# %convert_element_type_309 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_234, torch.bfloat16), kwargs = {}) | |
# %mul_267 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %convert_element_type_309), kwargs = {}) | |
# %sum_14 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_267,), kwargs = {}) | |
triton_red_fused__to_copy_mul_sum_8 = async_compile.triton('triton_red_fused__to_copy_mul_sum_8', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 512, 'r0_': 131072}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_mul_sum_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 512 | |
r0_numel = 131072 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x0 = xindex | |
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
_tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp8 = tl.load(in_ptr3 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0) | |
tmp2 = tmp0 * tmp1 | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) | |
tmp5 = _tmp4 + tmp3 | |
_tmp4 = tl.where(xmask, tmp5, _tmp4) | |
tmp7 = tmp6.to(tl.float32) | |
tmp9 = tmp7 * tmp8 | |
tmp10 = tmp9.to(tl.float32) | |
tmp11 = tmp0 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(xmask, tmp14, _tmp13) | |
tmp4 = tl.sum(_tmp4, 1)[:, None] | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp4, xmask) | |
tl.store(out_ptr1 + (x0), tmp13, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/id/cidd6fgxlf7gs3tonqx5yreybf2dwl3utnhsnlnvarag43y63rtu.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {}) | |
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {}) | |
triton_per_fused_mul_sum_9 = async_compile.triton('triton_per_fused_mul_sum_9', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 1, 'r0_': 512}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 3), 'tt.equal_to': (2,)}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_sum_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused_mul_sum_9(in_ptr0, out_ptr0, xnumel, r0_numel): | |
xnumel = 1 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 512 | |
R0_BLOCK: tl.constexpr = 512 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_0 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_0), None) | |
tmp1 = tl.broadcast_to(tmp0, [R0_BLOCK]) | |
tmp3 = triton_helpers.promote_to_tensor(tl.sum(tmp1, 0)) | |
tl.store(out_ptr0 + (tl.full([1], 0, tl.int32)), tmp3, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/vc/cvckp5jouod57mmfqeg2deebnbb7ny7dnkuh4ptuho2w6tgsf36c.py | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_43 => convert_element_type_308 | |
# Graph fragment: | |
# %mul_266 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %select_101), kwargs = {}) | |
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {}) | |
# %convert_element_type_349 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_266, torch.float32), kwargs = {}) | |
# %mul_268 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %convert_element_type_308), kwargs = {}) | |
# %mul_269 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %rsqrt_60), kwargs = {}) | |
# %sum_15 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_268, [3], True), kwargs = {}) | |
# %div_5 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_22, 128), kwargs = {}) | |
# %pow_89 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_308, 1.0), kwargs = {}) | |
# %mul_272 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_89, 2.0), kwargs = {}) | |
# %mul_273 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_5, %mul_272), kwargs = {}) | |
# %add_185 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_269, %mul_273), kwargs = {}) | |
# %convert_element_type_350 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_185, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_10 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_10', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (78)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (78)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/sw/cswnxerjagcklzy6jfmvlvkyel6hqazhxrakyffns5ponz5c5yzq.py | |
# Topologically Sorted Source Nodes: [k_43, q_43], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# k_43 => convert_element_type_302 | |
# q_43 => convert_element_type_300 | |
# Graph fragment: | |
# %cat_30 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_187, %add_186], 3), kwargs = {}) | |
# %cat_31 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_189, %add_188], 3), kwargs = {}) | |
# %convert_element_type_302 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_139, torch.float32), kwargs = {}) | |
# %mul_282 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %convert_element_type_302), kwargs = {}) | |
# %mul_283 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %rsqrt_59), kwargs = {}) | |
# %sum_16 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_282, [3], True), kwargs = {}) | |
# %div_6 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_23, 128), kwargs = {}) | |
# %pow_91 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_302, 1.0), kwargs = {}) | |
# %mul_286 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_91, 2.0), kwargs = {}) | |
# %mul_287 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_6, %mul_286), kwargs = {}) | |
# %add_190 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_283, %mul_287), kwargs = {}) | |
# %convert_element_type_356 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_190, torch.bfloat16), kwargs = {}) | |
# %convert_element_type_300 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_138, torch.float32), kwargs = {}) | |
# %mul_288 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %convert_element_type_300), kwargs = {}) | |
# %mul_289 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %rsqrt_58), kwargs = {}) | |
# %sum_17 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_288, [3], True), kwargs = {}) | |
# %div_7 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_24, 128), kwargs = {}) | |
# %pow_93 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_300, 1.0), kwargs = {}) | |
# %mul_292 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_93, 2.0), kwargs = {}) | |
# %mul_293 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_7, %mul_292), kwargs = {}) | |
# %add_191 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_289, %mul_293), kwargs = {}) | |
# %convert_element_type_358 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_191, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11 = async_compile.triton('triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr3': '*bf16', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 21, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
x1 = xindex // 8 | |
x0 = (xindex % 8) | |
_tmp55 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp51 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp0 = r0_2 | |
tmp1 = tl.full([1, 1], 0, tl.int64) | |
tmp2 = tmp0 >= tmp1 | |
tmp3 = tl.full([1, 1], 64, tl.int64) | |
tmp4 = tmp0 < tmp3 | |
tmp5 = tl.load(in_ptr0 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tl.load(in_ptr1 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0) | |
tmp8 = -tmp7 | |
tmp9 = tmp6 * tmp8 | |
tmp10 = tl.load(in_ptr0 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0) | |
tmp13 = tmp11 * tmp12 | |
tmp14 = tmp9 + tmp13 | |
tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype) | |
tmp16 = tl.where(tmp4, tmp14, tmp15) | |
tmp17 = tmp0 >= tmp3 | |
tmp18 = tl.full([1, 1], 128, tl.int64) | |
tmp19 = tmp0 < tmp18 | |
tmp20 = tl.load(in_ptr0 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp21 = tmp20.to(tl.float32) | |
tmp22 = tl.load(in_ptr2 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0) | |
tmp23 = tmp21 * tmp22 | |
tmp24 = tl.load(in_ptr0 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = tl.load(in_ptr1 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0) | |
tmp27 = tmp25 * tmp26 | |
tmp28 = tmp23 + tmp27 | |
tmp29 = tl.full(tmp28.shape, 0.0, tmp28.dtype) | |
tmp30 = tl.where(tmp17, tmp28, tmp29) | |
tmp31 = tl.where(tmp4, tmp16, tmp30) | |
tmp32 = tl.load(in_ptr3 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp33 = tmp32.to(tl.float32) | |
tmp34 = tmp33 * tmp8 | |
tmp35 = tl.load(in_ptr3 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tmp36 * tmp12 | |
tmp38 = tmp34 + tmp37 | |
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype) | |
tmp40 = tl.where(tmp4, tmp38, tmp39) | |
tmp41 = tl.load(in_ptr3 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp42 * tmp22 | |
tmp44 = tl.load(in_ptr3 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp45 = tmp44.to(tl.float32) | |
tmp46 = tmp45 * tmp26 | |
tmp47 = tmp43 + tmp46 | |
tmp48 = tl.full(tmp47.shape, 0.0, tmp47.dtype) | |
tmp49 = tl.where(tmp17, tmp47, tmp48) | |
tmp50 = tl.where(tmp4, tmp40, tmp49) | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp31 * tmp52 | |
tmp54 = tl.broadcast_to(tmp53, [XBLOCK, R0_BLOCK]) | |
tmp56 = _tmp55 + tmp54 | |
_tmp55 = tl.where(r0_mask, tmp56, _tmp55) | |
tl.store(out_ptr0 + (r0_2 + 128*x3), tmp31, r0_mask) | |
tl.store(out_ptr1 + (r0_2 + 128*x3), tmp50, r0_mask) | |
tmp55 = tl.sum(_tmp55, 1)[:, None] | |
tmp58 = tl.load(in_ptr5 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp57 = tl.load(out_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0) | |
tmp67 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp59 = tmp57 * tmp58 | |
tmp60 = -0.5 | |
tmp61 = tmp55 * tmp60 | |
tmp62 = tmp58 * tmp58 | |
tmp63 = tmp62 * tmp58 | |
tmp64 = tmp61 * tmp63 | |
tmp65 = 0.0078125 | |
tmp66 = tmp64 * tmp65 | |
tmp68 = tmp67.to(tl.float32) | |
tmp69 = 2.0 | |
tmp70 = tmp68 * tmp69 | |
tmp71 = tmp66 * tmp70 | |
tmp72 = tmp59 + tmp71 | |
tmp73 = tmp72.to(tl.float32) | |
tl.store(out_ptr3 + (r0_2 + 128*x0 + 3072*x1), tmp73, r0_mask) | |
_tmp79 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp74 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0) | |
tmp75 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp76 = tmp75.to(tl.float32) | |
tmp77 = tmp74 * tmp76 | |
tmp78 = tl.broadcast_to(tmp77, [XBLOCK, R0_BLOCK]) | |
tmp80 = _tmp79 + tmp78 | |
_tmp79 = tl.where(r0_mask, tmp80, _tmp79) | |
tmp79 = tl.sum(_tmp79, 1)[:, None] | |
tmp82 = tl.load(in_ptr7 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp81 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0) | |
tmp91 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp83 = tmp81 * tmp82 | |
tmp84 = -0.5 | |
tmp85 = tmp79 * tmp84 | |
tmp86 = tmp82 * tmp82 | |
tmp87 = tmp86 * tmp82 | |
tmp88 = tmp85 * tmp87 | |
tmp89 = 0.0078125 | |
tmp90 = tmp88 * tmp89 | |
tmp92 = tmp91.to(tl.float32) | |
tmp93 = 2.0 | |
tmp94 = tmp92 * tmp93 | |
tmp95 = tmp90 * tmp94 | |
tmp96 = tmp83 + tmp95 | |
tmp97 = tmp96.to(tl.float32) | |
tl.store(out_ptr5 + (r0_2 + 128*x0 + 3072*x1), tmp97, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/5d/c5dlh2bjlk4ei2idmqft6cji5oj5hyp7yehj6tk5eb5qsk2hiubd.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_18 : [num_users=30] = call_function[target=torch.ops.aten.full.default](args = ([4, 1024, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_18, %mm_69, 0, 3), kwargs = {}) | |
# %slice_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_18, %view_196, 0, 0, 3), kwargs = {}) | |
# %add_193 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default, %slice_scatter_default), kwargs = {}) | |
triton_poi_fused_add_select_backward_12 = async_compile.triton('triton_poi_fused_add_select_backward_12', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 4194304}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_select_backward_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_select_backward_12(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 4194304 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x1 = xindex // 1048576 | |
x0 = (xindex % 1048576) | |
x2 = xindex | |
tmp3 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32) | |
tmp0 = x1 | |
tmp1 = tl.full([1], 3, tl.int32) | |
tmp2 = tmp0 == tmp1 | |
tmp4 = 0.0 | |
tmp5 = tl.where(tmp2, tmp3, tmp4) | |
tmp6 = tl.full([1], 3, tl.int64) | |
tmp7 = tmp0 < tmp6 | |
tmp8 = tl.load(in_ptr1 + (x2), tmp7, other=0.0).to(tl.float32) | |
tmp9 = tl.where(tmp7, tmp8, tmp4) | |
tmp10 = tmp5 + tmp9 | |
tl.store(out_ptr0 + (x2), tmp10, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/rh/crhlyhq55hgkmgz65givk2zndj2w4bvvle5lwhpxcz4poewwdmnz.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %mul_296 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_99), kwargs = {}) | |
triton_poi_fused_add_mul_13 = async_compile.triton('triton_poi_fused_add_mul_13', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_13(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (46)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/hc/chc7hawo3j62h3g5bffii536wpkihnsguegjwg6n26mklfjjcjn3.py | |
# Topologically Sorted Source Nodes: [rms_norm_57], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# rms_norm_57 => convert_element_type_292 | |
# Graph fragment: | |
# %convert_element_type_373 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_200, torch.float32), kwargs = {}) | |
# %convert_element_type_292 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_163, torch.float32), kwargs = {}) | |
# %mul_300 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %convert_element_type_292), kwargs = {}) | |
# %mul_301 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %rsqrt_57), kwargs = {}) | |
# %sum_20 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_300, [2], True), kwargs = {}) | |
# %div_8 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_25, 1024), kwargs = {}) | |
# %pow_96 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_292, 1.0), kwargs = {}) | |
# %mul_304 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_96, 2.0), kwargs = {}) | |
# %mul_305 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_8, %mul_304), kwargs = {}) | |
# %add_195 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_301, %mul_305), kwargs = {}) | |
# %convert_element_type_374 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_195, torch.bfloat16), kwargs = {}) | |
# %add_196 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_296, %convert_element_type_374), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_14', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_14(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp10 = tmp1 * tmp9 | |
tmp11 = -0.5 | |
tmp12 = tmp7 * tmp11 | |
tmp13 = tmp9 * tmp9 | |
tmp14 = tmp13 * tmp9 | |
tmp15 = tmp12 * tmp14 | |
tmp16 = 0.0009765625 | |
tmp17 = tmp15 * tmp16 | |
tmp18 = 2.0 | |
tmp19 = tmp3 * tmp18 | |
tmp20 = tmp17 * tmp19 | |
tmp21 = tmp10 + tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp8 + tmp22 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/dq/cdqthdpfu5gimvpxpp775dnqoysmbyfs3mtk4f5xsqfjde6wil3e.py | |
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_40 => convert_element_type_288 | |
# Graph fragment: | |
# %mul_308 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_164, %select_94), kwargs = {}) | |
# %convert_element_type_288 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_131, torch.float32), kwargs = {}) | |
# %convert_element_type_381 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_308, torch.float32), kwargs = {}) | |
# %mul_310 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %convert_element_type_288), kwargs = {}) | |
# %mul_311 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %rsqrt_56), kwargs = {}) | |
# %sum_23 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_310, [3], True), kwargs = {}) | |
# %div_9 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_26, 128), kwargs = {}) | |
# %pow_98 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_288, 1.0), kwargs = {}) | |
# %mul_314 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_98, 2.0), kwargs = {}) | |
# %mul_315 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_9, %mul_314), kwargs = {}) | |
# %add_198 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_311, %mul_315), kwargs = {}) | |
# %convert_element_type_382 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_198, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_15 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_15', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_15(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (76)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (76)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/rw/crwt7reaapsh2trvps6lvjbbnnscw3ujc7z7ghzdznxwumm6b43r.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {}) | |
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {}) | |
# %mul_297 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %add_165), kwargs = {}) | |
# %sum_19 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_297,), kwargs = {}) | |
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {}) | |
# %mul_337 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %convert_element_type_11), kwargs = {}) | |
# %sum_26 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_337,), kwargs = {}) | |
# %mul_338 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_92), kwargs = {}) | |
# %mul_339 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %add_153), kwargs = {}) | |
# %sum_27 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_339,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_16 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_16', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 9, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_16(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_ptr8 + (44)) | |
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = tmp19 * tmp31 | |
tl.store(out_ptr4 + (r0_1 + 1024*x0), tmp32, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/2e/c2e4k3473pxml3aiowmqgdvsvsjm635x3t5jxglwwuodopxjzrqe.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {}) | |
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {}) | |
triton_red_fused__to_copy_add_mul_sum_17 = async_compile.triton('triton_red_fused__to_copy_add_mul_sum_17', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 1, 'r0_': 65536}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 3), 'tt.equal_to': (2,)}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mul_sum_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_mul_sum_17(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 1 | |
r0_numel = 65536 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_0 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_0), None, eviction_policy='evict_first') | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tmp3 | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp2, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/yb/cybckfvcqgm55idccl4nb26mpv4fp2n6sltf2mq7jzlrn37bjned.py | |
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_37 => convert_element_type_268 | |
# Graph fragment: | |
# %mul_350 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_184, %select_87), kwargs = {}) | |
# %convert_element_type_268 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_122, torch.float32), kwargs = {}) | |
# %convert_element_type_413 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_350, torch.float32), kwargs = {}) | |
# %mul_352 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %convert_element_type_268), kwargs = {}) | |
# %mul_353 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %rsqrt_52), kwargs = {}) | |
# %sum_31 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_352, [3], True), kwargs = {}) | |
# %div_13 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_30, 128), kwargs = {}) | |
# %pow_107 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_268, 1.0), kwargs = {}) | |
# %mul_356 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_107, 2.0), kwargs = {}) | |
# %mul_357 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_13, %mul_356), kwargs = {}) | |
# %add_214 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_353, %mul_357), kwargs = {}) | |
# %convert_element_type_414 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_214, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_18 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_18', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_18', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_18(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (74)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (74)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/3w/c3wolzn6xji5tvbc62o57dehpylfh33s7yixvo3nby3vsodrtfpj.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten.mul, aten._to_copy, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %mul_294 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_100), kwargs = {}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {}) | |
# %mul_336 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_93), kwargs = {}) | |
# %add_207 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_294, %mul_336), kwargs = {}) | |
# %add_221 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_212, %view_219), kwargs = {}) | |
# %mul_378 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_86), kwargs = {}) | |
# %mul_379 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %convert_element_type_11), kwargs = {}) | |
# %sum_34 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_379,), kwargs = {}) | |
# %add_223 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_207, %mul_378), kwargs = {}) | |
# %mul_380 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_85), kwargs = {}) | |
# %mul_381 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %add_141), kwargs = {}) | |
# %sum_35 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_381,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_19 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_19', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_19', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_19(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr1 + (47)) | |
tmp4 = tl.broadcast_to(tmp3, [R0_BLOCK]) | |
tmp7 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp10 = tl.load(in_ptr1 + (45)) | |
tmp11 = tl.broadcast_to(tmp10, [R0_BLOCK]) | |
tmp15 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp16 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr1 + (43)) | |
tmp19 = tl.broadcast_to(tmp18, [R0_BLOCK]) | |
tmp23 = tl.load(in_ptr1 + (42)) | |
tmp24 = tl.broadcast_to(tmp23, [R0_BLOCK]) | |
tmp27 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last') | |
tmp36 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tmp9 = tmp7 + tmp8 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp9 * tmp12 | |
tmp14 = tmp6 + tmp13 | |
tmp17 = tmp15 + tmp16 | |
tmp20 = tmp19.to(tl.float32) | |
tmp21 = tmp17 * tmp20 | |
tmp22 = tmp14 + tmp21 | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = tmp17 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = tmp17 * tmp31 | |
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK]) | |
tmp35 = triton_helpers.promote_to_tensor(tl.sum(tmp33, 0)) | |
tmp37 = tmp17 * tmp36 | |
tmp38 = tl.broadcast_to(tmp37, [R0_BLOCK]) | |
tmp40 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0)) | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp22, None) | |
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp26, None) | |
tl.store(out_ptr1 + (x0), tmp35, None) | |
tl.store(out_ptr2 + (x0), tmp40, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/vm/cvmnq4yi35zsfxtrc6yd7p7r3wbfy6nb6zghevp246rtyncjsb46.py | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# v_34 => convert_element_type_248, convert_element_type_249, mul_187 | |
# Graph fragment: | |
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {}) | |
# %mul_187 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_248, %rsqrt_48), kwargs = {}) | |
# %convert_element_type_249 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_187, torch.bfloat16), kwargs = {}) | |
# %mul_391 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %convert_element_type_249), kwargs = {}) | |
# %sum_37 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_391,), kwargs = {}) | |
triton_red_fused__to_copy_mul_sum_20 = async_compile.triton('triton_red_fused__to_copy_mul_sum_20', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 512, 'r0_': 131072}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_20', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_mul_sum_20(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 512 | |
r0_numel = 131072 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x0 = xindex | |
_tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp4 = tmp2 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp0 * tmp5 | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(xmask, tmp9, _tmp8) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp8, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/fe/cfepqopoqmnpzbgue7ezqbcqee5l7skpuhz2cft2vnpout7de5lb.py | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_34 => convert_element_type_248 | |
# Graph fragment: | |
# %mul_390 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %select_81), kwargs = {}) | |
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {}) | |
# %convert_element_type_444 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_390, torch.float32), kwargs = {}) | |
# %mul_392 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %convert_element_type_248), kwargs = {}) | |
# %mul_393 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %rsqrt_48), kwargs = {}) | |
# %sum_38 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_392, [3], True), kwargs = {}) | |
# %div_17 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_34, 128), kwargs = {}) | |
# %pow_116 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_248, 1.0), kwargs = {}) | |
# %mul_396 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_116, 2.0), kwargs = {}) | |
# %mul_397 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_17, %mul_396), kwargs = {}) | |
# %add_229 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_393, %mul_397), kwargs = {}) | |
# %convert_element_type_445 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_229, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_21 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_21', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_21(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (72)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (72)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/yj/cyjkhrb27hvv3l3kua6flocagrnqlqmc5754eswnsklhtleriys4.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {}) | |
# %mul_420 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_79), kwargs = {}) | |
triton_poi_fused_add_mul_22 = async_compile.triton('triton_poi_fused_add_mul_22', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (40)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/5l/c5l5r3u6uyjim3crjlcztqsca7jjltxx4qahm4o4c3sfeyhnral7.py | |
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_31 => convert_element_type_228 | |
# Graph fragment: | |
# %mul_430 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_224, %select_75), kwargs = {}) | |
# %convert_element_type_228 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_104, torch.float32), kwargs = {}) | |
# %convert_element_type_475 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_430, torch.float32), kwargs = {}) | |
# %mul_432 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %convert_element_type_228), kwargs = {}) | |
# %mul_433 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %rsqrt_44), kwargs = {}) | |
# %sum_45 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_432, [3], True), kwargs = {}) | |
# %div_21 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_38, 128), kwargs = {}) | |
# %pow_125 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_228, 1.0), kwargs = {}) | |
# %mul_436 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_125, 2.0), kwargs = {}) | |
# %mul_437 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_21, %mul_436), kwargs = {}) | |
# %add_244 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_433, %mul_437), kwargs = {}) | |
# %convert_element_type_476 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_244, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_23 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_23', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (70)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (70)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ee/ceemqyuqc5se34jwuz7av5jovsufciprlj22yvbhuuydkqt6tt3k.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {}) | |
# %mul_419 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %convert_element_type_11), kwargs = {}) | |
# %sum_41 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_419,), kwargs = {}) | |
# %mul_421 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %add_130), kwargs = {}) | |
# %sum_42 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_421,), kwargs = {}) | |
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {}) | |
# %mul_459 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %convert_element_type_11), kwargs = {}) | |
# %sum_48 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_459,), kwargs = {}) | |
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {}) | |
# %mul_461 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %add_119), kwargs = {}) | |
# %sum_49 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_461,), kwargs = {}) | |
# %mul_463 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %add_36), kwargs = {}) | |
# %sum_50 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_463,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_24 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_24', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'in_ptr9': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 10, 'num_reduction': 5, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_24(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_ptr8 + (38)) | |
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK]) | |
tmp33 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = tmp19 * tmp31 | |
tmp34 = tmp32 * tmp33 | |
tmp35 = tl.broadcast_to(tmp34, [R0_BLOCK]) | |
tmp37 = triton_helpers.promote_to_tensor(tl.sum(tmp35, 0)) | |
tl.store(out_ptr5 + (r0_1 + 1024*x0), tmp32, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
tl.store(out_ptr4 + (x0), tmp37, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/jg/cjgcteapcf4w5vv6fwhouzamepnj6zi7b3lj4lab2tysbpwawhfk.py | |
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_28 => convert_element_type_208 | |
# Graph fragment: | |
# %mul_472 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_244, %select_68), kwargs = {}) | |
# %convert_element_type_208 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_95, torch.float32), kwargs = {}) | |
# %convert_element_type_507 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_472, torch.float32), kwargs = {}) | |
# %mul_474 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %convert_element_type_208), kwargs = {}) | |
# %mul_475 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %rsqrt_40), kwargs = {}) | |
# %sum_53 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_474, [3], True), kwargs = {}) | |
# %div_25 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_42, 128), kwargs = {}) | |
# %pow_134 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_208, 1.0), kwargs = {}) | |
# %mul_478 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_134, 2.0), kwargs = {}) | |
# %mul_479 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_25, %mul_478), kwargs = {}) | |
# %add_259 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_475, %mul_479), kwargs = {}) | |
# %convert_element_type_508 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_259, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_25 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_25', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (68)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (68)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/tg/ctg6mqjt4r3k5i3go5y7jwzewso2vhohvzrh7cb2aqs7s3qv7ndn.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {}) | |
# %mul_418 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_80), kwargs = {}) | |
# %add_238 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_223, %mul_418), kwargs = {}) | |
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {}) | |
# %mul_458 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_74), kwargs = {}) | |
# %add_253 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_238, %mul_458), kwargs = {}) | |
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {}) | |
# %mul_500 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_67), kwargs = {}) | |
# %mul_501 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %convert_element_type_11), kwargs = {}) | |
# %sum_56 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_501,), kwargs = {}) | |
# %add_268 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_253, %mul_500), kwargs = {}) | |
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {}) | |
# %mul_503 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %add_107), kwargs = {}) | |
# %sum_57 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_503,), kwargs = {}) | |
# %mul_505 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %add_58), kwargs = {}) | |
# %sum_58 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_505,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_26 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_26', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_26', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 15, 'num_reduction': 3, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_26(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (36)) | |
tmp18 = tl.broadcast_to(tmp17, [R0_BLOCK]) | |
tmp21 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp26 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp27 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp28 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp30 = tl.load(in_ptr5 + (41)) | |
tmp31 = tl.broadcast_to(tmp30, [R0_BLOCK]) | |
tmp35 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp36 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp38 = tl.load(in_ptr5 + (39)) | |
tmp39 = tl.broadcast_to(tmp38, [R0_BLOCK]) | |
tmp43 = tl.load(in_ptr5 + (37)) | |
tmp44 = tl.broadcast_to(tmp43, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = tmp2 * tmp19 | |
tmp22 = tmp20 * tmp21 | |
tmp23 = tl.broadcast_to(tmp22, [R0_BLOCK]) | |
tmp25 = triton_helpers.promote_to_tensor(tl.sum(tmp23, 0)) | |
tmp29 = tmp27 + tmp28 | |
tmp32 = tmp31.to(tl.float32) | |
tmp33 = tmp29 * tmp32 | |
tmp34 = tmp26 + tmp33 | |
tmp37 = tmp35 + tmp36 | |
tmp40 = tmp39.to(tl.float32) | |
tmp41 = tmp37 * tmp40 | |
tmp42 = tmp34 + tmp41 | |
tmp45 = tmp44.to(tl.float32) | |
tmp46 = tmp2 * tmp45 | |
tmp47 = tmp42 + tmp46 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp47, None) | |
tl.store(out_ptr3 + (r0_1 + 1024*x0), tmp20, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp25, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/am/camyez2exa6htwupvbzdzegpgqt727hcqpldhjcpx3zt5nnd7onq.py | |
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_25 => convert_element_type_188 | |
# Graph fragment: | |
# %mul_514 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_264, %select_61), kwargs = {}) | |
# %convert_element_type_188 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_86, torch.float32), kwargs = {}) | |
# %convert_element_type_539 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_514, torch.float32), kwargs = {}) | |
# %mul_516 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %convert_element_type_188), kwargs = {}) | |
# %mul_517 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %rsqrt_36), kwargs = {}) | |
# %sum_61 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_516, [3], True), kwargs = {}) | |
# %div_29 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_46, 128), kwargs = {}) | |
# %pow_143 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_188, 1.0), kwargs = {}) | |
# %mul_520 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_143, 2.0), kwargs = {}) | |
# %mul_521 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_29, %mul_520), kwargs = {}) | |
# %add_275 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_517, %mul_521), kwargs = {}) | |
# %convert_element_type_540 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_275, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_27 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_27', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_27(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (66)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (66)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/w5/cw5lm2xilanikcb3utse3lrzhou2ow25ytp5um7ilwwffymj7vvv.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {}) | |
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {}) | |
triton_poi_fused_add_mul_28 = async_compile.triton('triton_poi_fused_add_mul_28', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_28(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (34)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/2m/c2mids55xghemi35me3p4us3qt5tzutbm52tf45sv4r3kifthkqn.py | |
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_22 => convert_element_type_168 | |
# Graph fragment: | |
# %mul_556 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_284, %select_54), kwargs = {}) | |
# %convert_element_type_168 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_77, torch.float32), kwargs = {}) | |
# %convert_element_type_571 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_556, torch.float32), kwargs = {}) | |
# %mul_558 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %convert_element_type_168), kwargs = {}) | |
# %mul_559 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %rsqrt_32), kwargs = {}) | |
# %sum_69 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_558, [3], True), kwargs = {}) | |
# %div_33 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_50, 128), kwargs = {}) | |
# %pow_152 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_168, 1.0), kwargs = {}) | |
# %mul_562 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_152, 2.0), kwargs = {}) | |
# %mul_563 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_33, %mul_562), kwargs = {}) | |
# %add_291 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_559, %mul_563), kwargs = {}) | |
# %convert_element_type_572 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_291, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_29 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_29', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_29(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (64)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (64)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/7g/c7ggnywfvwyw3iiwnlm66m4wogxu6ssncnpxmdwolqb7dfrdx2i3.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {}) | |
# %mul_586 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_52), kwargs = {}) | |
triton_poi_fused_add_mul_30 = async_compile.triton('triton_poi_fused_add_mul_30', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_30(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (32)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/7d/c7dbizqshojxmody3a5tubxthr47qpwkudrrqa5hfb6qgxo5ekow.py | |
# Topologically Sorted Source Nodes: [x, rms_norm_29], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum, aten.div, aten.pow] | |
# Source node to ATen node mapping: | |
# rms_norm_29 => convert_element_type_152 | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {}) | |
# %mul_542 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_60), kwargs = {}) | |
# %mul_543 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %convert_element_type_11), kwargs = {}) | |
# %sum_64 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_543,), kwargs = {}) | |
# %add_284 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_268, %mul_542), kwargs = {}) | |
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {}) | |
# %mul_545 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %add_95), kwargs = {}) | |
# %sum_65 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_545,), kwargs = {}) | |
# %mul_546 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %select_56), kwargs = {}) | |
# %mul_547 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %add_80), kwargs = {}) | |
# %sum_66 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_547,), kwargs = {}) | |
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {}) | |
# %mul_584 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_53), kwargs = {}) | |
# %mul_585 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %convert_element_type_11), kwargs = {}) | |
# %sum_72 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_585,), kwargs = {}) | |
# %add_300 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_284, %mul_584), kwargs = {}) | |
# %mul_587 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %add_83), kwargs = {}) | |
# %sum_73 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_587,), kwargs = {}) | |
# %convert_element_type_595 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_279, torch.float32), kwargs = {}) | |
# %convert_element_type_152 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_81, torch.float32), kwargs = {}) | |
# %mul_590 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %convert_element_type_152), kwargs = {}) | |
# %mul_591 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %rsqrt_29), kwargs = {}) | |
# %sum_74 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_590, [2], True), kwargs = {}) | |
# %div_36 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_53, 1024), kwargs = {}) | |
# %pow_159 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_152, 1.0), kwargs = {}) | |
# %mul_594 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_159, 2.0), kwargs = {}) | |
# %mul_595 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_36, %mul_594), kwargs = {}) | |
# %add_304 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_591, %mul_595), kwargs = {}) | |
# %convert_element_type_596 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_304, torch.bfloat16), kwargs = {}) | |
# %add_305 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_586, %convert_element_type_596), kwargs = {}) | |
# %mul_596 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_49), kwargs = {}) | |
# %mul_597 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %convert_element_type_11), kwargs = {}) | |
# %sum_75 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_597,), kwargs = {}) | |
# %add_306 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_300, %mul_596), kwargs = {}) | |
# %mul_598 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_48), kwargs = {}) | |
# %mul_599 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %add_80), kwargs = {}) | |
# %sum_76 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_599,), kwargs = {}) | |
# %add_307 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_546, %mul_598), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_31 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_31', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*fp32', 'out_ptr6': '*fp32', 'out_ptr7': '*fp32', 'out_ptr8': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_31', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 20, 'num_reduction': 8, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_31(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') | |
tmp24 = tl.load(in_out_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp25 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp26 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp28 = tl.load(in_ptr5 + (35)) | |
tmp29 = tl.broadcast_to(tmp28, [R0_BLOCK]) | |
tmp33 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp34 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp36 = tl.load(in_ptr5 + (33)) | |
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK]) | |
tmp41 = tl.load(in_ptr5 + (31)) | |
tmp42 = tl.broadcast_to(tmp41, [R0_BLOCK]) | |
tmp46 = tl.load(in_ptr5 + (34)) | |
tmp47 = tl.broadcast_to(tmp46, [R0_BLOCK]) | |
tmp50 = tl.load(in_ptr5 + (6)) | |
tmp51 = tl.broadcast_to(tmp50, [R0_BLOCK]) | |
tmp54 = tl.load(in_ptr5 + (30)) | |
tmp55 = tl.broadcast_to(tmp54, [R0_BLOCK]) | |
tmp59 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp61 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last') | |
tmp68 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp73 = tl.load(in_ptr11 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp82 = tl.load(in_ptr12 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp10 = tmp1 * tmp9 | |
tmp11 = -0.5 | |
tmp12 = tmp7 * tmp11 | |
tmp13 = tmp9 * tmp9 | |
tmp14 = tmp13 * tmp9 | |
tmp15 = tmp12 * tmp14 | |
tmp16 = 0.0009765625 | |
tmp17 = tmp15 * tmp16 | |
tmp18 = 2.0 | |
tmp19 = tmp3 * tmp18 | |
tmp20 = tmp17 * tmp19 | |
tmp21 = tmp10 + tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp8 + tmp22 | |
tmp27 = tmp25 + tmp26 | |
tmp30 = tmp29.to(tl.float32) | |
tmp31 = tmp27 * tmp30 | |
tmp32 = tmp24 + tmp31 | |
tmp35 = tmp33 + tmp34 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp35 * tmp38 | |
tmp40 = tmp32 + tmp39 | |
tmp43 = tmp42.to(tl.float32) | |
tmp44 = tmp23 * tmp43 | |
tmp45 = tmp40 + tmp44 | |
tmp48 = tmp47.to(tl.float32) | |
tmp49 = tmp27 * tmp48 | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp49 * tmp52 | |
tmp56 = tmp55.to(tl.float32) | |
tmp57 = tmp23 * tmp56 | |
tmp58 = tmp53 + tmp57 | |
tmp60 = tmp59.to(tl.float32) | |
tmp62 = tmp60 * tmp61 | |
tmp63 = tmp62.to(tl.float32) | |
tmp64 = tmp27 * tmp63 | |
tmp65 = tl.broadcast_to(tmp64, [R0_BLOCK]) | |
tmp67 = triton_helpers.promote_to_tensor(tl.sum(tmp65, 0)) | |
tmp69 = tmp27 * tmp68 | |
tmp70 = tl.broadcast_to(tmp69, [R0_BLOCK]) | |
tmp72 = triton_helpers.promote_to_tensor(tl.sum(tmp70, 0)) | |
tmp74 = tmp49 * tmp73 | |
tmp75 = tl.broadcast_to(tmp74, [R0_BLOCK]) | |
tmp77 = triton_helpers.promote_to_tensor(tl.sum(tmp75, 0)) | |
tmp78 = tmp35 * tmp63 | |
tmp79 = tl.broadcast_to(tmp78, [R0_BLOCK]) | |
tmp81 = triton_helpers.promote_to_tensor(tl.sum(tmp79, 0)) | |
tmp83 = tmp35 * tmp82 | |
tmp84 = tl.broadcast_to(tmp83, [R0_BLOCK]) | |
tmp86 = triton_helpers.promote_to_tensor(tl.sum(tmp84, 0)) | |
tmp87 = tmp23 * tmp63 | |
tmp88 = tl.broadcast_to(tmp87, [R0_BLOCK]) | |
tmp90 = triton_helpers.promote_to_tensor(tl.sum(tmp88, 0)) | |
tmp91 = tmp23 * tmp73 | |
tmp92 = tl.broadcast_to(tmp91, [R0_BLOCK]) | |
tmp94 = triton_helpers.promote_to_tensor(tl.sum(tmp92, 0)) | |
tl.store(in_out_ptr1 + (r0_1 + 1024*x0), tmp45, None) | |
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp58, None) | |
tl.store(out_ptr2 + (x0), tmp67, None) | |
tl.store(out_ptr3 + (x0), tmp72, None) | |
tl.store(out_ptr4 + (x0), tmp77, None) | |
tl.store(out_ptr5 + (x0), tmp81, None) | |
tl.store(out_ptr6 + (x0), tmp86, None) | |
tl.store(out_ptr7 + (x0), tmp90, None) | |
tl.store(out_ptr8 + (x0), tmp94, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/tv/ctvx4p64xr75uwoae5pprzrej2pwv7rbwkrsu5y2zaip2qp6b45i.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten._to_copy, aten.add] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_363 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_18, torch.float32), kwargs = {}) | |
# %select_scatter_default_3 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_363, 0, 1), kwargs = {}) | |
# %convert_element_type_364 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_19, torch.float32), kwargs = {}) | |
# %select_scatter_default_4 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_364, 0, 0), kwargs = {}) | |
# %add_194 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_3, %select_scatter_default_4), kwargs = {}) | |
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_6 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_194, 0, 15), kwargs = {}) | |
# %convert_element_type_395 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_26, torch.float32), kwargs = {}) | |
# %select_scatter_default_10 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_395, 0, 1), kwargs = {}) | |
# %convert_element_type_396 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_27, torch.float32), kwargs = {}) | |
# %select_scatter_default_11 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_396, 0, 0), kwargs = {}) | |
# %add_208 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_10, %select_scatter_default_11), kwargs = {}) | |
# %select_scatter_default_13 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_208, 0, 14), kwargs = {}) | |
# %add_210 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_6, %select_scatter_default_13), kwargs = {}) | |
# %convert_element_type_427 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_34, torch.float32), kwargs = {}) | |
# %select_scatter_default_17 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_427, 0, 1), kwargs = {}) | |
# %convert_element_type_428 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_35, torch.float32), kwargs = {}) | |
# %select_scatter_default_18 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_428, 0, 0), kwargs = {}) | |
# %add_224 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_17, %select_scatter_default_18), kwargs = {}) | |
# %select_scatter_default_20 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_224, 0, 13), kwargs = {}) | |
# %add_226 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_210, %select_scatter_default_20), kwargs = {}) | |
# %convert_element_type_458 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_41, torch.float32), kwargs = {}) | |
# %select_scatter_default_23 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_458, 0, 1), kwargs = {}) | |
# %convert_element_type_459 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_42, torch.float32), kwargs = {}) | |
# %select_scatter_default_24 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_459, 0, 0), kwargs = {}) | |
# %add_239 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_23, %select_scatter_default_24), kwargs = {}) | |
# %select_scatter_default_26 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_239, 0, 12), kwargs = {}) | |
# %add_241 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_226, %select_scatter_default_26), kwargs = {}) | |
# %convert_element_type_489 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_48, torch.float32), kwargs = {}) | |
# %select_scatter_default_29 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_489, 0, 1), kwargs = {}) | |
# %convert_element_type_490 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_49, torch.float32), kwargs = {}) | |
# %select_scatter_default_30 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_490, 0, 0), kwargs = {}) | |
# %add_254 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_29, %select_scatter_default_30), kwargs = {}) | |
# %select_scatter_default_32 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_254, 0, 11), kwargs = {}) | |
# %add_256 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_241, %select_scatter_default_32), kwargs = {}) | |
# %convert_element_type_521 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_56, torch.float32), kwargs = {}) | |
# %select_scatter_default_36 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_521, 0, 1), kwargs = {}) | |
# %convert_element_type_522 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_57, torch.float32), kwargs = {}) | |
# %select_scatter_default_37 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_522, 0, 0), kwargs = {}) | |
# %add_269 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_36, %select_scatter_default_37), kwargs = {}) | |
# %select_scatter_default_39 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_269, 0, 10), kwargs = {}) | |
# %add_271 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_256, %select_scatter_default_39), kwargs = {}) | |
# %convert_element_type_553 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_64, torch.float32), kwargs = {}) | |
# %select_scatter_default_43 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_553, 0, 1), kwargs = {}) | |
# %convert_element_type_554 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_65, torch.float32), kwargs = {}) | |
# %select_scatter_default_44 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_554, 0, 0), kwargs = {}) | |
# %add_285 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_43, %select_scatter_default_44), kwargs = {}) | |
# %select_scatter_default_46 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_285, 0, 9), kwargs = {}) | |
# %add_287 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_271, %select_scatter_default_46), kwargs = {}) | |
# %convert_element_type_585 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_72, torch.float32), kwargs = {}) | |
# %select_scatter_default_50 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_585, 0, 1), kwargs = {}) | |
# %convert_element_type_586 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_73, torch.float32), kwargs = {}) | |
# %select_scatter_default_51 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_586, 0, 0), kwargs = {}) | |
# %add_301 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_50, %select_scatter_default_51), kwargs = {}) | |
# %select_scatter_default_53 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_301, 0, 8), kwargs = {}) | |
# %add_303 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_287, %select_scatter_default_53), kwargs = {}) | |
# %convert_element_type_597 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_75, torch.float32), kwargs = {}) | |
# %select_scatter_default_54 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_597, 0, 1), kwargs = {}) | |
# %convert_element_type_598 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_76, torch.float32), kwargs = {}) | |
# %select_scatter_default_55 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_598, 0, 0), kwargs = {}) | |
# %add_308 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_54, %select_scatter_default_55), kwargs = {}) | |
# %select_scatter_default_56 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_308, 0, 7), kwargs = {}) | |
# %add_309 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_303, %select_scatter_default_56), kwargs = {}) | |
triton_poi_fused__to_copy_add_select_backward_32 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_32', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 32}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'in_ptr15': '*bf16', 'in_ptr16': '*bf16', 'in_ptr17': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 18, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_add_select_backward_32(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x1 = xindex // 2 | |
x0 = (xindex % 2) | |
x2 = xindex | |
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK]) | |
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK]) | |
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK]) | |
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32) | |
tmp26 = tl.broadcast_to(tmp25, [XBLOCK]) | |
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32) | |
tmp35 = tl.broadcast_to(tmp34, [XBLOCK]) | |
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32) | |
tmp39 = tl.broadcast_to(tmp38, [XBLOCK]) | |
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32) | |
tmp48 = tl.broadcast_to(tmp47, [XBLOCK]) | |
tmp51 = tl.load(in_ptr7 + (0)).to(tl.float32) | |
tmp52 = tl.broadcast_to(tmp51, [XBLOCK]) | |
tmp60 = tl.load(in_ptr8 + (0)).to(tl.float32) | |
tmp61 = tl.broadcast_to(tmp60, [XBLOCK]) | |
tmp64 = tl.load(in_ptr9 + (0)).to(tl.float32) | |
tmp65 = tl.broadcast_to(tmp64, [XBLOCK]) | |
tmp73 = tl.load(in_ptr10 + (0)).to(tl.float32) | |
tmp74 = tl.broadcast_to(tmp73, [XBLOCK]) | |
tmp77 = tl.load(in_ptr11 + (0)).to(tl.float32) | |
tmp78 = tl.broadcast_to(tmp77, [XBLOCK]) | |
tmp86 = tl.load(in_ptr12 + (0)).to(tl.float32) | |
tmp87 = tl.broadcast_to(tmp86, [XBLOCK]) | |
tmp90 = tl.load(in_ptr13 + (0)).to(tl.float32) | |
tmp91 = tl.broadcast_to(tmp90, [XBLOCK]) | |
tmp99 = tl.load(in_ptr14 + (0)).to(tl.float32) | |
tmp100 = tl.broadcast_to(tmp99, [XBLOCK]) | |
tmp103 = tl.load(in_ptr15 + (0)).to(tl.float32) | |
tmp104 = tl.broadcast_to(tmp103, [XBLOCK]) | |
tmp112 = tl.load(in_ptr16 + (0)).to(tl.float32) | |
tmp113 = tl.broadcast_to(tmp112, [XBLOCK]) | |
tmp116 = tl.load(in_ptr17 + (0)).to(tl.float32) | |
tmp117 = tl.broadcast_to(tmp116, [XBLOCK]) | |
tmp0 = x1 | |
tmp1 = tl.full([1], 15, tl.int32) | |
tmp2 = tmp0 == tmp1 | |
tmp3 = x0 | |
tmp4 = tl.full([1], 1, tl.int32) | |
tmp5 = tmp3 == tmp4 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = 0.0 | |
tmp10 = tl.where(tmp5, tmp8, tmp9) | |
tmp11 = tl.full([1], 0, tl.int32) | |
tmp12 = tmp3 == tmp11 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tl.where(tmp12, tmp15, tmp9) | |
tmp17 = tmp10 + tmp16 | |
tmp18 = tl.where(tmp2, tmp17, tmp9) | |
tmp19 = tl.full([1], 14, tl.int32) | |
tmp20 = tmp0 == tmp19 | |
tmp23 = tmp22.to(tl.float32) | |
tmp24 = tl.where(tmp5, tmp23, tmp9) | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tl.where(tmp12, tmp27, tmp9) | |
tmp29 = tmp24 + tmp28 | |
tmp30 = tl.where(tmp20, tmp29, tmp9) | |
tmp31 = tmp18 + tmp30 | |
tmp32 = tl.full([1], 13, tl.int32) | |
tmp33 = tmp0 == tmp32 | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tl.where(tmp5, tmp36, tmp9) | |
tmp40 = tmp39.to(tl.float32) | |
tmp41 = tl.where(tmp12, tmp40, tmp9) | |
tmp42 = tmp37 + tmp41 | |
tmp43 = tl.where(tmp33, tmp42, tmp9) | |
tmp44 = tmp31 + tmp43 | |
tmp45 = tl.full([1], 12, tl.int32) | |
tmp46 = tmp0 == tmp45 | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tl.where(tmp5, tmp49, tmp9) | |
tmp53 = tmp52.to(tl.float32) | |
tmp54 = tl.where(tmp12, tmp53, tmp9) | |
tmp55 = tmp50 + tmp54 | |
tmp56 = tl.where(tmp46, tmp55, tmp9) | |
tmp57 = tmp44 + tmp56 | |
tmp58 = tl.full([1], 11, tl.int32) | |
tmp59 = tmp0 == tmp58 | |
tmp62 = tmp61.to(tl.float32) | |
tmp63 = tl.where(tmp5, tmp62, tmp9) | |
tmp66 = tmp65.to(tl.float32) | |
tmp67 = tl.where(tmp12, tmp66, tmp9) | |
tmp68 = tmp63 + tmp67 | |
tmp69 = tl.where(tmp59, tmp68, tmp9) | |
tmp70 = tmp57 + tmp69 | |
tmp71 = tl.full([1], 10, tl.int32) | |
tmp72 = tmp0 == tmp71 | |
tmp75 = tmp74.to(tl.float32) | |
tmp76 = tl.where(tmp5, tmp75, tmp9) | |
tmp79 = tmp78.to(tl.float32) | |
tmp80 = tl.where(tmp12, tmp79, tmp9) | |
tmp81 = tmp76 + tmp80 | |
tmp82 = tl.where(tmp72, tmp81, tmp9) | |
tmp83 = tmp70 + tmp82 | |
tmp84 = tl.full([1], 9, tl.int32) | |
tmp85 = tmp0 == tmp84 | |
tmp88 = tmp87.to(tl.float32) | |
tmp89 = tl.where(tmp5, tmp88, tmp9) | |
tmp92 = tmp91.to(tl.float32) | |
tmp93 = tl.where(tmp12, tmp92, tmp9) | |
tmp94 = tmp89 + tmp93 | |
tmp95 = tl.where(tmp85, tmp94, tmp9) | |
tmp96 = tmp83 + tmp95 | |
tmp97 = tl.full([1], 8, tl.int32) | |
tmp98 = tmp0 == tmp97 | |
tmp101 = tmp100.to(tl.float32) | |
tmp102 = tl.where(tmp5, tmp101, tmp9) | |
tmp105 = tmp104.to(tl.float32) | |
tmp106 = tl.where(tmp12, tmp105, tmp9) | |
tmp107 = tmp102 + tmp106 | |
tmp108 = tl.where(tmp98, tmp107, tmp9) | |
tmp109 = tmp96 + tmp108 | |
tmp110 = tl.full([1], 7, tl.int32) | |
tmp111 = tmp0 == tmp110 | |
tmp114 = tmp113.to(tl.float32) | |
tmp115 = tl.where(tmp5, tmp114, tmp9) | |
tmp118 = tmp117.to(tl.float32) | |
tmp119 = tl.where(tmp12, tmp118, tmp9) | |
tmp120 = tmp115 + tmp119 | |
tmp121 = tl.where(tmp111, tmp120, tmp9) | |
tmp122 = tmp109 + tmp121 | |
tl.store(out_ptr0 + (x2), tmp122, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/qz/cqzbz3hka7het3halgo66rhyco4yixsw5yvj3ce7xtlbrrkl3q44.py | |
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_19 => convert_element_type_142 | |
# Graph fragment: | |
# %mul_608 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_312, %select_45), kwargs = {}) | |
# %convert_element_type_142 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_68, torch.float32), kwargs = {}) | |
# %convert_element_type_614 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_608, torch.float32), kwargs = {}) | |
# %mul_610 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %convert_element_type_142), kwargs = {}) | |
# %mul_611 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %rsqrt_27), kwargs = {}) | |
# %sum_79 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_610, [3], True), kwargs = {}) | |
# %div_38 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_55, 128), kwargs = {}) | |
# %pow_164 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_142, 1.0), kwargs = {}) | |
# %mul_614 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_164, 2.0), kwargs = {}) | |
# %mul_615 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_38, %mul_614), kwargs = {}) | |
# %add_312 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_611, %mul_615), kwargs = {}) | |
# %convert_element_type_615 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_312, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_33 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_33', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_33(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (60)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (60)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/cf/ccfrqdp3jfaejzety3ut7kp3qjdtsb2f57ehfivai7bm4enygzav.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {}) | |
# %mul_638 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_43), kwargs = {}) | |
triton_poi_fused_add_mul_34 = async_compile.triton('triton_poi_fused_add_mul_34', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_34(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (28)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/zu/czuup4uodzelkv3jrzpk43tivi7ihymxeqi6lh5xpm4tylw5ufnq.py | |
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_16 => convert_element_type_122 | |
# Graph fragment: | |
# %mul_648 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_332, %select_39), kwargs = {}) | |
# %convert_element_type_122 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_59, torch.float32), kwargs = {}) | |
# %convert_element_type_645 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_648, torch.float32), kwargs = {}) | |
# %mul_650 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %convert_element_type_122), kwargs = {}) | |
# %mul_651 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %rsqrt_23), kwargs = {}) | |
# %sum_86 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_650, [3], True), kwargs = {}) | |
# %div_42 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_59, 128), kwargs = {}) | |
# %pow_173 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_122, 1.0), kwargs = {}) | |
# %mul_654 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_173, 2.0), kwargs = {}) | |
# %mul_655 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_42, %mul_654), kwargs = {}) | |
# %add_327 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_651, %mul_655), kwargs = {}) | |
# %convert_element_type_646 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_327, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_35 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_35', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_35(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (58)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (58)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ye/cyevw6rg4ip64gwg4p72a7ympinl3ibkmxfljw4vtadz2pjlhji5.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {}) | |
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {}) | |
# %mul_504 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %select_63), kwargs = {}) | |
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {}) | |
# %mul_637 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %convert_element_type_11), kwargs = {}) | |
# %sum_82 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_637,), kwargs = {}) | |
# %mul_639 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %add_69), kwargs = {}) | |
# %sum_83 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_639,), kwargs = {}) | |
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {}) | |
# %mul_677 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %convert_element_type_11), kwargs = {}) | |
# %sum_89 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_677,), kwargs = {}) | |
# %mul_678 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_37), kwargs = {}) | |
# %mul_679 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %add_58), kwargs = {}) | |
# %sum_90 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_679,), kwargs = {}) | |
# %add_337 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_504, %mul_678), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_36 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_36', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_36', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_36(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp32 = tl.load(in_ptr9 + (36)) | |
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK]) | |
tmp36 = tl.load(in_ptr9 + (4)) | |
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK]) | |
tmp40 = tl.load(in_ptr9 + (26)) | |
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp29 + tmp30 | |
tmp34 = tmp33.to(tl.float32) | |
tmp35 = tmp31 * tmp34 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp35 * tmp38 | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp19 * tmp42 | |
tmp44 = tmp39 + tmp43 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ad/caddjc5juf35pqoynq6vglejv2oh6vukmzqdu6mtjvkncst3kqyo.py | |
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_13 => convert_element_type_102 | |
# Graph fragment: | |
# %mul_688 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_352, %select_33), kwargs = {}) | |
# %convert_element_type_102 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_50, torch.float32), kwargs = {}) | |
# %convert_element_type_676 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_688, torch.float32), kwargs = {}) | |
# %mul_690 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %convert_element_type_102), kwargs = {}) | |
# %mul_691 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %rsqrt_19), kwargs = {}) | |
# %sum_93 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_690, [3], True), kwargs = {}) | |
# %div_46 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_63, 128), kwargs = {}) | |
# %pow_182 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_102, 1.0), kwargs = {}) | |
# %mul_694 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_182, 2.0), kwargs = {}) | |
# %mul_695 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_46, %mul_694), kwargs = {}) | |
# %add_343 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_691, %mul_695), kwargs = {}) | |
# %convert_element_type_677 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_343, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_37 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_37', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_37(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (56)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (56)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/7f/c7fhg3d33i52ftlp3mhnqlchtl7dk3n3m3n4hihwykvl52gttcka.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {}) | |
# %mul_636 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_44), kwargs = {}) | |
# %add_321 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_306, %mul_636), kwargs = {}) | |
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {}) | |
# %mul_676 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_38), kwargs = {}) | |
# %add_336 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_321, %mul_676), kwargs = {}) | |
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {}) | |
# %mul_716 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_32), kwargs = {}) | |
# %add_352 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_336, %mul_716), kwargs = {}) | |
# %mul_718 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_31), kwargs = {}) | |
triton_poi_fused_add_mul_38 = async_compile.triton('triton_poi_fused_add_mul_38', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_38', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_38(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (29)) | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK]) | |
tmp9 = tl.load(in_ptr3 + (x0), None).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (x0), None).to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (27)) | |
tmp13 = tl.broadcast_to(tmp12, [XBLOCK]) | |
tmp17 = tl.load(in_ptr5 + (x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (x0), None).to(tl.float32) | |
tmp20 = tl.load(in_ptr2 + (25)) | |
tmp21 = tl.broadcast_to(tmp20, [XBLOCK]) | |
tmp25 = tl.load(in_ptr2 + (24)) | |
tmp26 = tl.broadcast_to(tmp25, [XBLOCK]) | |
tmp3 = tmp1 + tmp2 | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tmp3 * tmp6 | |
tmp8 = tmp0 + tmp7 | |
tmp11 = tmp9 + tmp10 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tmp11 * tmp14 | |
tmp16 = tmp8 + tmp15 | |
tmp19 = tmp17 + tmp18 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp19 * tmp22 | |
tmp24 = tmp16 + tmp23 | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tmp19 * tmp27 | |
tl.store(in_out_ptr0 + (x0), tmp24, None) | |
tl.store(out_ptr0 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/kw/ckwwptmpfidft6quaist42sma7dmjhpwkm7ihduw6nsauc6r776j.py | |
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_10 => convert_element_type_82 | |
# Graph fragment: | |
# %mul_728 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_372, %select_27), kwargs = {}) | |
# %convert_element_type_82 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_41, torch.float32), kwargs = {}) | |
# %convert_element_type_707 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_728, torch.float32), kwargs = {}) | |
# %mul_730 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %convert_element_type_82), kwargs = {}) | |
# %mul_731 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %rsqrt_15), kwargs = {}) | |
# %sum_100 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_730, [3], True), kwargs = {}) | |
# %div_50 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_67, 128), kwargs = {}) | |
# %pow_191 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_82, 1.0), kwargs = {}) | |
# %mul_734 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_191, 2.0), kwargs = {}) | |
# %mul_735 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_50, %mul_734), kwargs = {}) | |
# %add_358 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_731, %mul_735), kwargs = {}) | |
# %convert_element_type_708 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_358, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_39 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_39', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_39(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (54)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (54)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/dy/cdyh5oakmezmdxspcqgny4mt5oeduceikz5mfycxi4la3phv73r5.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {}) | |
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {}) | |
# %mul_462 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %select_70), kwargs = {}) | |
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {}) | |
# %mul_717 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %convert_element_type_11), kwargs = {}) | |
# %sum_96 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_717,), kwargs = {}) | |
# %mul_719 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %add_47), kwargs = {}) | |
# %sum_97 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_719,), kwargs = {}) | |
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {}) | |
# %mul_757 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %convert_element_type_11), kwargs = {}) | |
# %sum_103 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_757,), kwargs = {}) | |
# %mul_758 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_25), kwargs = {}) | |
# %mul_759 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %add_36), kwargs = {}) | |
# %sum_104 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_759,), kwargs = {}) | |
# %add_368 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_462, %mul_758), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_40 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_40', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_40', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_40(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp32 = tl.load(in_ptr9 + (38)) | |
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK]) | |
tmp36 = tl.load(in_ptr9 + (2)) | |
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK]) | |
tmp40 = tl.load(in_ptr9 + (22)) | |
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp29 + tmp30 | |
tmp34 = tmp33.to(tl.float32) | |
tmp35 = tmp31 * tmp34 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp35 * tmp38 | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp19 * tmp42 | |
tmp44 = tmp39 + tmp43 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/lv/clvyuwrkiqq5tjp6zq253fl6uzqefz5dja2cdbo5svtjygwezv6b.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %convert_element_type_347 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_13, torch.float32), kwargs = {}) | |
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_1 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_347, 0, 1), kwargs = {}) | |
# %convert_element_type_348 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_14, torch.float32), kwargs = {}) | |
# %select_scatter_default_2 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_348, 0, 0), kwargs = {}) | |
# %add_184 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_1, %select_scatter_default_2), kwargs = {}) | |
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_5 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_184, 0, 15), kwargs = {}) | |
# %convert_element_type_379 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_21, torch.float32), kwargs = {}) | |
# %select_scatter_default_8 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_379, 0, 1), kwargs = {}) | |
# %convert_element_type_380 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_22, torch.float32), kwargs = {}) | |
# %select_scatter_default_9 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_380, 0, 0), kwargs = {}) | |
# %add_197 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_8, %select_scatter_default_9), kwargs = {}) | |
# %select_scatter_default_12 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_197, 0, 14), kwargs = {}) | |
# %add_209 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_5, %select_scatter_default_12), kwargs = {}) | |
# %convert_element_type_411 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_29, torch.float32), kwargs = {}) | |
# %select_scatter_default_15 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_411, 0, 1), kwargs = {}) | |
# %convert_element_type_412 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_30, torch.float32), kwargs = {}) | |
# %select_scatter_default_16 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_412, 0, 0), kwargs = {}) | |
# %add_213 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_15, %select_scatter_default_16), kwargs = {}) | |
# %select_scatter_default_19 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_213, 0, 13), kwargs = {}) | |
# %add_225 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_209, %select_scatter_default_19), kwargs = {}) | |
# %convert_element_type_443 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_37, torch.float32), kwargs = {}) | |
# %select_scatter_default_22 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_443, 0, 0), kwargs = {}) | |
# %select_scatter_default_25 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_22, 0, 12), kwargs = {}) | |
# %add_240 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_225, %select_scatter_default_25), kwargs = {}) | |
# %convert_element_type_474 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_44, torch.float32), kwargs = {}) | |
# %select_scatter_default_28 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_474, 0, 0), kwargs = {}) | |
# %select_scatter_default_31 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_28, 0, 11), kwargs = {}) | |
# %add_255 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_240, %select_scatter_default_31), kwargs = {}) | |
# %convert_element_type_506 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_52, torch.float32), kwargs = {}) | |
# %select_scatter_default_35 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_506, 0, 0), kwargs = {}) | |
# %select_scatter_default_38 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_35, 0, 10), kwargs = {}) | |
# %add_270 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_255, %select_scatter_default_38), kwargs = {}) | |
# %convert_element_type_538 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_60, torch.float32), kwargs = {}) | |
# %select_scatter_default_42 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_538, 0, 0), kwargs = {}) | |
# %select_scatter_default_45 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_42, 0, 9), kwargs = {}) | |
# %add_286 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_270, %select_scatter_default_45), kwargs = {}) | |
# %convert_element_type_570 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_68, torch.float32), kwargs = {}) | |
# %select_scatter_default_49 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_570, 0, 0), kwargs = {}) | |
# %select_scatter_default_52 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_49, 0, 8), kwargs = {}) | |
# %add_302 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_286, %select_scatter_default_52), kwargs = {}) | |
# %convert_element_type_613 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_78, torch.float32), kwargs = {}) | |
# %select_scatter_default_58 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_613, 0, 0), kwargs = {}) | |
# %select_scatter_default_61 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_58, 0, 6), kwargs = {}) | |
# %add_323 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_302, %select_scatter_default_61), kwargs = {}) | |
# %convert_element_type_644 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_85, torch.float32), kwargs = {}) | |
# %select_scatter_default_64 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_644, 0, 0), kwargs = {}) | |
# %select_scatter_default_67 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_64, 0, 5), kwargs = {}) | |
# %add_339 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_323, %select_scatter_default_67), kwargs = {}) | |
# %convert_element_type_675 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_92, torch.float32), kwargs = {}) | |
# %select_scatter_default_70 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_675, 0, 0), kwargs = {}) | |
# %select_scatter_default_73 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_70, 0, 4), kwargs = {}) | |
# %add_354 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_339, %select_scatter_default_73), kwargs = {}) | |
# %convert_element_type_706 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_99, torch.float32), kwargs = {}) | |
# %select_scatter_default_76 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_706, 0, 0), kwargs = {}) | |
# %select_scatter_default_79 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_76, 0, 3), kwargs = {}) | |
# %add_370 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_354, %select_scatter_default_79), kwargs = {}) | |
triton_poi_fused__to_copy_add_select_backward_41 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_41', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 32}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 15, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_add_select_backward_41(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x1 = xindex // 2 | |
x0 = (xindex % 2) | |
x2 = xindex | |
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK]) | |
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK]) | |
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK]) | |
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32) | |
tmp26 = tl.broadcast_to(tmp25, [XBLOCK]) | |
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32) | |
tmp35 = tl.broadcast_to(tmp34, [XBLOCK]) | |
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32) | |
tmp39 = tl.broadcast_to(tmp38, [XBLOCK]) | |
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32) | |
tmp48 = tl.broadcast_to(tmp47, [XBLOCK]) | |
tmp55 = tl.load(in_ptr7 + (0)).to(tl.float32) | |
tmp56 = tl.broadcast_to(tmp55, [XBLOCK]) | |
tmp63 = tl.load(in_ptr8 + (0)).to(tl.float32) | |
tmp64 = tl.broadcast_to(tmp63, [XBLOCK]) | |
tmp71 = tl.load(in_ptr9 + (0)).to(tl.float32) | |
tmp72 = tl.broadcast_to(tmp71, [XBLOCK]) | |
tmp79 = tl.load(in_ptr10 + (0)).to(tl.float32) | |
tmp80 = tl.broadcast_to(tmp79, [XBLOCK]) | |
tmp87 = tl.load(in_ptr11 + (0)).to(tl.float32) | |
tmp88 = tl.broadcast_to(tmp87, [XBLOCK]) | |
tmp95 = tl.load(in_ptr12 + (0)).to(tl.float32) | |
tmp96 = tl.broadcast_to(tmp95, [XBLOCK]) | |
tmp103 = tl.load(in_ptr13 + (0)).to(tl.float32) | |
tmp104 = tl.broadcast_to(tmp103, [XBLOCK]) | |
tmp111 = tl.load(in_ptr14 + (0)).to(tl.float32) | |
tmp112 = tl.broadcast_to(tmp111, [XBLOCK]) | |
tmp0 = x1 | |
tmp1 = tl.full([1], 15, tl.int32) | |
tmp2 = tmp0 == tmp1 | |
tmp3 = x0 | |
tmp4 = tl.full([1], 1, tl.int32) | |
tmp5 = tmp3 == tmp4 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = 0.0 | |
tmp10 = tl.where(tmp5, tmp8, tmp9) | |
tmp11 = tl.full([1], 0, tl.int32) | |
tmp12 = tmp3 == tmp11 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tl.where(tmp12, tmp15, tmp9) | |
tmp17 = tmp10 + tmp16 | |
tmp18 = tl.where(tmp2, tmp17, tmp9) | |
tmp19 = tl.full([1], 14, tl.int32) | |
tmp20 = tmp0 == tmp19 | |
tmp23 = tmp22.to(tl.float32) | |
tmp24 = tl.where(tmp5, tmp23, tmp9) | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tl.where(tmp12, tmp27, tmp9) | |
tmp29 = tmp24 + tmp28 | |
tmp30 = tl.where(tmp20, tmp29, tmp9) | |
tmp31 = tmp18 + tmp30 | |
tmp32 = tl.full([1], 13, tl.int32) | |
tmp33 = tmp0 == tmp32 | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tl.where(tmp5, tmp36, tmp9) | |
tmp40 = tmp39.to(tl.float32) | |
tmp41 = tl.where(tmp12, tmp40, tmp9) | |
tmp42 = tmp37 + tmp41 | |
tmp43 = tl.where(tmp33, tmp42, tmp9) | |
tmp44 = tmp31 + tmp43 | |
tmp45 = tl.full([1], 12, tl.int32) | |
tmp46 = tmp0 == tmp45 | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tl.where(tmp12, tmp49, tmp9) | |
tmp51 = tl.where(tmp46, tmp50, tmp9) | |
tmp52 = tmp44 + tmp51 | |
tmp53 = tl.full([1], 11, tl.int32) | |
tmp54 = tmp0 == tmp53 | |
tmp57 = tmp56.to(tl.float32) | |
tmp58 = tl.where(tmp12, tmp57, tmp9) | |
tmp59 = tl.where(tmp54, tmp58, tmp9) | |
tmp60 = tmp52 + tmp59 | |
tmp61 = tl.full([1], 10, tl.int32) | |
tmp62 = tmp0 == tmp61 | |
tmp65 = tmp64.to(tl.float32) | |
tmp66 = tl.where(tmp12, tmp65, tmp9) | |
tmp67 = tl.where(tmp62, tmp66, tmp9) | |
tmp68 = tmp60 + tmp67 | |
tmp69 = tl.full([1], 9, tl.int32) | |
tmp70 = tmp0 == tmp69 | |
tmp73 = tmp72.to(tl.float32) | |
tmp74 = tl.where(tmp12, tmp73, tmp9) | |
tmp75 = tl.where(tmp70, tmp74, tmp9) | |
tmp76 = tmp68 + tmp75 | |
tmp77 = tl.full([1], 8, tl.int32) | |
tmp78 = tmp0 == tmp77 | |
tmp81 = tmp80.to(tl.float32) | |
tmp82 = tl.where(tmp12, tmp81, tmp9) | |
tmp83 = tl.where(tmp78, tmp82, tmp9) | |
tmp84 = tmp76 + tmp83 | |
tmp85 = tl.full([1], 6, tl.int32) | |
tmp86 = tmp0 == tmp85 | |
tmp89 = tmp88.to(tl.float32) | |
tmp90 = tl.where(tmp12, tmp89, tmp9) | |
tmp91 = tl.where(tmp86, tmp90, tmp9) | |
tmp92 = tmp84 + tmp91 | |
tmp93 = tl.full([1], 5, tl.int32) | |
tmp94 = tmp0 == tmp93 | |
tmp97 = tmp96.to(tl.float32) | |
tmp98 = tl.where(tmp12, tmp97, tmp9) | |
tmp99 = tl.where(tmp94, tmp98, tmp9) | |
tmp100 = tmp92 + tmp99 | |
tmp101 = tl.full([1], 4, tl.int32) | |
tmp102 = tmp0 == tmp101 | |
tmp105 = tmp104.to(tl.float32) | |
tmp106 = tl.where(tmp12, tmp105, tmp9) | |
tmp107 = tl.where(tmp102, tmp106, tmp9) | |
tmp108 = tmp100 + tmp107 | |
tmp109 = tl.full([1], 3, tl.int32) | |
tmp110 = tmp0 == tmp109 | |
tmp113 = tmp112.to(tl.float32) | |
tmp114 = tl.where(tmp12, tmp113, tmp9) | |
tmp115 = tl.where(tmp110, tmp114, tmp9) | |
tmp116 = tmp108 + tmp115 | |
tl.store(out_ptr0 + (x2), tmp116, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/o4/co4wxq5g7foyzmi2ihlzwhjvufk3vmnc4dgakkdaqoxojgejv4d2.py | |
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_7 => convert_element_type_62 | |
# Graph fragment: | |
# %mul_770 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_392, %select_20), kwargs = {}) | |
# %convert_element_type_62 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_32, torch.float32), kwargs = {}) | |
# %convert_element_type_739 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_770, torch.float32), kwargs = {}) | |
# %mul_772 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %convert_element_type_62), kwargs = {}) | |
# %mul_773 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %rsqrt_11), kwargs = {}) | |
# %sum_108 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_772, [3], True), kwargs = {}) | |
# %div_54 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_71, 128), kwargs = {}) | |
# %pow_200 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_62, 1.0), kwargs = {}) | |
# %mul_776 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_200, 2.0), kwargs = {}) | |
# %mul_777 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_54, %mul_776), kwargs = {}) | |
# %add_376 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_773, %mul_777), kwargs = {}) | |
# %convert_element_type_740 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_376, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_42 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_42', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_42(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (52)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (52)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/bo/cbowtcvxfb4pbtd63playhswlkgy4jhcpl4b3bcqkbz3mfy4xrzq.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {}) | |
# %mul_800 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_18), kwargs = {}) | |
triton_poi_fused_add_mul_43 = async_compile.triton('triton_poi_fused_add_mul_43', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_43(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (20)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/p3/cp356rh6ewfn6qtk7pjhx6degjloji6ndbxsbgpu5ffy373kln4d.py | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {}) | |
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {}) | |
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {}) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44 = async_compile.triton('triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 51463168 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/qe/cqetkj23jbj4oymf7m4upmjbom7qob3kerhbagz6hn5pzn6li6pu.py | |
# Topologically Sorted Source Nodes: [loss, v_4], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# v_4 => convert_element_type_42 | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {}) | |
# %mul_812 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_412, %select_13), kwargs = {}) | |
# %convert_element_type_42 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_23, torch.float32), kwargs = {}) | |
# %convert_element_type_771 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_812, torch.float32), kwargs = {}) | |
# %mul_814 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %convert_element_type_42), kwargs = {}) | |
# %mul_815 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %rsqrt_7), kwargs = {}) | |
# %sum_116 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_814, [3], True), kwargs = {}) | |
# %div_58 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_75, 128), kwargs = {}) | |
# %pow_209 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_42, 1.0), kwargs = {}) | |
# %mul_818 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_209, 2.0), kwargs = {}) | |
# %mul_819 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_58, %mul_818), kwargs = {}) | |
# %add_393 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_815, %mul_819), kwargs = {}) | |
# %convert_element_type_772 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_393, torch.bfloat16), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {}) | |
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {}) | |
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {}) | |
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45 = async_compile.triton('triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*i32', 'in_ptr5': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', 'mutated_arg_names': ['out_ptr2'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (50)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (50)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
tmp34 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last') | |
tmp44 = tl.load(in_ptr1 + (77)) | |
tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) | |
tmp48 = tl.load(in_ptr1 + (51)) | |
tmp49 = tl.broadcast_to(tmp48, [XBLOCK, R0_BLOCK]) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp43 = tl.load(in_ptr5 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tmp35 = tmp34.to(tl.int64) | |
tmp36 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32) | |
tmp37 = tmp35 + tmp36 | |
tmp38 = tmp35 < 0 | |
tmp39 = tl.where(tmp38, tmp37, tmp35) | |
tl.device_assert((0 <= tmp39) & (tmp39 < 50257), "index out of bounds: 0 <= tmp39 < 50257") | |
tmp41 = tl.full([1, 1], -1, tl.int64) | |
tmp42 = tmp35 == tmp41 | |
tmp46 = tmp45.to(tl.float32) | |
tmp47 = tmp43 * tmp46 | |
tmp50 = tmp49.to(tl.float32) | |
tmp51 = tmp12 * tmp50 | |
tmp52 = tmp47 + tmp51 | |
tmp53 = tmp52.to(tl.float32) | |
tmp54 = 0.0 | |
tmp55 = tl.where(tmp42, tmp54, tmp53) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
tl.atomic_add(out_ptr2 + (r0_2 + 128*x0 + 1024*tmp39), tmp55, r0_mask, sem='relaxed') | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/xd/cxde4n2e7lw45tanddyed35fd2lchomkxmgk6ds234dvjtxmqmgj.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {}) | |
# %mul_756 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_26), kwargs = {}) | |
# %add_367 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_352, %mul_756), kwargs = {}) | |
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {}) | |
# %mul_798 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_19), kwargs = {}) | |
# %mul_799 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %convert_element_type_11), kwargs = {}) | |
# %sum_111 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_799,), kwargs = {}) | |
# %add_385 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_367, %mul_798), kwargs = {}) | |
# %mul_801 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %add_24), kwargs = {}) | |
# %sum_112 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_801,), kwargs = {}) | |
# %add_400 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_390, %view_346), kwargs = {}) | |
# %mul_840 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %select_12), kwargs = {}) | |
# %mul_841 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %convert_element_type_11), kwargs = {}) | |
# %sum_119 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_841,), kwargs = {}) | |
# %add_402 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_385, %mul_840), kwargs = {}) | |
# %mul_842 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %select_11), kwargs = {}) | |
# %mul_843 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %add_12), kwargs = {}) | |
# %sum_120 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_843,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_46 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_46', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_46', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 15, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_46(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (23)) | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp9 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (21)) | |
tmp13 = tl.broadcast_to(tmp12, [R0_BLOCK]) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp20 = tl.load(in_ptr2 + (19)) | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp25 = tl.load(in_ptr2 + (18)) | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp29 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp31 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last') | |
tmp38 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp47 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tmp1 + tmp2 | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tmp3 * tmp6 | |
tmp8 = tmp0 + tmp7 | |
tmp11 = tmp9 + tmp10 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tmp11 * tmp14 | |
tmp16 = tmp8 + tmp15 | |
tmp19 = tmp17 + tmp18 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp19 * tmp22 | |
tmp24 = tmp16 + tmp23 | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tmp19 * tmp27 | |
tmp30 = tmp29.to(tl.float32) | |
tmp32 = tmp30 * tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tmp34 = tmp11 * tmp33 | |
tmp35 = tl.broadcast_to(tmp34, [R0_BLOCK]) | |
tmp37 = triton_helpers.promote_to_tensor(tl.sum(tmp35, 0)) | |
tmp39 = tmp11 * tmp38 | |
tmp40 = tl.broadcast_to(tmp39, [R0_BLOCK]) | |
tmp42 = triton_helpers.promote_to_tensor(tl.sum(tmp40, 0)) | |
tmp43 = tmp19 * tmp33 | |
tmp44 = tl.broadcast_to(tmp43, [R0_BLOCK]) | |
tmp46 = triton_helpers.promote_to_tensor(tl.sum(tmp44, 0)) | |
tmp48 = tmp19 * tmp47 | |
tmp49 = tl.broadcast_to(tmp48, [R0_BLOCK]) | |
tmp51 = triton_helpers.promote_to_tensor(tl.sum(tmp49, 0)) | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp24, None) | |
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp28, None) | |
tl.store(out_ptr1 + (x0), tmp37, None) | |
tl.store(out_ptr2 + (x0), tmp42, None) | |
tl.store(out_ptr3 + (x0), tmp46, None) | |
tl.store(out_ptr4 + (x0), tmp51, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/d6/cd6bxpx2rotlzolnhqlsypn4c3ofsxep56v7rmvfvdcg45tofh7f.py | |
# Topologically Sorted Source Nodes: [loss, v_1], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# v_1 => convert_element_type_22 | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %add_408 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_216, %view_355), kwargs = {}) | |
# %mul_854 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_432, %select_6), kwargs = {}) | |
# %convert_element_type_22 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_14, torch.float32), kwargs = {}) | |
# %convert_element_type_803 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_854, torch.float32), kwargs = {}) | |
# %mul_856 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_803, %convert_element_type_22), kwargs = {}) | |
# %mul_857 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_803, %rsqrt_3), kwargs = {}) | |
# %sum_124 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_856, [3], True), kwargs = {}) | |
# %div_62 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_79, 128), kwargs = {}) | |
# %pow_218 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_22, 1.0), kwargs = {}) | |
# %mul_860 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_218, 2.0), kwargs = {}) | |
# %mul_861 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_62, %mul_860), kwargs = {}) | |
# %add_410 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_857, %mul_861), kwargs = {}) | |
# %convert_element_type_804 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_410, torch.bfloat16), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_830 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_408, torch.float32), kwargs = {}) | |
# %where_27 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_830), kwargs = {}) | |
# %index_put_7 : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_169, [%convert_element_type_822], %where_27, True), kwargs = {}) | |
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47 = async_compile.triton('triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*i32', 'in_ptr5': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47', 'mutated_arg_names': ['out_ptr2'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (48)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (48)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
tmp34 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last') | |
tmp44 = tl.load(in_ptr1 + (75)) | |
tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) | |
tmp48 = tl.load(in_ptr1 + (49)) | |
tmp49 = tl.broadcast_to(tmp48, [XBLOCK, R0_BLOCK]) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp43 = tl.load(in_ptr5 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tmp35 = tmp34.to(tl.int64) | |
tmp36 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32) | |
tmp37 = tmp35 + tmp36 | |
tmp38 = tmp35 < 0 | |
tmp39 = tl.where(tmp38, tmp37, tmp35) | |
tl.device_assert((0 <= tmp39) & (tmp39 < 50257), "index out of bounds: 0 <= tmp39 < 50257") | |
tmp41 = tl.full([1, 1], -1, tl.int64) | |
tmp42 = tmp35 == tmp41 | |
tmp46 = tmp45.to(tl.float32) | |
tmp47 = tmp43 * tmp46 | |
tmp50 = tmp49.to(tl.float32) | |
tmp51 = tmp12 * tmp50 | |
tmp52 = tmp47 + tmp51 | |
tmp53 = tmp52.to(tl.float32) | |
tmp54 = 0.0 | |
tmp55 = tl.where(tmp42, tmp54, tmp53) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
tl.atomic_add(out_ptr2 + (r0_2 + 128*x0 + 1024*tmp39), tmp55, r0_mask, sem='relaxed') | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ku/cku3y346hkc2rt3ol4ul22b7a2hr3nvz4j76utjfqol6py33tpfs.py | |
# Topologically Sorted Source Nodes: [loss, x], Original ATen: [aten.nll_loss_forward, aten._to_copy, aten.mul, aten.add, aten.sum, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_417 : [num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_407, %view_358), kwargs = {}) | |
# %mul_882 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_417, %select_5), kwargs = {}) | |
# %mul_883 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_417, %convert_element_type_11), kwargs = {}) | |
# %sum_127 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_883,), kwargs = {}) | |
# %add_419 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_402, %mul_882), kwargs = {}) | |
# %mul_884 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_417, %select_4), kwargs = {}) | |
# %add_420 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_419, %mul_884), kwargs = {}) | |
# %convert_element_type_819 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_420, torch.float32), kwargs = {}) | |
# %mul_886 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_819, %convert_element_type_10), kwargs = {}) | |
# %sum_129 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_886, [2], True), kwargs = {}) | |
# %convert_element_type_821 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%squeeze_1, torch.float32), kwargs = {}) | |
# %where_24 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_821), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %index_put_4 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_24, True), kwargs = {}) | |
triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48 = async_compile.triton('triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*i32', 'out_ptr1': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 8, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr1, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp4 = tl.load(in_ptr3 + (17)) | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp9 = tl.load(in_ptr3 + (16)) | |
tmp10 = tl.broadcast_to(tmp9, [R0_BLOCK]) | |
tmp15 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last') | |
tmp28 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last') | |
tmp3 = tmp1 + tmp2 | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tmp3 * tmp6 | |
tmp8 = tmp0 + tmp7 | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = tmp3 * tmp11 | |
tmp13 = tmp8 + tmp12 | |
tmp14 = tmp13.to(tl.float32) | |
tmp16 = tmp15.to(tl.float32) | |
tmp18 = tmp16 * tmp17 | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = tmp3 * tmp19 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp24 = tmp14 * tmp16 | |
tmp25 = tl.broadcast_to(tmp24, [R0_BLOCK]) | |
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0)) | |
tmp29 = tmp28.to(tl.int64) | |
tmp30 = tl.full([R0_BLOCK], 50257, tl.int32) | |
tmp31 = tmp29 + tmp30 | |
tmp32 = tmp29 < 0 | |
tmp33 = tl.where(tmp32, tmp31, tmp29) | |
tl.device_assert((0 <= tmp33) & (tmp33 < 50257), "index out of bounds: 0 <= tmp33 < 50257") | |
tmp35 = tl.full([1], -1, tl.int64) | |
tmp36 = tmp29 == tmp35 | |
tmp37 = tmp14 * tmp17 | |
tmp38 = -0.5 | |
tmp39 = tmp27 * tmp38 | |
tmp40 = tmp17 * tmp17 | |
tmp41 = tmp40 * tmp17 | |
tmp42 = tmp39 * tmp41 | |
tmp43 = 0.0009765625 | |
tmp44 = tmp42 * tmp43 | |
tmp45 = 2.0 | |
tmp46 = tmp16 * tmp45 | |
tmp47 = tmp44 * tmp46 | |
tmp48 = tmp37 + tmp47 | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tmp49.to(tl.float32) | |
tmp51 = 0.0 | |
tmp52 = tl.where(tmp36, tmp51, tmp50) | |
tl.atomic_add(out_ptr3 + (tl.broadcast_to(r0_1 + 1024*tmp33, [R0_BLOCK])), tmp52, None, sem='relaxed') | |
tl.store(out_ptr1 + (x0), tmp23, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/lx/clxwsbf6ichdepzkafcu43p6rfa33ojvhrn7mlshq6ooqw2t4b5b.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add, aten.slice_backward] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %convert_element_type_491 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_50, torch.float32), kwargs = {}) | |
# %full_default_65 : [num_users=3] = call_function[target=torch.ops.aten.full.default](args = ([16], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_33 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_65, %convert_element_type_491, 0, 2), kwargs = {}) | |
# %convert_element_type_523 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_58, torch.float32), kwargs = {}) | |
# %select_scatter_default_40 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_65, %convert_element_type_523, 0, 4), kwargs = {}) | |
# %add_272 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_33, %select_scatter_default_40), kwargs = {}) | |
# %convert_element_type_555 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_66, torch.float32), kwargs = {}) | |
# %select_scatter_default_47 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_65, %convert_element_type_555, 0, 6), kwargs = {}) | |
# %add_288 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_272, %select_scatter_default_47), kwargs = {}) | |
# %full_default_165 : [num_users=3] = call_function[target=torch.ops.aten.full.default](args = ([80], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %slice_scatter_default_15 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_165, %view_360, 0, 48, 80), kwargs = {}) | |
# %slice_scatter_default_16 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_165, %view_361, 0, 16, 48), kwargs = {}) | |
# %add_424 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_15, %slice_scatter_default_16), kwargs = {}) | |
# %slice_scatter_default_17 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_165, %add_288, 0, 0, 16), kwargs = {}) | |
# %add_425 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_424, %slice_scatter_default_17), kwargs = {}) | |
triton_poi_fused__to_copy_add_select_backward_slice_backward_49 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_slice_backward_49', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 128}, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'in_ptr15': '*bf16', 'in_ptr16': '*bf16', 'in_ptr17': '*bf16', 'in_ptr18': '*bf16', 'in_ptr19': '*bf16', 'in_ptr20': '*bf16', 'in_ptr21': '*bf16', 'in_ptr22': '*bf16', 'in_ptr23': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_slice_backward_49', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 24, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_add_select_backward_slice_backward_49(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, in_ptr18, in_ptr19, in_ptr20, in_ptr21, in_ptr22, in_ptr23, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 80 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = x0 | |
tmp1 = tl.full([1], 48, tl.int64) | |
tmp2 = tmp0 >= tmp1 | |
tmp3 = tl.load(in_ptr0 + (2*(((((-48) + x0) // 2) % 16)) + ((x0 % 2))), xmask & tmp2, other=0.0) | |
tmp4 = ((((-48) + x0) // 2) % 16) | |
tmp5 = tl.full([1], 2, tl.int32) | |
tmp6 = tmp4 == tmp5 | |
tmp7 = (x0 % 2) | |
tmp8 = tl.full([1], 1, tl.int32) | |
tmp9 = tmp7 == tmp8 | |
tmp10 = tl.load(in_ptr1 + (0)).to(tl.float32) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK]) | |
tmp12 = tl.where(tmp2, tmp11, 0.0) | |
tmp13 = tmp12.to(tl.float32) | |
tmp14 = 0.0 | |
tmp15 = tl.where(tmp9, tmp13, tmp14) | |
tmp16 = tl.full([1], 0, tl.int32) | |
tmp17 = tmp7 == tmp16 | |
tmp18 = tl.load(in_ptr2 + (0)).to(tl.float32) | |
tmp19 = tl.broadcast_to(tmp18, [XBLOCK]) | |
tmp20 = tl.where(tmp2, tmp19, 0.0) | |
tmp21 = tmp20.to(tl.float32) | |
tmp22 = tl.where(tmp17, tmp21, tmp14) | |
tmp23 = tmp15 + tmp22 | |
tmp24 = tl.where(tmp6, tmp23, tmp14) | |
tmp25 = tmp3 + tmp24 | |
tmp26 = tmp4 == tmp8 | |
tmp27 = tl.load(in_ptr3 + (0)).to(tl.float32) | |
tmp28 = tl.broadcast_to(tmp27, [XBLOCK]) | |
tmp29 = tl.where(tmp2, tmp28, 0.0) | |
tmp30 = tmp29.to(tl.float32) | |
tmp31 = tl.where(tmp9, tmp30, tmp14) | |
tmp32 = tl.load(in_ptr4 + (0)).to(tl.float32) | |
tmp33 = tl.broadcast_to(tmp32, [XBLOCK]) | |
tmp34 = tl.where(tmp2, tmp33, 0.0) | |
tmp35 = tmp34.to(tl.float32) | |
tmp36 = tl.where(tmp17, tmp35, tmp14) | |
tmp37 = tmp31 + tmp36 | |
tmp38 = tl.where(tmp26, tmp37, tmp14) | |
tmp39 = tmp25 + tmp38 | |
tmp40 = tmp4 == tmp16 | |
tmp41 = tl.load(in_ptr5 + (0)).to(tl.float32) | |
tmp42 = tl.broadcast_to(tmp41, [XBLOCK]) | |
tmp43 = tl.where(tmp2, tmp42, 0.0) | |
tmp44 = tmp43.to(tl.float32) | |
tmp45 = tl.where(tmp9, tmp44, tmp14) | |
tmp46 = tl.load(in_ptr6 + (0)).to(tl.float32) | |
tmp47 = tl.broadcast_to(tmp46, [XBLOCK]) | |
tmp48 = tl.where(tmp2, tmp47, 0.0) | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tl.where(tmp17, tmp49, tmp14) | |
tmp51 = tmp45 + tmp50 | |
tmp52 = tl.where(tmp40, tmp51, tmp14) | |
tmp53 = tmp39 + tmp52 | |
tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype) | |
tmp55 = tl.where(tmp2, tmp53, tmp54) | |
tmp56 = 0.0 | |
tmp57 = tl.where(tmp2, tmp55, tmp56) | |
tmp58 = tl.full([1], 16, tl.int64) | |
tmp59 = tmp0 >= tmp58 | |
tmp60 = tmp0 < tmp1 | |
tmp61 = tmp59 & tmp60 | |
tmp62 = tl.load(in_ptr7 + (2*(((((-16) + x0) // 2) % 16)) + ((x0 % 2))), xmask & tmp61, other=0.0) | |
tmp63 = ((((-16) + x0) // 2) % 16) | |
tmp64 = tl.full([1], 6, tl.int32) | |
tmp65 = tmp63 == tmp64 | |
tmp66 = (x0 % 2) | |
tmp67 = tl.full([1], 1, tl.int32) | |
tmp68 = tmp66 == tmp67 | |
tmp69 = tl.load(in_ptr8 + (0)).to(tl.float32) | |
tmp70 = tl.broadcast_to(tmp69, [XBLOCK]) | |
tmp71 = tl.where(tmp61, tmp70, 0.0) | |
tmp72 = tmp71.to(tl.float32) | |
tmp73 = 0.0 | |
tmp74 = tl.where(tmp68, tmp72, tmp73) | |
tmp75 = tl.full([1], 0, tl.int32) | |
tmp76 = tmp66 == tmp75 | |
tmp77 = tl.load(in_ptr9 + (0)).to(tl.float32) | |
tmp78 = tl.broadcast_to(tmp77, [XBLOCK]) | |
tmp79 = tl.where(tmp61, tmp78, 0.0) | |
tmp80 = tmp79.to(tl.float32) | |
tmp81 = tl.where(tmp76, tmp80, tmp73) | |
tmp82 = tmp74 + tmp81 | |
tmp83 = tl.where(tmp65, tmp82, tmp73) | |
tmp84 = tmp62 + tmp83 | |
tmp85 = tl.full([1], 5, tl.int32) | |
tmp86 = tmp63 == tmp85 | |
tmp87 = tl.load(in_ptr10 + (0)).to(tl.float32) | |
tmp88 = tl.broadcast_to(tmp87, [XBLOCK]) | |
tmp89 = tl.where(tmp61, tmp88, 0.0) | |
tmp90 = tmp89.to(tl.float32) | |
tmp91 = tl.where(tmp68, tmp90, tmp73) | |
tmp92 = tl.load(in_ptr11 + (0)).to(tl.float32) | |
tmp93 = tl.broadcast_to(tmp92, [XBLOCK]) | |
tmp94 = tl.where(tmp61, tmp93, 0.0) | |
tmp95 = tmp94.to(tl.float32) | |
tmp96 = tl.where(tmp76, tmp95, tmp73) | |
tmp97 = tmp91 + tmp96 | |
tmp98 = tl.where(tmp86, tmp97, tmp73) | |
tmp99 = tmp84 + tmp98 | |
tmp100 = tl.full([1], 4, tl.int32) | |
tmp101 = tmp63 == tmp100 | |
tmp102 = tl.load(in_ptr12 + (0)).to(tl.float32) | |
tmp103 = tl.broadcast_to(tmp102, [XBLOCK]) | |
tmp104 = tl.where(tmp61, tmp103, 0.0) | |
tmp105 = tmp104.to(tl.float32) | |
tmp106 = tl.where(tmp68, tmp105, tmp73) | |
tmp107 = tl.load(in_ptr13 + (0)).to(tl.float32) | |
tmp108 = tl.broadcast_to(tmp107, [XBLOCK]) | |
tmp109 = tl.where(tmp61, tmp108, 0.0) | |
tmp110 = tmp109.to(tl.float32) | |
tmp111 = tl.where(tmp76, tmp110, tmp73) | |
tmp112 = tmp106 + tmp111 | |
tmp113 = tl.where(tmp101, tmp112, tmp73) | |
tmp114 = tmp99 + tmp113 | |
tmp115 = tl.full([1], 3, tl.int32) | |
tmp116 = tmp63 == tmp115 | |
tmp117 = tl.load(in_ptr14 + (0)).to(tl.float32) | |
tmp118 = tl.broadcast_to(tmp117, [XBLOCK]) | |
tmp119 = tl.where(tmp61, tmp118, 0.0) | |
tmp120 = tmp119.to(tl.float32) | |
tmp121 = tl.where(tmp68, tmp120, tmp73) | |
tmp122 = tl.load(in_ptr15 + (0)).to(tl.float32) | |
tmp123 = tl.broadcast_to(tmp122, [XBLOCK]) | |
tmp124 = tl.where(tmp61, tmp123, 0.0) | |
tmp125 = tmp124.to(tl.float32) | |
tmp126 = tl.where(tmp76, tmp125, tmp73) | |
tmp127 = tmp121 + tmp126 | |
tmp128 = tl.where(tmp116, tmp127, tmp73) | |
tmp129 = tmp114 + tmp128 | |
tmp130 = tl.full([1], 2, tl.int32) | |
tmp131 = tmp63 == tmp130 | |
tmp132 = tl.load(in_ptr16 + (0)).to(tl.float32) | |
tmp133 = tl.broadcast_to(tmp132, [XBLOCK]) | |
tmp134 = tl.where(tmp61, tmp133, 0.0) | |
tmp135 = tmp134.to(tl.float32) | |
tmp136 = tl.where(tmp68, tmp135, tmp73) | |
tmp137 = tl.load(in_ptr17 + (0)).to(tl.float32) | |
tmp138 = tl.broadcast_to(tmp137, [XBLOCK]) | |
tmp139 = tl.where(tmp61, tmp138, 0.0) | |
tmp140 = tmp139.to(tl.float32) | |
tmp141 = tl.where(tmp76, tmp140, tmp73) | |
tmp142 = tmp136 + tmp141 | |
tmp143 = tl.where(tmp131, tmp142, tmp73) | |
tmp144 = tmp129 + tmp143 | |
tmp145 = tmp63 == tmp67 | |
tmp146 = tl.load(in_ptr18 + (0)).to(tl.float32) | |
tmp147 = tl.broadcast_to(tmp146, [XBLOCK]) | |
tmp148 = tl.where(tmp61, tmp147, 0.0) | |
tmp149 = tmp148.to(tl.float32) | |
tmp150 = tl.where(tmp68, tmp149, tmp73) | |
tmp151 = tl.load(in_ptr19 + (0)).to(tl.float32) | |
tmp152 = tl.broadcast_to(tmp151, [XBLOCK]) | |
tmp153 = tl.where(tmp61, tmp152, 0.0) | |
tmp154 = tmp153.to(tl.float32) | |
tmp155 = tl.where(tmp76, tmp154, tmp73) | |
tmp156 = tmp150 + tmp155 | |
tmp157 = tl.where(tmp145, tmp156, tmp73) | |
tmp158 = tmp144 + tmp157 | |
tmp159 = tmp63 == tmp75 | |
tmp160 = tl.load(in_ptr20 + (0)).to(tl.float32) | |
tmp161 = tl.broadcast_to(tmp160, [XBLOCK]) | |
tmp162 = tl.where(tmp61, tmp161, 0.0) | |
tmp163 = tmp162.to(tl.float32) | |
tmp164 = tl.where(tmp68, tmp163, tmp73) | |
tmp165 = tl.where(tmp76, tmp163, tmp73) | |
tmp166 = tmp164 + tmp165 | |
tmp167 = tl.where(tmp159, tmp166, tmp73) | |
tmp168 = tmp158 + tmp167 | |
tmp169 = tl.full(tmp168.shape, 0.0, tmp168.dtype) | |
tmp170 = tl.where(tmp61, tmp168, tmp169) | |
tmp171 = tl.where(tmp61, tmp170, tmp56) | |
tmp172 = tmp57 + tmp171 | |
tmp173 = tmp0 < tmp58 | |
tmp174 = x0 | |
tmp175 = tl.full([1], 2, tl.int32) | |
tmp176 = tmp174 == tmp175 | |
tmp177 = tl.load(in_ptr21 + (0)).to(tl.float32) | |
tmp178 = tl.broadcast_to(tmp177, [XBLOCK]) | |
tmp179 = tl.where(tmp173, tmp178, 0.0) | |
tmp180 = tmp179.to(tl.float32) | |
tmp181 = 0.0 | |
tmp182 = tl.where(tmp176, tmp180, tmp181) | |
tmp183 = tl.full([1], 4, tl.int32) | |
tmp184 = tmp174 == tmp183 | |
tmp185 = tl.load(in_ptr22 + (0)).to(tl.float32) | |
tmp186 = tl.broadcast_to(tmp185, [XBLOCK]) | |
tmp187 = tl.where(tmp173, tmp186, 0.0) | |
tmp188 = tmp187.to(tl.float32) | |
tmp189 = tl.where(tmp184, tmp188, tmp181) | |
tmp190 = tmp182 + tmp189 | |
tmp191 = tl.full([1], 6, tl.int32) | |
tmp192 = tmp174 == tmp191 | |
tmp193 = tl.load(in_ptr23 + (0)).to(tl.float32) | |
tmp194 = tl.broadcast_to(tmp193, [XBLOCK]) | |
tmp195 = tl.where(tmp173, tmp194, 0.0) | |
tmp196 = tmp195.to(tl.float32) | |
tmp197 = tl.where(tmp192, tmp196, tmp181) | |
tmp198 = tmp190 + tmp197 | |
tmp199 = tl.full(tmp198.shape, 0.0, tmp198.dtype) | |
tmp200 = tl.where(tmp173, tmp198, tmp199) | |
tmp201 = tl.where(tmp173, tmp200, tmp56) | |
tmp202 = tmp172 + tmp201 | |
tl.store(in_out_ptr0 + (x0), tmp202, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/xa/cxacj2l5hljhfhgtubky2uytsytumvxkk5obolbxfdymabkc3jde.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %convert_element_type_823 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%index_put_4, torch.bfloat16), kwargs = {}) | |
triton_poi_fused_embedding_dense_backward_50 = async_compile.triton('triton_poi_fused_embedding_dense_backward_50', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_50', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_embedding_dense_backward_50(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 51463168 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), xmask) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/kg/ckgmdoifhue6qhtyygsdiopbhe342puvu34bhdcnkwnjowz6irqm.py | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %add_374 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_192, %view_331), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_824 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_374, torch.float32), kwargs = {}) | |
# %where_25 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_824), kwargs = {}) | |
# %index_put_5 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_25, True), kwargs = {}) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51 = async_compile.triton('triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51', ''' | |
import triton | |
import triton.language as tl | |
from triton.compiler.compiler import AttrsDescriptor | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x1 = xindex // 1024 | |
x2 = xindex | |
x0 = (xindex % 1024) | |
tmp0 = tl.load(in_ptr0 + (x1), None, eviction_policy='evict_last') | |
tmp9 = tl.load(in_ptr1 + (x2), None).to(tl.float32) | |
tmp10 = tl.load(in_ptr2 + (79)) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK]) | |
tmp14 = tl.load(in_ptr3 + (x2), None).to(tl.float32) | |
tmp15 = tl.load(in_ptr2 + (53)) | |
tmp16 = tl.broadcast_to(tmp15, [XBLOCK]) | |
tmp1 = tmp0.to(tl.int64) | |
tmp2 = tl.full([XBLOCK], 50257, tl.int32) | |
tmp3 = tmp1 + tmp2 | |
tmp4 = tmp1 < 0 | |
tmp5 = tl.where(tmp4, tmp3, tmp1) | |
tl.device_assert((0 <= tmp5) & (tmp5 < 50257), "index out of bounds: 0 <= tmp5 < 50257") | |
tmp7 = tl.full([1], -1, tl.int64) | |
tmp8 = tmp1 == tmp7 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp9 * tmp12 | |
tmp17 = tmp16.to(tl.float32) | |
tmp18 = tmp14 * tmp17 | |
tmp19 = tmp13 + tmp18 | |
tmp20 = tmp19.to(tl.float32) | |
tmp21 = 0.0 | |
tmp22 = tl.where(tmp8, tmp21, tmp20) | |
tl.atomic_add(out_ptr0 + (x0 + 1024*tmp5), tmp22, None, sem='relaxed') | |
''', device_str='cuda') | |
async_compile.wait(globals()) | |
del async_compile | |
def call(args): | |
primals_1, primals_7, primals_86, embedding, embedding_1, embedding_2, cumsum, unsqueeze_9, unsqueeze_13, clamp_max, clamp_max_1, convert_element_type_2, clone_4, convert_element_type_4, clone_7, clamp_max_2, clamp_max_3, convert_element_type_6, clone_10, convert_element_type_8, clone_13, embedding_3, rsqrt, view_9, getitem_12, getitem_13, getitem_14, rsqrt_1, rsqrt_2, unsqueeze_44, unsqueeze_46, rsqrt_3, permute_5, permute_6, permute_7, getitem_19, getitem_20, add_10, rsqrt_4, view_16, mm_2, view_18, add_12, view_21, getitem_21, getitem_22, getitem_23, rsqrt_5, rsqrt_6, unsqueeze_52, unsqueeze_54, rsqrt_7, permute_13, permute_14, permute_15, getitem_28, getitem_29, add_22, rsqrt_8, view_28, mm_6, view_30, add_24, view_33, getitem_30, getitem_31, getitem_32, rsqrt_9, rsqrt_10, unsqueeze_60, unsqueeze_62, rsqrt_11, permute_21, permute_22, permute_23, getitem_37, getitem_38, add_34, rsqrt_12, view_40, mm_10, view_42, add_36, view_45, getitem_39, getitem_40, getitem_41, rsqrt_13, rsqrt_14, unsqueeze_68, unsqueeze_70, rsqrt_15, permute_29, permute_30, permute_31, getitem_46, getitem_47, add_45, rsqrt_16, view_51, mm_14, view_53, add_47, view_56, getitem_48, getitem_49, getitem_50, rsqrt_17, rsqrt_18, unsqueeze_76, unsqueeze_78, rsqrt_19, permute_37, permute_38, permute_39, getitem_55, getitem_56, add_56, rsqrt_20, view_62, mm_18, view_64, add_58, view_67, getitem_57, getitem_58, getitem_59, rsqrt_21, rsqrt_22, unsqueeze_84, unsqueeze_86, rsqrt_23, permute_45, permute_46, permute_47, getitem_64, getitem_65, add_67, rsqrt_24, view_73, mm_22, view_75, add_69, view_78, getitem_66, getitem_67, getitem_68, rsqrt_25, rsqrt_26, unsqueeze_92, unsqueeze_94, rsqrt_27, permute_53, permute_54, permute_55, getitem_73, getitem_74, add_78, rsqrt_28, view_84, mm_26, view_86, add_80, add_81, rsqrt_29, view_88, mm_28, view_90, add_83, view_93, getitem_75, getitem_76, getitem_77, rsqrt_30, rsqrt_31, unsqueeze_100, unsqueeze_102, rsqrt_32, permute_63, permute_64, permute_65, getitem_82, getitem_83, add_92, rsqrt_33, view_99, mm_32, view_101, add_95, view_104, getitem_84, getitem_85, getitem_86, rsqrt_34, rsqrt_35, unsqueeze_108, unsqueeze_110, rsqrt_36, permute_71, permute_72, permute_73, getitem_91, getitem_92, add_104, rsqrt_37, view_110, mm_36, view_112, add_107, view_115, getitem_93, getitem_94, getitem_95, rsqrt_38, rsqrt_39, unsqueeze_116, unsqueeze_118, rsqrt_40, permute_79, permute_80, permute_81, getitem_100, getitem_101, add_116, rsqrt_41, view_121, mm_40, view_123, add_119, view_126, getitem_102, getitem_103, getitem_104, rsqrt_42, rsqrt_43, unsqueeze_124, unsqueeze_126, rsqrt_44, permute_87, permute_88, permute_89, getitem_109, getitem_110, add_128, rsqrt_45, view_132, mm_44, view_134, add_130, view_137, getitem_111, getitem_112, getitem_113, rsqrt_46, rsqrt_47, unsqueeze_132, unsqueeze_134, rsqrt_48, permute_95, permute_96, permute_97, getitem_118, getitem_119, add_139, rsqrt_49, view_143, mm_48, view_145, add_141, view_148, getitem_120, getitem_121, getitem_122, rsqrt_50, rsqrt_51, unsqueeze_140, unsqueeze_142, rsqrt_52, permute_103, permute_104, permute_105, getitem_127, getitem_128, add_151, rsqrt_53, view_155, mm_52, view_157, add_153, view_160, getitem_129, getitem_130, getitem_131, rsqrt_54, rsqrt_55, unsqueeze_148, unsqueeze_150, rsqrt_56, permute_111, permute_112, permute_113, getitem_136, getitem_137, add_163, rsqrt_57, view_167, mm_56, view_169, add_165, view_172, getitem_138, getitem_139, getitem_140, rsqrt_58, rsqrt_59, unsqueeze_156, unsqueeze_158, rsqrt_60, permute_119, permute_120, permute_121, getitem_145, getitem_146, add_175, rsqrt_61, view_179, mm_60, view_181, add_177, rsqrt_62, view_183, mm_62, amax, log, convert_element_type_324, permute_129, permute_133, permute_137, permute_141, permute_149, permute_153, permute_157, permute_161, permute_169, permute_173, permute_177, permute_181, permute_189, permute_193, permute_197, permute_201, permute_209, permute_213, permute_217, permute_221, permute_229, permute_233, permute_237, permute_241, permute_249, permute_253, permute_257, permute_261, permute_269, permute_273, permute_277, permute_281, permute_289, permute_293, permute_297, permute_301, permute_305, permute_309, permute_317, permute_321, permute_325, permute_329, permute_337, permute_341, permute_345, permute_349, permute_357, permute_361, permute_365, permute_369, permute_377, permute_381, permute_385, permute_389, permute_397, permute_401, permute_405, permute_409, permute_417, permute_421, permute_425, permute_429, permute_437, tangents_1 = args | |
args.clear() | |
assert_size_stride(primals_1, (65536, ), (1, )) | |
assert_size_stride(primals_7, (80, ), (1, )) | |
assert_size_stride(primals_86, (65536, ), (1, )) | |
assert_size_stride(embedding, (65536, 1024), (1024, 1)) | |
assert_size_stride(embedding_1, (65536, 1024), (1024, 1)) | |
assert_size_stride(embedding_2, (65536, 1024), (1024, 1)) | |
assert_size_stride(cumsum, (65536, ), (1, )) | |
assert_size_stride(unsqueeze_9, (1, 1, 512, 512), (262144, 262144, 512, 1)) | |
assert_size_stride(unsqueeze_13, (1, 1, 512, 512), (262144, 262144, 512, 1)) | |
assert_size_stride(clamp_max, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(clamp_max_1, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(convert_element_type_2, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(clone_4, (1, 1, 512, 512), (262144, 262144, 512, 1)) | |
assert_size_stride(convert_element_type_4, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(clone_7, (1, 1, 512, 512), (262144, 262144, 512, 1)) | |
assert_size_stride(clamp_max_2, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(clamp_max_3, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(convert_element_type_6, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(clone_10, (1, 1, 512, 512), (262144, 262144, 512, 1)) | |
assert_size_stride(convert_element_type_8, (1, 1, 512), (512, 512, 1)) | |
assert_size_stride(clone_13, (1, 1, 512, 512), (262144, 262144, 512, 1)) | |
assert_size_stride(embedding_3, (65536, 1024), (1024, 1)) | |
assert_size_stride(rsqrt, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_9, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_12, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_13, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_14, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_1, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_2, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_44, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_46, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_3, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_5, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_6, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_7, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_19, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_20, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_10, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_4, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_16, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_2, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_18, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_12, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_21, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_21, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_22, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_23, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_5, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_6, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_52, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_54, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_7, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_13, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_14, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_15, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_28, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_29, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_22, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_8, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_28, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_6, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_30, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_24, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_33, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_30, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_31, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_32, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_9, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_10, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_60, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_62, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_11, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_21, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_22, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_23, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_37, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_38, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_34, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_12, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_40, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_10, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_42, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_36, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_45, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_39, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_40, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_41, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_13, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_14, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_68, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_70, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_15, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_29, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_30, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_31, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_46, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_47, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_45, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_16, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_51, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_14, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_53, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_47, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_56, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_48, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_49, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_50, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_17, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_18, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_76, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_78, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_19, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_37, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_38, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_39, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_55, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_56, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_56, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_20, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_62, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_18, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_64, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_58, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_67, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_57, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_58, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_59, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_21, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_22, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_84, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_86, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_23, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_45, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_46, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_47, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_64, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_65, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_67, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_24, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_73, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_22, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_75, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_69, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_78, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_66, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_67, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_68, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_25, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_26, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_92, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_94, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_27, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_53, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_54, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_55, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_73, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_74, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_78, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_28, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_84, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_26, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_86, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_80, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(add_81, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_29, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_88, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_28, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_90, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_83, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_93, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_75, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_76, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_77, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_30, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_31, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_100, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_102, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_32, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_63, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_64, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_65, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_82, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_83, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_92, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_33, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_99, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_32, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_101, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_95, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_104, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_84, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_85, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_86, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_34, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_35, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_108, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_110, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_36, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_71, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_72, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_73, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_91, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_92, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_104, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_37, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_110, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_36, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_112, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_107, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_115, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_93, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_94, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_95, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_38, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_39, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_116, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_118, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_40, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_79, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_80, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_81, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_100, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_101, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_116, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_41, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_121, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_40, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_123, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_119, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_126, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_102, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_103, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_104, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_42, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_43, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_124, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_126, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_44, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_87, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_88, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_89, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_109, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_110, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_128, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_45, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_132, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_44, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_134, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_130, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_137, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_111, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_112, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_113, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_46, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_47, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_132, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_134, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_48, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_95, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_96, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_97, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_118, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_119, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_139, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_49, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_143, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_48, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_145, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_141, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_148, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_120, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_121, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_122, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_50, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_51, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_140, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_142, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_52, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_103, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_104, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_105, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_127, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_128, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_151, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_53, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_155, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_52, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_157, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_153, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_160, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_129, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_130, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_131, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_54, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_55, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_148, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_150, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_56, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_111, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_112, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_113, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_136, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_137, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_163, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_57, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_167, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_56, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_169, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_165, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(view_172, (65536, 1024), (1024, 1)) | |
assert_size_stride(getitem_138, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_139, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(getitem_140, (1, 65536, 8, 128), (201326592, 3072, 128, 1)) | |
assert_size_stride(rsqrt_58, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(rsqrt_59, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(unsqueeze_156, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(unsqueeze_158, (1, 65536, 1, 64), (16777216, 64, 64, 1)) | |
assert_size_stride(rsqrt_60, (1, 65536, 8, 1), (524288, 8, 1, 1)) | |
assert_size_stride(permute_119, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_120, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(permute_121, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_145, (1, 8, 65536, 128), (67108864, 128, 1024, 1)) | |
assert_size_stride(getitem_146, (1, 8, 65536), (524288, 65536, 1)) | |
assert_size_stride(add_175, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_61, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_179, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_60, (65536, 4096), (4096, 1)) | |
assert_size_stride(view_181, (65536, 4096), (4096, 1)) | |
assert_size_stride(add_177, (1, 65536, 1024), (67108864, 1024, 1)) | |
assert_size_stride(rsqrt_62, (1, 65536, 1), (65536, 1, 1)) | |
assert_size_stride(view_183, (65536, 1024), (1024, 1)) | |
assert_size_stride(mm_62, (65536, 50304), (50304, 1)) | |
assert_size_stride(amax, (65536, 1), (1, 1)) | |
assert_size_stride(log, (65536, 1), (1, 1)) | |
assert_size_stride(convert_element_type_324, (), ()) | |
assert_size_stride(permute_129, (50304, 1024), (1024, 1)) | |
assert_size_stride(permute_133, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_137, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_141, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_149, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_153, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_157, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_161, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_169, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_173, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_177, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_181, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_189, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_193, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_197, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_201, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_209, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_213, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_217, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_221, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_229, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_233, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_237, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_241, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_249, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_253, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_257, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_261, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_269, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_273, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_277, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_281, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_289, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_293, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_297, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_301, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_305, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_309, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_317, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_321, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_325, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_329, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_337, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_341, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_345, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_349, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_357, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_361, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_365, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_369, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_377, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_381, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_385, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_389, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_397, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_401, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_405, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_409, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_417, (3072, 1024), (1024, 1)) | |
assert_size_stride(permute_421, (1024, 4096), (4096, 1)) | |
assert_size_stride(permute_425, (4096, 1024), (1024, 1)) | |
assert_size_stride(permute_429, (1024, 1024), (1024, 1)) | |
assert_size_stride(permute_437, (3072, 1024), (1024, 1)) | |
assert_size_stride(tangents_1, (), ()) | |
with torch.cuda._DeviceGuard(0): | |
torch.cuda.set_device(0) | |
buf2 = mm_62; del mm_62 # reuse | |
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0.run(buf2, primals_86, tangents_1, convert_element_type_324, amax, log, 65536, 50304, grid=grid(65536), stream=stream0) | |
del amax | |
del convert_element_type_324 | |
del log | |
del primals_86 | |
del tangents_1 | |
buf3 = empty_strided_cuda((50304, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf2, (50304, 65536), (1, 50304), 0), view_183, out=buf3) | |
del view_183 | |
buf4 = empty_strided_cuda((65536, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(buf2, permute_129, out=buf4) | |
del buf2 | |
del permute_129 | |
buf5 = empty_strided_cuda((50304, 1024), (1024, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_1.run(buf3, buf5, 51511296, grid=grid(51511296), stream=stream0) | |
del buf3 | |
buf7 = reinterpret_tensor(buf4, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf4 # reuse | |
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_2.run(buf7, add_177, rsqrt_62, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_177 | |
del rsqrt_62 | |
buf8 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf7, (1024, 65536), (1, 1024), 0), view_181, out=buf8) | |
del view_181 | |
buf9 = empty_strided_cuda((65536, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf7, (65536, 1024), (1024, 1), 0), permute_133, out=buf9) | |
del permute_133 | |
buf10 = reinterpret_tensor(mm_60, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_60 # reuse | |
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf10, buf9, 268435456, grid=grid(268435456), stream=stream0) | |
del buf9 | |
buf11 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf10, (4096, 65536), (1, 4096), 0), view_179, out=buf11) | |
del view_179 | |
buf12 = empty_strided_cuda((65536, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf10, (65536, 4096), (4096, 1), 0), permute_137, out=buf12) | |
del permute_137 | |
buf14 = buf7; del buf7 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf14, buf12, add_175, rsqrt_61, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_175 | |
del rsqrt_61 | |
buf15 = empty_strided_cuda((1024, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf14, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_145, (65536, 1024), (1024, 1), 0), out=buf15) | |
buf16 = buf12; del buf12 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf14, (65536, 1024), (1024, 1), 0), permute_141, out=buf16) | |
del permute_141 | |
buf17 = empty_strided_cuda((1, 8, 65536), (524288, 1, 8), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_145, buf16, buf17, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_145 | |
buf18 = empty_strided_cuda((1, 8, 65536), (524288, 65536, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf17, buf18, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf20 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf21 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf22 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_119, permute_120, permute_121, getitem_146, buf18, buf16, buf20, buf21, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf22, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_146 | |
del permute_119 | |
del permute_120 | |
del permute_121 | |
buf25 = empty_strided_cuda((512, ), (1, ), torch.float32) | |
buf27 = empty_strided_cuda((512, ), (1, ), torch.float32) | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_8.run(buf21, embedding_2, getitem_140, rsqrt_60, buf25, buf27, 512, 131072, grid=grid(512), stream=stream0) | |
buf26 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf25, buf26, 1, 512, grid=grid(1), stream=stream0) | |
buf28 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf27, buf28, 1, 512, grid=grid(1), stream=stream0) | |
buf37 = empty_strided_cuda((1, 65536, 24, 128), (201326592, 3072, 128, 1), torch.bfloat16) | |
buf36 = reinterpret_tensor(buf37, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_10.run(buf21, primals_7, getitem_140, rsqrt_60, buf36, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_140 | |
del rsqrt_60 | |
buf30 = empty_strided_cuda((1, 65536, 8, 128), (67108864, 1024, 128, 1), torch.float32) | |
buf31 = empty_strided_cuda((1, 65536, 8, 128), (67108864, 1024, 128, 1), torch.float32) | |
buf35 = reinterpret_tensor(buf37, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf34 = reinterpret_tensor(buf37, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_43, q_43], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf22, unsqueeze_158, unsqueeze_156, buf20, getitem_139, rsqrt_59, getitem_138, rsqrt_58, buf30, buf31, buf35, buf34, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_138 | |
del getitem_139 | |
del rsqrt_58 | |
del rsqrt_59 | |
del unsqueeze_156 | |
del unsqueeze_158 | |
del buf34 | |
del buf35 | |
del buf36 | |
buf38 = empty_strided_cuda((3072, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf37, (3072, 65536), (1, 3072), 0), view_172, out=buf38) | |
del view_172 | |
buf39 = reinterpret_tensor(buf22, (65536, 1024), (1024, 1), 0); del buf22 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf37, (65536, 3072), (3072, 1), 0), permute_149, out=buf39) | |
del permute_149 | |
buf40 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf15, buf38, buf40, 4194304, grid=grid(4194304), stream=stream0) | |
buf45 = reinterpret_tensor(buf20, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf20 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_13.run(buf14, buf39, primals_7, buf45, 67108864, grid=grid(67108864), stream=stream0) | |
buf47 = reinterpret_tensor(buf10, (65536, 4096), (4096, 1), 0); del buf10 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf45, (65536, 1024), (1024, 1), 0), permute_153, out=buf47) | |
del permute_153 | |
buf48 = reinterpret_tensor(mm_56, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_56 # reuse | |
# Topologically Sorted Source Nodes: [relu_14], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf48, buf47, 268435456, grid=grid(268435456), stream=stream0) | |
del buf47 | |
buf50 = buf16; del buf16 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf48, (65536, 4096), (4096, 1), 0), permute_157, out=buf50) | |
del permute_157 | |
buf52 = reinterpret_tensor(buf50, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf50 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_57], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf52, add_163, buf45, rsqrt_57, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_163 | |
del rsqrt_57 | |
buf54 = empty_strided_cuda((65536, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf52, (65536, 1024), (1024, 1), 0), permute_161, out=buf54) | |
del permute_161 | |
buf55 = buf17; del buf17 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_136, buf54, buf55, 524288, 128, grid=grid(524288), stream=stream0) | |
buf56 = buf18; del buf18 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf55, buf56, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf58 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf59 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf60 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_111, permute_112, permute_113, getitem_137, buf56, buf54, buf58, buf59, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf60, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_137 | |
del permute_111 | |
del permute_112 | |
del permute_113 | |
buf68 = buf31; del buf31 # reuse | |
buf69 = buf30; del buf30 # reuse | |
buf75 = buf37; del buf37 # reuse | |
buf73 = reinterpret_tensor(buf75, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf72 = reinterpret_tensor(buf75, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_40, q_40], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf60, unsqueeze_150, unsqueeze_148, buf58, getitem_130, rsqrt_55, getitem_129, rsqrt_54, buf68, buf69, buf73, buf72, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_129 | |
del getitem_130 | |
del rsqrt_54 | |
del rsqrt_55 | |
del unsqueeze_148 | |
del unsqueeze_150 | |
buf74 = reinterpret_tensor(buf75, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_15.run(buf59, primals_7, getitem_131, rsqrt_56, buf74, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf72 | |
del buf73 | |
del buf74 | |
buf77 = reinterpret_tensor(buf60, (65536, 1024), (1024, 1), 0); del buf60 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf75, (65536, 3072), (3072, 1), 0), permute_169, out=buf77) | |
del permute_169 | |
buf41 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
buf43 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
buf79 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
buf81 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
buf83 = reinterpret_tensor(buf58, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf58 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_16.run(buf14, buf39, embedding_3, rsqrt, add_165, buf52, buf77, add_153, primals_7, buf41, buf43, buf79, buf81, buf83, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_153 | |
del add_165 | |
buf42 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf41, buf42, 1, 65536, grid=grid(1), stream=stream0) | |
buf44 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf43, buf44, 1, 65536, grid=grid(1), stream=stream0) | |
buf46 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf45, (1024, 65536), (1, 1024), 0), view_169, out=buf46) | |
del view_169 | |
buf49 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf48, (4096, 65536), (1, 4096), 0), view_167, out=buf49) | |
del view_167 | |
buf53 = buf15; del buf15 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf52, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_136, (65536, 1024), (1024, 1), 0), out=buf53) | |
del getitem_136 | |
buf63 = buf27; del buf27 # reuse | |
buf65 = buf25; del buf25 # reuse | |
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_8.run(buf59, embedding_1, getitem_131, rsqrt_56, buf63, buf65, 512, 131072, grid=grid(512), stream=stream0) | |
del getitem_131 | |
del rsqrt_56 | |
buf64 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf63, buf64, 1, 512, grid=grid(1), stream=stream0) | |
buf66 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf65, buf66, 1, 512, grid=grid(1), stream=stream0) | |
buf76 = buf38; del buf38 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf75, (3072, 65536), (1, 3072), 0), view_160, out=buf76) | |
del view_160 | |
buf78 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf53, buf76, buf78, 4194304, grid=grid(4194304), stream=stream0) | |
buf80 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf79, buf80, 1, 65536, grid=grid(1), stream=stream0) | |
buf82 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf81, buf82, 1, 65536, grid=grid(1), stream=stream0) | |
buf84 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf83, (1024, 65536), (1, 1024), 0), view_157, out=buf84) | |
del view_157 | |
buf85 = reinterpret_tensor(buf48, (65536, 4096), (4096, 1), 0); del buf48 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf83, (65536, 1024), (1024, 1), 0), permute_173, out=buf85) | |
del permute_173 | |
buf86 = reinterpret_tensor(mm_52, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_52 # reuse | |
# Topologically Sorted Source Nodes: [relu_13], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf86, buf85, 268435456, grid=grid(268435456), stream=stream0) | |
del buf85 | |
buf87 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf86, (4096, 65536), (1, 4096), 0), view_155, out=buf87) | |
del view_155 | |
buf88 = reinterpret_tensor(buf45, (65536, 1024), (1024, 1), 0); del buf45 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf86, (65536, 4096), (4096, 1), 0), permute_177, out=buf88) | |
del permute_177 | |
buf90 = buf83; del buf83 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_53], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf90, buf88, add_151, rsqrt_53, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_151 | |
del rsqrt_53 | |
buf91 = buf53; del buf53 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf90, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_127, (65536, 1024), (1024, 1), 0), out=buf91) | |
buf92 = buf88; del buf88 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf90, (65536, 1024), (1024, 1), 0), permute_181, out=buf92) | |
del permute_181 | |
buf93 = buf55; del buf55 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_127, buf92, buf93, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_127 | |
buf94 = buf56; del buf56 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf93, buf94, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf96 = reinterpret_tensor(buf54, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf54 # reuse | |
buf97 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf98 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_103, permute_104, permute_105, getitem_128, buf94, buf92, buf96, buf97, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf98, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_128 | |
del permute_103 | |
del permute_104 | |
del permute_105 | |
buf101 = buf65; del buf65 # reuse | |
buf103 = buf63; del buf63 # reuse | |
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_8.run(buf97, embedding, getitem_122, rsqrt_52, buf101, buf103, 512, 131072, grid=grid(512), stream=stream0) | |
buf102 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf101, buf102, 1, 512, grid=grid(1), stream=stream0) | |
buf104 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf103, buf104, 1, 512, grid=grid(1), stream=stream0) | |
buf113 = buf75; del buf75 # reuse | |
buf112 = reinterpret_tensor(buf113, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_18.run(buf97, primals_7, getitem_122, rsqrt_52, buf112, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_122 | |
del rsqrt_52 | |
buf106 = buf69; del buf69 # reuse | |
buf107 = buf68; del buf68 # reuse | |
buf111 = reinterpret_tensor(buf113, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf110 = reinterpret_tensor(buf113, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_37, q_37], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf98, unsqueeze_142, unsqueeze_140, buf96, getitem_121, rsqrt_51, getitem_120, rsqrt_50, buf106, buf107, buf111, buf110, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_120 | |
del getitem_121 | |
del rsqrt_50 | |
del rsqrt_51 | |
del unsqueeze_140 | |
del unsqueeze_142 | |
del buf110 | |
del buf111 | |
del buf112 | |
buf114 = buf76; del buf76 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf113, (3072, 65536), (1, 3072), 0), view_148, out=buf114) | |
del view_148 | |
buf115 = reinterpret_tensor(buf98, (65536, 1024), (1024, 1), 0); del buf98 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf113, (65536, 3072), (3072, 1), 0), permute_189, out=buf115) | |
del permute_189 | |
buf116 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf91, buf114, buf116, 4194304, grid=grid(4194304), stream=stream0) | |
buf119 = buf14; del buf14 # reuse | |
buf122 = reinterpret_tensor(buf96, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf96 # reuse | |
buf117 = buf81; del buf81 # reuse | |
buf120 = buf79; del buf79 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten.mul, aten._to_copy, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_19.run(buf119, buf39, primals_7, buf52, buf77, buf90, buf115, embedding_3, rsqrt, add_141, buf122, buf117, buf120, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_141 | |
buf118 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf117, buf118, 1, 65536, grid=grid(1), stream=stream0) | |
buf121 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf120, buf121, 1, 65536, grid=grid(1), stream=stream0) | |
buf123 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf122, (1024, 65536), (1, 1024), 0), view_145, out=buf123) | |
del view_145 | |
buf124 = reinterpret_tensor(buf86, (65536, 4096), (4096, 1), 0); del buf86 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf122, (65536, 1024), (1024, 1), 0), permute_193, out=buf124) | |
del permute_193 | |
buf125 = reinterpret_tensor(mm_48, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_48 # reuse | |
# Topologically Sorted Source Nodes: [relu_12], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf125, buf124, 268435456, grid=grid(268435456), stream=stream0) | |
del buf124 | |
buf126 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf125, (4096, 65536), (1, 4096), 0), view_143, out=buf126) | |
del view_143 | |
buf127 = reinterpret_tensor(buf90, (65536, 1024), (1024, 1), 0); del buf90 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf125, (65536, 4096), (4096, 1), 0), permute_197, out=buf127) | |
del permute_197 | |
buf129 = buf122; del buf122 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_49], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf129, buf127, add_139, rsqrt_49, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_139 | |
del rsqrt_49 | |
buf130 = buf91; del buf91 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf129, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_118, (65536, 1024), (1024, 1), 0), out=buf130) | |
buf131 = buf127; del buf127 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf129, (65536, 1024), (1024, 1), 0), permute_201, out=buf131) | |
del permute_201 | |
buf132 = buf93; del buf93 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_118, buf131, buf132, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_118 | |
buf133 = buf94; del buf94 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf132, buf133, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf135 = reinterpret_tensor(buf77, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf77 # reuse | |
buf136 = reinterpret_tensor(buf52, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf52 # reuse | |
buf137 = reinterpret_tensor(buf39, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf39 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_95, permute_96, permute_97, getitem_119, buf133, buf131, buf135, buf136, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf137, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_119 | |
del permute_95 | |
del permute_96 | |
del permute_97 | |
buf140 = buf103; del buf103 # reuse | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf136, getitem_113, rsqrt_48, buf140, 512, 131072, grid=grid(512), stream=stream0) | |
buf141 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf140, buf141, 1, 512, grid=grid(1), stream=stream0) | |
buf150 = buf113; del buf113 # reuse | |
buf149 = reinterpret_tensor(buf150, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_21.run(buf136, primals_7, getitem_113, rsqrt_48, buf149, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_113 | |
del rsqrt_48 | |
buf143 = buf107; del buf107 # reuse | |
buf144 = buf106; del buf106 # reuse | |
buf148 = reinterpret_tensor(buf150, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf147 = reinterpret_tensor(buf150, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_34, q_34], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf137, unsqueeze_134, unsqueeze_132, buf135, getitem_112, rsqrt_47, getitem_111, rsqrt_46, buf143, buf144, buf148, buf147, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_111 | |
del getitem_112 | |
del rsqrt_46 | |
del rsqrt_47 | |
del unsqueeze_132 | |
del unsqueeze_134 | |
del buf147 | |
del buf148 | |
del buf149 | |
buf151 = buf114; del buf114 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf150, (3072, 65536), (1, 3072), 0), view_137, out=buf151) | |
del view_137 | |
buf152 = reinterpret_tensor(buf137, (65536, 1024), (1024, 1), 0); del buf137 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf150, (65536, 3072), (3072, 1), 0), permute_209, out=buf152) | |
del permute_209 | |
buf153 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf130, buf151, buf153, 4194304, grid=grid(4194304), stream=stream0) | |
buf158 = reinterpret_tensor(buf135, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf135 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_22.run(buf129, buf152, primals_7, buf158, 67108864, grid=grid(67108864), stream=stream0) | |
buf160 = reinterpret_tensor(buf125, (65536, 4096), (4096, 1), 0); del buf125 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf158, (65536, 1024), (1024, 1), 0), permute_213, out=buf160) | |
del permute_213 | |
buf161 = reinterpret_tensor(mm_44, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_44 # reuse | |
# Topologically Sorted Source Nodes: [relu_11], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf161, buf160, 268435456, grid=grid(268435456), stream=stream0) | |
del buf160 | |
buf163 = reinterpret_tensor(buf136, (65536, 1024), (1024, 1), 0); del buf136 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf161, (65536, 4096), (4096, 1), 0), permute_217, out=buf163) | |
del permute_217 | |
buf165 = reinterpret_tensor(buf163, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf163 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_45], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf165, add_128, buf158, rsqrt_45, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_128 | |
del rsqrt_45 | |
buf167 = buf131; del buf131 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf165, (65536, 1024), (1024, 1), 0), permute_221, out=buf167) | |
del permute_221 | |
buf168 = buf132; del buf132 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_109, buf167, buf168, 524288, 128, grid=grid(524288), stream=stream0) | |
buf169 = buf133; del buf133 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf168, buf169, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf171 = reinterpret_tensor(buf115, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf115 # reuse | |
buf172 = reinterpret_tensor(buf92, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf92 # reuse | |
buf173 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_87, permute_88, permute_89, getitem_110, buf169, buf167, buf171, buf172, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf173, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_110 | |
del permute_87 | |
del permute_88 | |
del permute_89 | |
buf179 = buf144; del buf144 # reuse | |
buf180 = buf143; del buf143 # reuse | |
buf186 = buf150; del buf150 # reuse | |
buf184 = reinterpret_tensor(buf186, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf183 = reinterpret_tensor(buf186, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_31, q_31], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf173, unsqueeze_126, unsqueeze_124, buf171, getitem_103, rsqrt_43, getitem_102, rsqrt_42, buf179, buf180, buf184, buf183, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_102 | |
del getitem_103 | |
del rsqrt_42 | |
del rsqrt_43 | |
del unsqueeze_124 | |
del unsqueeze_126 | |
buf185 = reinterpret_tensor(buf186, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_23.run(buf172, primals_7, getitem_104, rsqrt_44, buf185, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf183 | |
del buf184 | |
del buf185 | |
buf188 = reinterpret_tensor(buf173, (65536, 1024), (1024, 1), 0); del buf173 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf186, (65536, 3072), (3072, 1), 0), permute_229, out=buf188) | |
del permute_229 | |
buf154 = buf120; del buf120 # reuse | |
buf156 = buf117; del buf117 # reuse | |
buf190 = buf43; del buf43 # reuse | |
buf192 = buf41; del buf41 # reuse | |
buf194 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
buf196 = reinterpret_tensor(buf171, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf171 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_24.run(buf129, buf152, embedding_3, rsqrt, add_130, buf165, buf188, add_119, primals_7, add_36, buf154, buf156, buf190, buf192, buf194, buf196, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_119 | |
del add_130 | |
buf155 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf154, buf155, 1, 65536, grid=grid(1), stream=stream0) | |
buf157 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf156, buf157, 1, 65536, grid=grid(1), stream=stream0) | |
buf159 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf158, (1024, 65536), (1, 1024), 0), view_134, out=buf159) | |
del view_134 | |
buf162 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf161, (4096, 65536), (1, 4096), 0), view_132, out=buf162) | |
del view_132 | |
buf166 = buf130; del buf130 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf165, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_109, (65536, 1024), (1024, 1), 0), out=buf166) | |
del getitem_109 | |
buf176 = buf140; del buf140 # reuse | |
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf172, getitem_104, rsqrt_44, buf176, 512, 131072, grid=grid(512), stream=stream0) | |
del getitem_104 | |
del rsqrt_44 | |
buf177 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf176, buf177, 1, 512, grid=grid(1), stream=stream0) | |
buf187 = buf151; del buf151 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf186, (3072, 65536), (1, 3072), 0), view_126, out=buf187) | |
del view_126 | |
buf189 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf166, buf187, buf189, 4194304, grid=grid(4194304), stream=stream0) | |
buf191 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf190, buf191, 1, 65536, grid=grid(1), stream=stream0) | |
buf193 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf192, buf193, 1, 65536, grid=grid(1), stream=stream0) | |
buf195 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf194, buf195, 1, 65536, grid=grid(1), stream=stream0) | |
buf197 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf196, (1024, 65536), (1, 1024), 0), view_123, out=buf197) | |
del view_123 | |
buf198 = reinterpret_tensor(buf161, (65536, 4096), (4096, 1), 0); del buf161 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf196, (65536, 1024), (1024, 1), 0), permute_233, out=buf198) | |
del permute_233 | |
buf199 = reinterpret_tensor(mm_40, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_40 # reuse | |
# Topologically Sorted Source Nodes: [relu_10], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf199, buf198, 268435456, grid=grid(268435456), stream=stream0) | |
del buf198 | |
buf200 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf199, (4096, 65536), (1, 4096), 0), view_121, out=buf200) | |
del view_121 | |
buf201 = reinterpret_tensor(buf172, (65536, 1024), (1024, 1), 0); del buf172 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf199, (65536, 4096), (4096, 1), 0), permute_237, out=buf201) | |
del permute_237 | |
buf203 = buf196; del buf196 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_41], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf203, buf201, add_116, rsqrt_41, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_116 | |
del rsqrt_41 | |
buf204 = buf166; del buf166 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf203, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_100, (65536, 1024), (1024, 1), 0), out=buf204) | |
buf205 = buf201; del buf201 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf203, (65536, 1024), (1024, 1), 0), permute_241, out=buf205) | |
del permute_241 | |
buf206 = buf168; del buf168 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_100, buf205, buf206, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_100 | |
buf207 = buf169; del buf169 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf206, buf207, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf209 = reinterpret_tensor(buf158, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf158 # reuse | |
buf210 = reinterpret_tensor(buf167, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf167 # reuse | |
buf211 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_79, permute_80, permute_81, getitem_101, buf207, buf205, buf209, buf210, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf211, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_101 | |
del permute_79 | |
del permute_80 | |
del permute_81 | |
buf214 = buf176; del buf176 # reuse | |
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf210, getitem_95, rsqrt_40, buf214, 512, 131072, grid=grid(512), stream=stream0) | |
buf215 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf214, buf215, 1, 512, grid=grid(1), stream=stream0) | |
buf224 = buf186; del buf186 # reuse | |
buf223 = reinterpret_tensor(buf224, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_25.run(buf210, primals_7, getitem_95, rsqrt_40, buf223, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_95 | |
del rsqrt_40 | |
buf217 = buf180; del buf180 # reuse | |
buf218 = buf179; del buf179 # reuse | |
buf222 = reinterpret_tensor(buf224, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf221 = reinterpret_tensor(buf224, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_28, q_28], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf211, unsqueeze_118, unsqueeze_116, buf209, getitem_94, rsqrt_39, getitem_93, rsqrt_38, buf217, buf218, buf222, buf221, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_93 | |
del getitem_94 | |
del rsqrt_38 | |
del rsqrt_39 | |
del unsqueeze_116 | |
del unsqueeze_118 | |
del buf221 | |
del buf222 | |
del buf223 | |
buf225 = buf187; del buf187 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf224, (3072, 65536), (1, 3072), 0), view_115, out=buf225) | |
del view_115 | |
buf226 = reinterpret_tensor(buf211, (65536, 1024), (1024, 1), 0); del buf211 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf224, (65536, 3072), (3072, 1), 0), permute_249, out=buf226) | |
del permute_249 | |
buf227 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf204, buf225, buf227, 4194304, grid=grid(4194304), stream=stream0) | |
buf228 = buf194; del buf194 # reuse | |
buf231 = buf192; del buf192 # reuse | |
buf233 = buf190; del buf190 # reuse | |
buf230 = buf119; del buf119 # reuse | |
buf235 = reinterpret_tensor(buf209, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf209 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_26.run(buf230, buf203, buf226, embedding_3, rsqrt, add_107, primals_7, add_58, buf129, buf152, buf165, buf188, buf228, buf231, buf233, buf235, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_107 | |
buf229 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf228, buf229, 1, 65536, grid=grid(1), stream=stream0) | |
buf232 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf231, buf232, 1, 65536, grid=grid(1), stream=stream0) | |
buf234 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf233, buf234, 1, 65536, grid=grid(1), stream=stream0) | |
buf236 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf235, (1024, 65536), (1, 1024), 0), view_112, out=buf236) | |
del view_112 | |
buf237 = reinterpret_tensor(buf199, (65536, 4096), (4096, 1), 0); del buf199 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf235, (65536, 1024), (1024, 1), 0), permute_253, out=buf237) | |
del permute_253 | |
buf238 = reinterpret_tensor(mm_36, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_36 # reuse | |
# Topologically Sorted Source Nodes: [relu_9], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf238, buf237, 268435456, grid=grid(268435456), stream=stream0) | |
del buf237 | |
buf239 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf238, (4096, 65536), (1, 4096), 0), view_110, out=buf239) | |
del view_110 | |
buf240 = buf152; del buf152 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf238, (65536, 4096), (4096, 1), 0), permute_257, out=buf240) | |
del permute_257 | |
buf242 = buf235; del buf235 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_37], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf242, buf240, add_104, rsqrt_37, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_104 | |
del rsqrt_37 | |
buf243 = buf204; del buf204 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf242, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_91, (65536, 1024), (1024, 1), 0), out=buf243) | |
buf244 = buf240; del buf240 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf242, (65536, 1024), (1024, 1), 0), permute_261, out=buf244) | |
del permute_261 | |
buf245 = buf206; del buf206 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_91, buf244, buf245, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_91 | |
buf246 = buf207; del buf207 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf245, buf246, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf248 = reinterpret_tensor(buf129, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf129 # reuse | |
buf249 = buf210; del buf210 # reuse | |
buf250 = reinterpret_tensor(buf205, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf205 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_71, permute_72, permute_73, getitem_92, buf246, buf244, buf248, buf249, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf250, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_92 | |
del permute_71 | |
del permute_72 | |
del permute_73 | |
buf253 = buf214; del buf214 # reuse | |
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf249, getitem_86, rsqrt_36, buf253, 512, 131072, grid=grid(512), stream=stream0) | |
buf254 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf253, buf254, 1, 512, grid=grid(1), stream=stream0) | |
buf263 = buf224; del buf224 # reuse | |
buf262 = reinterpret_tensor(buf263, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_27.run(buf249, primals_7, getitem_86, rsqrt_36, buf262, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_86 | |
del rsqrt_36 | |
buf256 = buf218; del buf218 # reuse | |
buf257 = buf217; del buf217 # reuse | |
buf261 = reinterpret_tensor(buf263, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf260 = reinterpret_tensor(buf263, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_25, q_25], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf250, unsqueeze_110, unsqueeze_108, buf248, getitem_85, rsqrt_35, getitem_84, rsqrt_34, buf256, buf257, buf261, buf260, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_84 | |
del getitem_85 | |
del rsqrt_34 | |
del rsqrt_35 | |
del unsqueeze_108 | |
del unsqueeze_110 | |
del buf260 | |
del buf261 | |
del buf262 | |
buf264 = buf225; del buf225 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf263, (3072, 65536), (1, 3072), 0), view_104, out=buf264) | |
del view_104 | |
buf265 = reinterpret_tensor(buf250, (65536, 1024), (1024, 1), 0); del buf250 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf263, (65536, 3072), (3072, 1), 0), permute_269, out=buf265) | |
del permute_269 | |
buf266 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf243, buf264, buf266, 4194304, grid=grid(4194304), stream=stream0) | |
buf273 = reinterpret_tensor(buf248, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf248 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_28.run(buf242, buf265, primals_7, buf273, 67108864, grid=grid(67108864), stream=stream0) | |
buf275 = reinterpret_tensor(buf238, (65536, 4096), (4096, 1), 0); del buf238 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf273, (65536, 1024), (1024, 1), 0), permute_273, out=buf275) | |
del permute_273 | |
buf276 = reinterpret_tensor(mm_32, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_32 # reuse | |
# Topologically Sorted Source Nodes: [relu_8], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf276, buf275, 268435456, grid=grid(268435456), stream=stream0) | |
buf278 = reinterpret_tensor(buf249, (65536, 1024), (1024, 1), 0); del buf249 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf276, (65536, 4096), (4096, 1), 0), permute_277, out=buf278) | |
del permute_277 | |
buf280 = reinterpret_tensor(buf278, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf278 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_33], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf280, add_92, buf273, rsqrt_33, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_92 | |
del rsqrt_33 | |
buf282 = buf244; del buf244 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf280, (65536, 1024), (1024, 1), 0), permute_281, out=buf282) | |
del permute_281 | |
buf283 = buf245; del buf245 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_82, buf282, buf283, 524288, 128, grid=grid(524288), stream=stream0) | |
buf284 = buf246; del buf246 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf283, buf284, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf286 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf287 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
buf288 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_63, permute_64, permute_65, getitem_83, buf284, buf282, buf286, buf287, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf288, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_83 | |
del permute_63 | |
del permute_64 | |
del permute_65 | |
buf294 = buf257; del buf257 # reuse | |
buf295 = buf256; del buf256 # reuse | |
buf301 = buf263; del buf263 # reuse | |
buf299 = reinterpret_tensor(buf301, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf298 = reinterpret_tensor(buf301, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_22, q_22], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf288, unsqueeze_102, unsqueeze_100, buf286, getitem_76, rsqrt_31, getitem_75, rsqrt_30, buf294, buf295, buf299, buf298, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_75 | |
del getitem_76 | |
del rsqrt_30 | |
del rsqrt_31 | |
del unsqueeze_100 | |
del unsqueeze_102 | |
buf300 = reinterpret_tensor(buf301, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_29.run(buf287, primals_7, getitem_77, rsqrt_32, buf300, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf298 | |
del buf299 | |
del buf300 | |
buf303 = reinterpret_tensor(buf288, (65536, 1024), (1024, 1), 0); del buf288 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf301, (65536, 3072), (3072, 1), 0), permute_289, out=buf303) | |
del permute_289 | |
buf309 = reinterpret_tensor(buf286, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf286 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_30.run(buf280, buf303, primals_7, buf309, 67108864, grid=grid(67108864), stream=stream0) | |
buf311 = buf275; del buf275 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf309, (65536, 1024), (1024, 1), 0), permute_293, out=buf311) | |
del permute_293 | |
buf312 = reinterpret_tensor(mm_28, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_28 # reuse | |
# Topologically Sorted Source Nodes: [relu_7], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf312, buf311, 268435456, grid=grid(268435456), stream=stream0) | |
del buf311 | |
buf314 = buf282; del buf282 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf312, (65536, 4096), (4096, 1), 0), permute_297, out=buf314) | |
del permute_297 | |
buf316 = reinterpret_tensor(buf314, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf314 # reuse | |
buf319 = buf230; del buf230 # reuse | |
buf322 = empty_strided_cuda((1, 65536, 1024), (67108864, 1024, 1), torch.bfloat16) | |
buf267 = buf233; del buf233 # reuse | |
buf269 = buf231; del buf231 # reuse | |
buf271 = buf228; del buf228 # reuse | |
buf305 = buf156; del buf156 # reuse | |
buf307 = buf154; del buf154 # reuse | |
buf317 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
buf320 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32) | |
# Topologically Sorted Source Nodes: [x, rms_norm_29], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum, aten.div, aten.pow] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_31.run(buf316, buf319, add_81, buf309, rsqrt_29, buf242, buf265, primals_7, buf280, buf303, embedding_3, rsqrt, add_95, add_80, add_83, buf322, buf267, buf269, buf271, buf305, buf307, buf317, buf320, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_80 | |
del add_81 | |
del add_83 | |
del add_95 | |
del buf242 | |
del rsqrt_29 | |
buf268 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf267, buf268, 1, 65536, grid=grid(1), stream=stream0) | |
del buf267 | |
buf270 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf269, buf270, 1, 65536, grid=grid(1), stream=stream0) | |
del buf269 | |
buf272 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf271, buf272, 1, 65536, grid=grid(1), stream=stream0) | |
del buf271 | |
buf274 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf273, (1024, 65536), (1, 1024), 0), view_101, out=buf274) | |
del view_101 | |
buf277 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf276, (4096, 65536), (1, 4096), 0), view_99, out=buf277) | |
del buf276 | |
del view_99 | |
buf281 = buf243; del buf243 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf280, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_82, (65536, 1024), (1024, 1), 0), out=buf281) | |
del getitem_82 | |
buf291 = buf253; del buf253 # reuse | |
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf287, getitem_77, rsqrt_32, buf291, 512, 131072, grid=grid(512), stream=stream0) | |
del getitem_77 | |
del rsqrt_32 | |
buf292 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf291, buf292, 1, 512, grid=grid(1), stream=stream0) | |
buf302 = buf264; del buf264 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf301, (3072, 65536), (1, 3072), 0), view_93, out=buf302) | |
del view_93 | |
buf304 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf281, buf302, buf304, 4194304, grid=grid(4194304), stream=stream0) | |
buf306 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf305, buf306, 1, 65536, grid=grid(1), stream=stream0) | |
buf308 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf307, buf308, 1, 65536, grid=grid(1), stream=stream0) | |
buf310 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf309, (1024, 65536), (1, 1024), 0), view_90, out=buf310) | |
del view_90 | |
buf313 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf312, (4096, 65536), (1, 4096), 0), view_88, out=buf313) | |
del view_88 | |
buf318 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf317, buf318, 1, 65536, grid=grid(1), stream=stream0) | |
buf321 = empty_strided_cuda((), (), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf320, buf321, 1, 65536, grid=grid(1), stream=stream0) | |
buf323 = empty_strided_cuda((16, 2), (2, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten._to_copy, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_add_select_backward_32.run(buf42, buf44, buf80, buf82, buf118, buf121, buf155, buf157, buf191, buf193, buf229, buf232, buf268, buf270, buf306, buf308, buf318, buf321, buf323, 32, grid=grid(32), stream=stream0) | |
del buf118 | |
del buf121 | |
del buf155 | |
del buf157 | |
del buf191 | |
del buf193 | |
buf324 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf322, (1024, 65536), (1, 1024), 0), view_86, out=buf324) | |
del view_86 | |
buf325 = reinterpret_tensor(buf312, (65536, 4096), (4096, 1), 0); del buf312 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf322, (65536, 1024), (1024, 1), 0), permute_301, out=buf325) | |
del permute_301 | |
buf326 = reinterpret_tensor(mm_26, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_26 # reuse | |
# Topologically Sorted Source Nodes: [relu_6], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf326, buf325, 268435456, grid=grid(268435456), stream=stream0) | |
del buf325 | |
buf327 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf326, (4096, 65536), (1, 4096), 0), view_84, out=buf327) | |
del view_84 | |
buf328 = reinterpret_tensor(buf309, (65536, 1024), (1024, 1), 0); del buf309 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf326, (65536, 4096), (4096, 1), 0), permute_305, out=buf328) | |
del permute_305 | |
buf330 = buf322; del buf322 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_28], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf330, buf328, add_78, rsqrt_28, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_78 | |
del rsqrt_28 | |
buf331 = buf281; del buf281 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf330, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_73, (65536, 1024), (1024, 1), 0), out=buf331) | |
buf332 = buf328; del buf328 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf330, (65536, 1024), (1024, 1), 0), permute_309, out=buf332) | |
del permute_309 | |
buf333 = buf283; del buf283 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_73, buf332, buf333, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_73 | |
buf334 = buf284; del buf284 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf333, buf334, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf336 = buf287; del buf287 # reuse | |
buf337 = reinterpret_tensor(buf280, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf280 # reuse | |
buf338 = reinterpret_tensor(buf273, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf273 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_53, permute_54, permute_55, getitem_74, buf334, buf332, buf336, buf337, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf338, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_74 | |
del permute_53 | |
del permute_54 | |
del permute_55 | |
buf341 = buf291; del buf291 # reuse | |
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf337, getitem_68, rsqrt_27, buf341, 512, 131072, grid=grid(512), stream=stream0) | |
buf342 = buf82; del buf82 # reuse | |
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf341, buf342, 1, 512, grid=grid(1), stream=stream0) | |
buf351 = buf301; del buf301 # reuse | |
buf350 = reinterpret_tensor(buf351, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_33.run(buf337, primals_7, getitem_68, rsqrt_27, buf350, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_68 | |
del rsqrt_27 | |
buf344 = buf295; del buf295 # reuse | |
buf345 = buf294; del buf294 # reuse | |
buf349 = reinterpret_tensor(buf351, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf348 = reinterpret_tensor(buf351, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_19, q_19], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf338, unsqueeze_94, unsqueeze_92, buf336, getitem_67, rsqrt_26, getitem_66, rsqrt_25, buf344, buf345, buf349, buf348, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_66 | |
del getitem_67 | |
del rsqrt_25 | |
del rsqrt_26 | |
del unsqueeze_92 | |
del unsqueeze_94 | |
del buf348 | |
del buf349 | |
del buf350 | |
buf352 = buf302; del buf302 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf351, (3072, 65536), (1, 3072), 0), view_78, out=buf352) | |
del view_78 | |
buf353 = reinterpret_tensor(buf338, (65536, 1024), (1024, 1), 0); del buf338 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf351, (65536, 3072), (3072, 1), 0), permute_317, out=buf353) | |
del permute_317 | |
buf354 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf331, buf352, buf354, 4194304, grid=grid(4194304), stream=stream0) | |
buf359 = reinterpret_tensor(buf336, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf336 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_34.run(buf330, buf353, primals_7, buf359, 67108864, grid=grid(67108864), stream=stream0) | |
buf361 = reinterpret_tensor(buf326, (65536, 4096), (4096, 1), 0); del buf326 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf359, (65536, 1024), (1024, 1), 0), permute_321, out=buf361) | |
del permute_321 | |
buf362 = reinterpret_tensor(mm_22, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_22 # reuse | |
# Topologically Sorted Source Nodes: [relu_5], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf362, buf361, 268435456, grid=grid(268435456), stream=stream0) | |
del buf361 | |
buf364 = reinterpret_tensor(buf337, (65536, 1024), (1024, 1), 0); del buf337 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf362, (65536, 4096), (4096, 1), 0), permute_325, out=buf364) | |
del permute_325 | |
buf366 = reinterpret_tensor(buf364, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf364 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_24], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf366, add_67, buf359, rsqrt_24, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_67 | |
del rsqrt_24 | |
buf368 = buf332; del buf332 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf366, (65536, 1024), (1024, 1), 0), permute_329, out=buf368) | |
del permute_329 | |
buf369 = buf333; del buf333 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_64, buf368, buf369, 524288, 128, grid=grid(524288), stream=stream0) | |
buf370 = buf334; del buf334 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf369, buf370, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf372 = reinterpret_tensor(buf316, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf316 # reuse | |
buf373 = reinterpret_tensor(buf303, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf303 # reuse | |
buf374 = reinterpret_tensor(buf265, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf265 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_45, permute_46, permute_47, getitem_65, buf370, buf368, buf372, buf373, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf374, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del buf368 | |
del getitem_65 | |
del permute_45 | |
del permute_46 | |
del permute_47 | |
buf380 = buf345; del buf345 # reuse | |
buf381 = buf344; del buf344 # reuse | |
buf387 = buf351; del buf351 # reuse | |
buf385 = reinterpret_tensor(buf387, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf384 = reinterpret_tensor(buf387, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_16, q_16], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf374, unsqueeze_86, unsqueeze_84, buf372, getitem_58, rsqrt_22, getitem_57, rsqrt_21, buf380, buf381, buf385, buf384, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_57 | |
del getitem_58 | |
del rsqrt_21 | |
del rsqrt_22 | |
del unsqueeze_84 | |
del unsqueeze_86 | |
buf386 = reinterpret_tensor(buf387, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_35.run(buf373, primals_7, getitem_59, rsqrt_23, buf386, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf384 | |
del buf385 | |
del buf386 | |
buf389 = reinterpret_tensor(buf374, (65536, 1024), (1024, 1), 0); del buf374 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf387, (65536, 3072), (3072, 1), 0), permute_337, out=buf389) | |
del permute_337 | |
buf355 = buf320; del buf320 # reuse | |
buf357 = buf317; del buf317 # reuse | |
buf391 = buf307; del buf307 # reuse | |
buf393 = buf305; del buf305 # reuse | |
buf395 = buf203; del buf203 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_36.run(buf395, buf330, buf353, embedding_3, rsqrt, add_69, buf366, buf389, add_58, buf226, primals_7, buf355, buf357, buf391, buf393, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_58 | |
del add_69 | |
buf356 = buf80; del buf80 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf355, buf356, 1, 65536, grid=grid(1), stream=stream0) | |
buf358 = buf44; del buf44 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf357, buf358, 1, 65536, grid=grid(1), stream=stream0) | |
buf360 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf359, (1024, 65536), (1, 1024), 0), view_75, out=buf360) | |
del view_75 | |
buf363 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf362, (4096, 65536), (1, 4096), 0), view_73, out=buf363) | |
del view_73 | |
buf367 = buf331; del buf331 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf366, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_64, (65536, 1024), (1024, 1), 0), out=buf367) | |
del getitem_64 | |
buf377 = buf341; del buf341 # reuse | |
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf373, getitem_59, rsqrt_23, buf377, 512, 131072, grid=grid(512), stream=stream0) | |
del getitem_59 | |
del rsqrt_23 | |
buf378 = buf42; del buf42 # reuse | |
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf377, buf378, 1, 512, grid=grid(1), stream=stream0) | |
buf388 = buf352; del buf352 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf387, (3072, 65536), (1, 3072), 0), view_67, out=buf388) | |
del view_67 | |
buf390 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf367, buf388, buf390, 4194304, grid=grid(4194304), stream=stream0) | |
buf392 = buf321; del buf321 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf391, buf392, 1, 65536, grid=grid(1), stream=stream0) | |
buf394 = buf318; del buf318 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf393, buf394, 1, 65536, grid=grid(1), stream=stream0) | |
buf396 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf395, (1024, 65536), (1, 1024), 0), view_64, out=buf396) | |
del view_64 | |
buf397 = reinterpret_tensor(buf362, (65536, 4096), (4096, 1), 0); del buf362 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf395, (65536, 1024), (1024, 1), 0), permute_341, out=buf397) | |
del permute_341 | |
buf398 = reinterpret_tensor(mm_18, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_18 # reuse | |
# Topologically Sorted Source Nodes: [relu_4], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf398, buf397, 268435456, grid=grid(268435456), stream=stream0) | |
del buf397 | |
buf399 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf398, (4096, 65536), (1, 4096), 0), view_62, out=buf399) | |
del view_62 | |
buf400 = reinterpret_tensor(buf373, (65536, 1024), (1024, 1), 0); del buf373 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf398, (65536, 4096), (4096, 1), 0), permute_345, out=buf400) | |
del permute_345 | |
buf402 = buf395; del buf395 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_20], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf402, buf400, add_56, rsqrt_20, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_56 | |
del rsqrt_20 | |
buf403 = buf367; del buf367 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf402, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_55, (65536, 1024), (1024, 1), 0), out=buf403) | |
buf404 = buf400; del buf400 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf402, (65536, 1024), (1024, 1), 0), permute_349, out=buf404) | |
del permute_349 | |
buf405 = buf369; del buf369 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_55, buf404, buf405, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_55 | |
buf406 = buf370; del buf370 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf405, buf406, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf408 = reinterpret_tensor(buf359, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf359 # reuse | |
buf409 = reinterpret_tensor(buf226, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf226 # reuse | |
buf410 = buf372; del buf372 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_37, permute_38, permute_39, getitem_56, buf406, buf404, buf408, buf409, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf410, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_56 | |
del permute_37 | |
del permute_38 | |
del permute_39 | |
buf413 = buf377; del buf377 # reuse | |
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf409, getitem_50, rsqrt_19, buf413, 512, 131072, grid=grid(512), stream=stream0) | |
buf414 = buf308; del buf308 # reuse | |
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf413, buf414, 1, 512, grid=grid(1), stream=stream0) | |
buf423 = buf387; del buf387 # reuse | |
buf422 = reinterpret_tensor(buf423, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_37.run(buf409, primals_7, getitem_50, rsqrt_19, buf422, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_50 | |
del rsqrt_19 | |
buf416 = buf381; del buf381 # reuse | |
buf417 = buf380; del buf380 # reuse | |
buf421 = reinterpret_tensor(buf423, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf420 = reinterpret_tensor(buf423, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_13, q_13], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf410, unsqueeze_78, unsqueeze_76, buf408, getitem_49, rsqrt_18, getitem_48, rsqrt_17, buf416, buf417, buf421, buf420, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_48 | |
del getitem_49 | |
del rsqrt_17 | |
del rsqrt_18 | |
del unsqueeze_76 | |
del unsqueeze_78 | |
del buf420 | |
del buf421 | |
del buf422 | |
buf424 = buf388; del buf388 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf423, (3072, 65536), (1, 3072), 0), view_56, out=buf424) | |
del view_56 | |
buf425 = reinterpret_tensor(buf410, (65536, 1024), (1024, 1), 0); del buf410 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf423, (65536, 3072), (3072, 1), 0), permute_357, out=buf425) | |
del permute_357 | |
buf426 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf403, buf424, buf426, 4194304, grid=grid(4194304), stream=stream0) | |
buf429 = buf319; del buf319 # reuse | |
buf432 = reinterpret_tensor(buf408, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf408 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_38.run(buf429, buf330, buf353, primals_7, buf366, buf389, buf402, buf425, buf432, 67108864, grid=grid(67108864), stream=stream0) | |
buf434 = reinterpret_tensor(buf398, (65536, 4096), (4096, 1), 0); del buf398 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf432, (65536, 1024), (1024, 1), 0), permute_361, out=buf434) | |
del permute_361 | |
buf435 = reinterpret_tensor(mm_14, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_14 # reuse | |
# Topologically Sorted Source Nodes: [relu_3], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf435, buf434, 268435456, grid=grid(268435456), stream=stream0) | |
del buf434 | |
buf437 = buf389; del buf389 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf435, (65536, 4096), (4096, 1), 0), permute_365, out=buf437) | |
del permute_365 | |
buf439 = reinterpret_tensor(buf437, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf437 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_16], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf439, add_45, buf432, rsqrt_16, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_45 | |
del rsqrt_16 | |
buf441 = reinterpret_tensor(buf366, (65536, 1024), (1024, 1), 0); del buf366 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf439, (65536, 1024), (1024, 1), 0), permute_369, out=buf441) | |
del permute_369 | |
buf442 = buf405; del buf405 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_46, buf441, buf442, 524288, 128, grid=grid(524288), stream=stream0) | |
buf443 = buf406; del buf406 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf442, buf443, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf445 = reinterpret_tensor(buf353, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf353 # reuse | |
buf446 = reinterpret_tensor(buf330, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf330 # reuse | |
buf447 = buf409; del buf409 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_29, permute_30, permute_31, getitem_47, buf443, buf441, buf445, buf446, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf447, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_47 | |
del permute_29 | |
del permute_30 | |
del permute_31 | |
buf453 = buf417; del buf417 # reuse | |
buf454 = buf416; del buf416 # reuse | |
buf460 = buf423; del buf423 # reuse | |
buf458 = reinterpret_tensor(buf460, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf457 = reinterpret_tensor(buf460, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_10, q_10], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf447, unsqueeze_70, unsqueeze_68, buf445, getitem_40, rsqrt_14, getitem_39, rsqrt_13, buf453, buf454, buf458, buf457, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_39 | |
del getitem_40 | |
del rsqrt_13 | |
del rsqrt_14 | |
del unsqueeze_68 | |
del unsqueeze_70 | |
buf459 = reinterpret_tensor(buf460, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_39.run(buf446, primals_7, getitem_41, rsqrt_15, buf459, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf457 | |
del buf458 | |
del buf459 | |
buf462 = reinterpret_tensor(buf447, (65536, 1024), (1024, 1), 0); del buf447 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf460, (65536, 3072), (3072, 1), 0), permute_377, out=buf462) | |
del permute_377 | |
buf427 = buf393; del buf393 # reuse | |
buf430 = buf391; del buf391 # reuse | |
buf464 = buf357; del buf357 # reuse | |
buf466 = buf355; del buf355 # reuse | |
buf468 = buf165; del buf165 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_40.run(buf468, buf402, buf425, embedding_3, rsqrt, add_47, buf439, buf462, add_36, buf188, primals_7, buf427, buf430, buf464, buf466, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_36 | |
del add_47 | |
buf428 = buf306; del buf306 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf427, buf428, 1, 65536, grid=grid(1), stream=stream0) | |
buf431 = buf270; del buf270 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf430, buf431, 1, 65536, grid=grid(1), stream=stream0) | |
buf433 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf432, (1024, 65536), (1, 1024), 0), view_53, out=buf433) | |
del view_53 | |
buf436 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf435, (4096, 65536), (1, 4096), 0), view_51, out=buf436) | |
del view_51 | |
buf440 = buf403; del buf403 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf439, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_46, (65536, 1024), (1024, 1), 0), out=buf440) | |
del getitem_46 | |
buf450 = buf413; del buf413 # reuse | |
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_20.run(buf446, getitem_41, rsqrt_15, buf450, 512, 131072, grid=grid(512), stream=stream0) | |
del getitem_41 | |
del rsqrt_15 | |
buf451 = buf268; del buf268 # reuse | |
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf450, buf451, 1, 512, grid=grid(1), stream=stream0) | |
buf461 = buf424; del buf424 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf460, (3072, 65536), (1, 3072), 0), view_45, out=buf461) | |
del view_45 | |
buf463 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf440, buf461, buf463, 4194304, grid=grid(4194304), stream=stream0) | |
buf465 = buf232; del buf232 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf464, buf465, 1, 65536, grid=grid(1), stream=stream0) | |
buf467 = buf229; del buf229 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf466, buf467, 1, 65536, grid=grid(1), stream=stream0) | |
buf469 = empty_strided_cuda((16, 2), (2, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_add_select_backward_41.run(buf26, buf28, buf64, buf66, buf102, buf104, buf141, buf177, buf215, buf254, buf292, buf342, buf378, buf414, buf451, buf469, 32, grid=grid(32), stream=stream0) | |
del buf102 | |
del buf104 | |
del buf141 | |
del buf177 | |
buf470 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf468, (1024, 65536), (1, 1024), 0), view_42, out=buf470) | |
del view_42 | |
buf471 = reinterpret_tensor(buf435, (65536, 4096), (4096, 1), 0); del buf435 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf468, (65536, 1024), (1024, 1), 0), permute_381, out=buf471) | |
del permute_381 | |
buf472 = reinterpret_tensor(mm_10, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_10 # reuse | |
# Topologically Sorted Source Nodes: [relu_2], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf472, buf471, 268435456, grid=grid(268435456), stream=stream0) | |
del buf471 | |
buf473 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf472, (4096, 65536), (1, 4096), 0), view_40, out=buf473) | |
del view_40 | |
buf474 = reinterpret_tensor(buf446, (65536, 1024), (1024, 1), 0); del buf446 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf472, (65536, 4096), (4096, 1), 0), permute_385, out=buf474) | |
del permute_385 | |
buf476 = buf468; del buf468 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_12], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf476, buf474, add_34, rsqrt_12, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_34 | |
del rsqrt_12 | |
buf477 = buf440; del buf440 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf476, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_37, (65536, 1024), (1024, 1), 0), out=buf477) | |
buf478 = buf474; del buf474 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf476, (65536, 1024), (1024, 1), 0), permute_389, out=buf478) | |
del permute_389 | |
buf479 = buf442; del buf442 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_37, buf478, buf479, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_37 | |
buf480 = buf443; del buf443 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf479, buf480, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf482 = reinterpret_tensor(buf432, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf432 # reuse | |
buf483 = reinterpret_tensor(buf425, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf425 # reuse | |
buf484 = reinterpret_tensor(buf402, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf402 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_21, permute_22, permute_23, getitem_38, buf480, buf478, buf482, buf483, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf484, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del getitem_38 | |
del permute_21 | |
del permute_22 | |
del permute_23 | |
buf487 = buf450; del buf450 # reuse | |
buf489 = buf101; del buf101 # reuse | |
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_8.run(buf483, embedding_2, getitem_32, rsqrt_11, buf487, buf489, 512, 131072, grid=grid(512), stream=stream0) | |
del embedding_2 | |
buf488 = buf66; del buf66 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf487, buf488, 1, 512, grid=grid(1), stream=stream0) | |
buf490 = buf64; del buf64 # reuse | |
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf489, buf490, 1, 512, grid=grid(1), stream=stream0) | |
buf499 = buf460; del buf460 # reuse | |
buf498 = reinterpret_tensor(buf499, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_42.run(buf483, primals_7, getitem_32, rsqrt_11, buf498, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_32 | |
del rsqrt_11 | |
buf492 = buf454; del buf454 # reuse | |
buf493 = buf453; del buf453 # reuse | |
buf497 = reinterpret_tensor(buf499, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf496 = reinterpret_tensor(buf499, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_7, q_7], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf484, unsqueeze_62, unsqueeze_60, buf482, getitem_31, rsqrt_10, getitem_30, rsqrt_9, buf492, buf493, buf497, buf496, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_30 | |
del getitem_31 | |
del rsqrt_10 | |
del rsqrt_9 | |
del unsqueeze_60 | |
del unsqueeze_62 | |
del buf496 | |
del buf497 | |
del buf498 | |
buf500 = buf461; del buf461 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf499, (3072, 65536), (1, 3072), 0), view_33, out=buf500) | |
del view_33 | |
buf501 = reinterpret_tensor(buf484, (65536, 1024), (1024, 1), 0); del buf484 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf499, (65536, 3072), (3072, 1), 0), permute_397, out=buf501) | |
del permute_397 | |
buf502 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf477, buf500, buf502, 4194304, grid=grid(4194304), stream=stream0) | |
buf507 = reinterpret_tensor(buf482, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf482 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_mul_43.run(buf476, buf501, primals_7, buf507, 67108864, grid=grid(67108864), stream=stream0) | |
buf509 = reinterpret_tensor(buf472, (65536, 4096), (4096, 1), 0); del buf472 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf507, (65536, 1024), (1024, 1), 0), permute_401, out=buf509) | |
del permute_401 | |
buf510 = reinterpret_tensor(mm_6, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_6 # reuse | |
# Topologically Sorted Source Nodes: [relu_1], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf510, buf509, 268435456, grid=grid(268435456), stream=stream0) | |
del buf509 | |
buf512 = buf478; del buf478 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf510, (65536, 4096), (4096, 1), 0), permute_405, out=buf512) | |
del permute_405 | |
buf514 = reinterpret_tensor(buf512, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf512 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_8], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf514, add_22, buf507, rsqrt_8, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_22 | |
del rsqrt_8 | |
buf516 = buf188; del buf188 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf514, (65536, 1024), (1024, 1), 0), permute_409, out=buf516) | |
del permute_409 | |
buf517 = buf479; del buf479 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_28, buf516, buf517, 524288, 128, grid=grid(524288), stream=stream0) | |
buf518 = buf480; del buf480 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf517, buf518, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
buf520 = buf445; del buf445 # reuse | |
buf521 = reinterpret_tensor(buf441, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf441 # reuse | |
buf522 = reinterpret_tensor(buf404, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf404 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_tem_fused_zeros_7.run(permute_13, permute_14, permute_15, getitem_29, buf518, buf516, buf520, buf521, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf522, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del buf516 | |
del clamp_max_2 | |
del clamp_max_3 | |
del clone_10 | |
del clone_13 | |
del convert_element_type_6 | |
del convert_element_type_8 | |
del getitem_29 | |
del permute_13 | |
del permute_14 | |
del permute_15 | |
buf530 = buf493; del buf493 # reuse | |
buf531 = buf492; del buf492 # reuse | |
buf537 = buf499; del buf499 # reuse | |
buf535 = reinterpret_tensor(buf537, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf534 = reinterpret_tensor(buf537, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_4, q_4], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf522, unsqueeze_54, unsqueeze_52, buf520, getitem_22, rsqrt_6, getitem_21, rsqrt_5, buf530, buf531, buf535, buf534, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf520 | |
del getitem_21 | |
del getitem_22 | |
del rsqrt_5 | |
del rsqrt_6 | |
del unsqueeze_52 | |
del unsqueeze_54 | |
buf592 = empty_strided_cuda((50257, 1024), (1024, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf592, 51463168, grid=grid(51463168), stream=stream0) | |
buf536 = reinterpret_tensor(buf537, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [loss, v_4], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45.run(buf521, primals_7, getitem_23, rsqrt_7, primals_1, buf59, buf536, buf592, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf534 | |
del buf535 | |
del buf536 | |
buf539 = reinterpret_tensor(buf59, (65536, 1024), (1024, 1), 0); del buf59 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf537, (65536, 3072), (3072, 1), 0), permute_417, out=buf539) | |
del permute_417 | |
buf543 = buf429; del buf429 # reuse | |
buf546 = reinterpret_tensor(buf522, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf522 # reuse | |
buf503 = buf466; del buf466 # reuse | |
buf505 = buf464; del buf464 # reuse | |
buf541 = buf430; del buf430 # reuse | |
buf544 = buf427; del buf427 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_mul_sum_46.run(buf543, buf439, buf462, primals_7, buf476, buf501, buf514, buf539, embedding_3, rsqrt, add_24, add_12, buf546, buf503, buf505, buf541, buf544, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_12 | |
del add_24 | |
del buf439 | |
del buf462 | |
del buf476 | |
del buf501 | |
buf504 = buf451; del buf451 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf503, buf504, 1, 65536, grid=grid(1), stream=stream0) | |
del buf503 | |
buf506 = buf414; del buf414 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf505, buf506, 1, 65536, grid=grid(1), stream=stream0) | |
del buf505 | |
buf508 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf507, (1024, 65536), (1, 1024), 0), view_30, out=buf508) | |
del view_30 | |
buf511 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf510, (4096, 65536), (1, 4096), 0), view_28, out=buf511) | |
del view_28 | |
buf515 = buf477; del buf477 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf514, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_28, (65536, 1024), (1024, 1), 0), out=buf515) | |
del getitem_28 | |
buf525 = buf489; del buf489 # reuse | |
buf527 = buf487; del buf487 # reuse | |
# Topologically Sorted Source Nodes: [v_4], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_8.run(buf521, embedding_1, getitem_23, rsqrt_7, buf525, buf527, 512, 131072, grid=grid(512), stream=stream0) | |
del embedding_1 | |
del getitem_23 | |
del rsqrt_7 | |
buf526 = buf378; del buf378 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf525, buf526, 1, 512, grid=grid(1), stream=stream0) | |
buf528 = buf342; del buf342 # reuse | |
# Topologically Sorted Source Nodes: [v_4], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf527, buf528, 1, 512, grid=grid(1), stream=stream0) | |
buf538 = buf500; del buf500 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf537, (3072, 65536), (1, 3072), 0), view_21, out=buf538) | |
del view_21 | |
buf540 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf515, buf538, buf540, 4194304, grid=grid(4194304), stream=stream0) | |
buf542 = buf292; del buf292 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf541, buf542, 1, 65536, grid=grid(1), stream=stream0) | |
del buf541 | |
buf545 = buf28; del buf28 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf544, buf545, 1, 65536, grid=grid(1), stream=stream0) | |
buf547 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf546, (1024, 65536), (1, 1024), 0), view_18, out=buf547) | |
del view_18 | |
buf548 = reinterpret_tensor(buf510, (65536, 4096), (4096, 1), 0); del buf510 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf546, (65536, 1024), (1024, 1), 0), permute_421, out=buf548) | |
del permute_421 | |
buf549 = reinterpret_tensor(mm_2, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_2 # reuse | |
# Topologically Sorted Source Nodes: [relu], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf549, buf548, 268435456, grid=grid(268435456), stream=stream0) | |
del buf548 | |
buf550 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf549, (4096, 65536), (1, 4096), 0), view_16, out=buf550) | |
del view_16 | |
buf551 = reinterpret_tensor(buf521, (65536, 1024), (1024, 1), 0); del buf521 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf549, (65536, 4096), (4096, 1), 0), permute_425, out=buf551) | |
del buf549 | |
del permute_425 | |
buf553 = buf546; del buf546 # reuse | |
# Topologically Sorted Source Nodes: [rms_norm_4], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf553, buf551, add_10, rsqrt_4, 65536, 1024, grid=grid(65536), stream=stream0) | |
del add_10 | |
del rsqrt_4 | |
buf554 = buf515; del buf515 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf553, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_19, (65536, 1024), (1024, 1), 0), out=buf554) | |
buf555 = buf551; del buf551 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf553, (65536, 1024), (1024, 1), 0), permute_429, out=buf555) | |
del permute_429 | |
buf556 = buf517; del buf517 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_zeros_5.run(getitem_19, buf555, buf556, 524288, 128, grid=grid(524288), stream=stream0) | |
del getitem_19 | |
buf557 = buf518; del buf518 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_zeros_6.run(buf556, buf557, 8, 65536, grid=grid(8, 65536), stream=stream0) | |
del buf556 | |
buf559 = reinterpret_tensor(buf514, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf514 # reuse | |
buf560 = reinterpret_tensor(buf507, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf507 # reuse | |
buf561 = reinterpret_tensor(buf539, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf539 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
stream0 = get_raw_stream(0) | |
torch.save( | |
(permute_5, permute_6, permute_7, getitem_20, buf557, buf555, buf559, buf560, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf561), | |
"intermediate_3.2.pt", | |
) | |
triton_tem_fused_zeros_7.run(permute_5, permute_6, permute_7, getitem_20, buf557, buf555, buf559, buf560, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf561, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0) | |
del buf555 | |
del buf557 | |
del clamp_max | |
del clamp_max_1 | |
del clone_4 | |
del clone_7 | |
del convert_element_type_2 | |
del convert_element_type_4 | |
del cumsum | |
del getitem_20 | |
del permute_5 | |
del permute_6 | |
del permute_7 | |
del unsqueeze_13 | |
del unsqueeze_9 | |
buf564 = buf527; del buf527 # reuse | |
buf566 = buf525; del buf525 # reuse | |
# Topologically Sorted Source Nodes: [v_1], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_mul_sum_8.run(buf560, embedding, getitem_14, rsqrt_3, buf564, buf566, 512, 131072, grid=grid(512), stream=stream0) | |
del embedding | |
buf565 = buf26; del buf26 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf564, buf565, 1, 512, grid=grid(1), stream=stream0) | |
del buf564 | |
buf567 = buf254; del buf254 # reuse | |
# Topologically Sorted Source Nodes: [v_1], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_per_fused_mul_sum_9.run(buf566, buf567, 1, 512, grid=grid(1), stream=stream0) | |
del buf566 | |
buf595 = empty_strided_cuda((50257, 1024), (1024, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf595, 51463168, grid=grid(51463168), stream=stream0) | |
buf576 = buf537; del buf537 # reuse | |
buf575 = reinterpret_tensor(buf576, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias | |
# Topologically Sorted Source Nodes: [loss, v_1], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47.run(buf560, primals_7, getitem_14, rsqrt_3, primals_1, buf97, buf575, buf595, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf560 | |
del buf97 | |
del getitem_14 | |
del rsqrt_3 | |
buf569 = buf531; del buf531 # reuse | |
buf570 = buf530; del buf530 # reuse | |
buf574 = reinterpret_tensor(buf576, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias | |
buf573 = reinterpret_tensor(buf576, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias | |
# Topologically Sorted Source Nodes: [k_1, q_1], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf561, unsqueeze_46, unsqueeze_44, buf559, getitem_13, rsqrt_2, getitem_12, rsqrt_1, buf569, buf570, buf574, buf573, 524288, 128, grid=grid(524288), stream=stream0) | |
del buf559 | |
del buf569 | |
del buf570 | |
del getitem_12 | |
del getitem_13 | |
del rsqrt_1 | |
del rsqrt_2 | |
del unsqueeze_44 | |
del unsqueeze_46 | |
del buf573 | |
del buf574 | |
del buf575 | |
buf577 = buf538; del buf538 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf576, (3072, 65536), (1, 3072), 0), view_9, out=buf577) | |
del view_9 | |
buf578 = reinterpret_tensor(buf561, (65536, 1024), (1024, 1), 0); del buf561 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm] | |
extern_kernels.mm(reinterpret_tensor(buf576, (65536, 3072), (3072, 1), 0), permute_437, out=buf578) | |
del buf576 | |
del permute_437 | |
buf579 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_select_backward_12.run(buf554, buf577, buf579, 4194304, grid=grid(4194304), stream=stream0) | |
del buf554 | |
del buf577 | |
buf586 = empty_strided_cuda((50257, 1024), (1024, 1), torch.float32) | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf586, 51463168, grid=grid(51463168), stream=stream0) | |
buf580 = buf544; del buf544 # reuse | |
# Topologically Sorted Source Nodes: [loss, x], Original ATen: [aten.nll_loss_forward, aten._to_copy, aten.mul, aten.add, aten.sum, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48.run(buf543, buf553, buf578, primals_7, embedding_3, rsqrt, primals_1, buf580, buf586, 65536, 1024, grid=grid(65536), stream=stream0) | |
del buf543 | |
del buf553 | |
del buf578 | |
del embedding_3 | |
del rsqrt | |
buf581 = buf215; del buf215 # reuse | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
stream0 = get_raw_stream(0) | |
triton_red_fused__to_copy_add_mul_sum_17.run(buf580, buf581, 1, 65536, grid=grid(1), stream=stream0) | |
del buf580 | |
buf582 = empty_strided_cuda((80, ), (1, ), torch.float32) | |
buf583 = buf582; del buf582 # reuse | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add, aten.slice_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused__to_copy_add_select_backward_slice_backward_49.run(buf583, buf469, buf488, buf490, buf526, buf528, buf565, buf567, buf323, buf356, buf358, buf392, buf394, buf428, buf431, buf465, buf467, buf504, buf506, buf542, buf545, buf581, buf195, buf234, buf272, 80, grid=grid(80), stream=stream0) | |
del buf195 | |
del buf234 | |
del buf272 | |
del buf323 | |
del buf356 | |
del buf358 | |
del buf392 | |
del buf394 | |
del buf428 | |
del buf431 | |
del buf465 | |
del buf467 | |
del buf469 | |
del buf488 | |
del buf490 | |
del buf504 | |
del buf506 | |
del buf526 | |
del buf528 | |
del buf542 | |
del buf545 | |
del buf565 | |
del buf567 | |
del buf581 | |
buf588 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_50.run(buf586, buf588, 51463168, grid=grid(51463168), stream=stream0) | |
buf589 = buf586; del buf586 # reuse | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf589, 51463168, grid=grid(51463168), stream=stream0) | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51.run(primals_1, buf21, primals_7, buf483, buf589, 67108864, grid=grid(67108864), stream=stream0) | |
del buf21 | |
del buf483 | |
del primals_1 | |
del primals_7 | |
buf591 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_50.run(buf589, buf591, 51463168, grid=grid(51463168), stream=stream0) | |
del buf589 | |
buf594 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_50.run(buf592, buf594, 51463168, grid=grid(51463168), stream=stream0) | |
del buf592 | |
buf597 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16) | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward] | |
stream0 = get_raw_stream(0) | |
triton_poi_fused_embedding_dense_backward_50.run(buf595, buf597, 51463168, grid=grid(51463168), stream=stream0) | |
del buf595 | |
return (None, buf597, buf594, buf591, None, buf588, buf583, buf579, None, None, buf550, buf547, buf540, None, None, buf511, buf508, buf502, None, None, buf473, buf470, buf463, None, None, buf436, buf433, buf426, None, None, buf399, buf396, buf390, None, None, buf363, buf360, buf354, None, None, buf327, buf324, buf313, buf310, buf304, None, None, buf277, buf274, buf266, None, None, buf239, buf236, buf227, None, None, buf200, buf197, buf189, None, None, buf162, buf159, buf153, None, None, buf126, buf123, buf116, None, None, buf87, buf84, buf78, None, None, buf49, buf46, buf40, None, None, buf11, buf8, buf5, None, ) | |
def benchmark_compiled_module(times=10, repeat=10): | |
from torch._dynamo.testing import rand_strided | |
from torch._inductor.utils import print_performance | |
primals_1 = rand_strided((65536, ), (1, ), device='cuda:0', dtype=torch.int32) | |
primals_7 = rand_strided((80, ), (1, ), device='cuda:0', dtype=torch.float32) | |
primals_86 = rand_strided((65536, ), (1, ), device='cuda:0', dtype=torch.int64) | |
embedding = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
embedding_1 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
embedding_2 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
cumsum = rand_strided((65536, ), (1, ), device='cuda:0', dtype=torch.int64) | |
unsqueeze_9 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32) | |
unsqueeze_13 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32) | |
clamp_max = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
clamp_max_1 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
convert_element_type_2 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
clone_4 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32) | |
convert_element_type_4 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
clone_7 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32) | |
clamp_max_2 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
clamp_max_3 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
convert_element_type_6 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
clone_10 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32) | |
convert_element_type_8 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32) | |
clone_13 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32) | |
embedding_3 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_9 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_12 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_13 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_14 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_1 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_2 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_44 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_46 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_3 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_5 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_6 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_7 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_19 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_20 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_10 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_4 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_16 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_2 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_18 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_12 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_21 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_21 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_22 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_23 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_5 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_6 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_52 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_54 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_7 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_13 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_14 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_15 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_28 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_29 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_22 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_8 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_28 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_6 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_30 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_24 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_33 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_30 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_31 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_32 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_9 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_10 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_60 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_62 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_11 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_21 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_22 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_23 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_37 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_38 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_34 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_12 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_40 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_10 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_42 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_36 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_45 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_39 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_40 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_41 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_13 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_14 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_68 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_70 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_15 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_29 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_30 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_31 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_46 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_47 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_45 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_16 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_51 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_14 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_53 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_47 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_56 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_48 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_49 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_50 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_17 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_18 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_76 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_78 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_19 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_37 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_38 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_39 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_55 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_56 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_56 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_20 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_62 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_18 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_64 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_58 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_67 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_57 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_58 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_59 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_21 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_22 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_84 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_86 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_23 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_45 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_46 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_47 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_64 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_65 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_67 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_24 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_73 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_22 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_75 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_69 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_78 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_66 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_67 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_68 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_25 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_26 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_92 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_94 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_27 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_53 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_54 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_55 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_73 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_74 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_78 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_28 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_84 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_26 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_86 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_80 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_81 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_29 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_88 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_28 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_90 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_83 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_93 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_75 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_76 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_77 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_30 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_31 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_100 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_102 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_32 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_63 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_64 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_65 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_82 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_83 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_92 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_33 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_99 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_32 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_101 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_95 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_104 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_84 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_85 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_86 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_34 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_35 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_108 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_110 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_36 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_71 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_72 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_73 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_91 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_92 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_104 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_37 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_110 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_36 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_112 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_107 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_115 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_93 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_94 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_95 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_38 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_39 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_116 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_118 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_40 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_79 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_80 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_81 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_100 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_101 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_116 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_41 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_121 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_40 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_123 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_119 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_126 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_102 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_103 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_104 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_42 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_43 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_124 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_126 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_44 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_87 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_88 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_89 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_109 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_110 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_128 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_45 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_132 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_44 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_134 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_130 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_137 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_111 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_112 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_113 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_46 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_47 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_132 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_134 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_48 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_95 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_96 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_97 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_118 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_119 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_139 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_49 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_143 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_48 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_145 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_141 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_148 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_120 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_121 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_122 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_50 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_51 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_140 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_142 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_52 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_103 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_104 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_105 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_127 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_128 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_151 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_53 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_155 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_52 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_157 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_153 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_160 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_129 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_130 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_131 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_54 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_55 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_148 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_150 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_56 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_111 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_112 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_113 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_136 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_137 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_163 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_57 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_167 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_56 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_169 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_165 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_172 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_138 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_139 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_140 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_58 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_59 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_156 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
unsqueeze_158 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32) | |
rsqrt_60 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32) | |
permute_119 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_120 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_121 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_145 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
getitem_146 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32) | |
add_175 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_61 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_179 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_60 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
view_181 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
add_177 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
rsqrt_62 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32) | |
view_183 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
mm_62 = rand_strided((65536, 50304), (50304, 1), device='cuda:0', dtype=torch.bfloat16) | |
amax = rand_strided((65536, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
log = rand_strided((65536, 1), (1, 1), device='cuda:0', dtype=torch.float32) | |
convert_element_type_324 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
permute_129 = rand_strided((50304, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_133 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_137 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_141 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_149 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_153 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_157 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_161 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_169 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_173 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_177 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_181 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_189 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_193 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_197 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_201 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_209 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_213 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_217 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_221 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_229 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_233 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_237 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_241 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_249 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_253 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_257 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_261 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_269 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_273 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_277 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_281 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_289 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_293 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_297 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_301 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_305 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_309 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_317 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_321 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_325 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_329 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_337 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_341 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_345 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_349 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_357 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_361 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_365 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_369 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_377 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_381 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_385 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_389 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_397 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_401 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_405 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_409 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_417 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_421 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_425 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_429 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
permute_437 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16) | |
tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32) | |
fn = lambda: call([primals_1, primals_7, primals_86, embedding, embedding_1, embedding_2, cumsum, unsqueeze_9, unsqueeze_13, clamp_max, clamp_max_1, convert_element_type_2, clone_4, convert_element_type_4, clone_7, clamp_max_2, clamp_max_3, convert_element_type_6, clone_10, convert_element_type_8, clone_13, embedding_3, rsqrt, view_9, getitem_12, getitem_13, getitem_14, rsqrt_1, rsqrt_2, unsqueeze_44, unsqueeze_46, rsqrt_3, permute_5, permute_6, permute_7, getitem_19, getitem_20, add_10, rsqrt_4, view_16, mm_2, view_18, add_12, view_21, getitem_21, getitem_22, getitem_23, rsqrt_5, rsqrt_6, unsqueeze_52, unsqueeze_54, rsqrt_7, permute_13, permute_14, permute_15, getitem_28, getitem_29, add_22, rsqrt_8, view_28, mm_6, view_30, add_24, view_33, getitem_30, getitem_31, getitem_32, rsqrt_9, rsqrt_10, unsqueeze_60, unsqueeze_62, rsqrt_11, permute_21, permute_22, permute_23, getitem_37, getitem_38, add_34, rsqrt_12, view_40, mm_10, view_42, add_36, view_45, getitem_39, getitem_40, getitem_41, rsqrt_13, rsqrt_14, unsqueeze_68, unsqueeze_70, rsqrt_15, permute_29, permute_30, permute_31, getitem_46, getitem_47, add_45, rsqrt_16, view_51, mm_14, view_53, add_47, view_56, getitem_48, getitem_49, getitem_50, rsqrt_17, rsqrt_18, unsqueeze_76, unsqueeze_78, rsqrt_19, permute_37, permute_38, permute_39, getitem_55, getitem_56, add_56, rsqrt_20, view_62, mm_18, view_64, add_58, view_67, getitem_57, getitem_58, getitem_59, rsqrt_21, rsqrt_22, unsqueeze_84, unsqueeze_86, rsqrt_23, permute_45, permute_46, permute_47, getitem_64, getitem_65, add_67, rsqrt_24, view_73, mm_22, view_75, add_69, view_78, getitem_66, getitem_67, getitem_68, rsqrt_25, rsqrt_26, unsqueeze_92, unsqueeze_94, rsqrt_27, permute_53, permute_54, permute_55, getitem_73, getitem_74, add_78, rsqrt_28, view_84, mm_26, view_86, add_80, add_81, rsqrt_29, view_88, mm_28, view_90, add_83, view_93, getitem_75, getitem_76, getitem_77, rsqrt_30, rsqrt_31, unsqueeze_100, unsqueeze_102, rsqrt_32, permute_63, permute_64, permute_65, getitem_82, getitem_83, add_92, rsqrt_33, view_99, mm_32, view_101, add_95, view_104, getitem_84, getitem_85, getitem_86, rsqrt_34, rsqrt_35, unsqueeze_108, unsqueeze_110, rsqrt_36, permute_71, permute_72, permute_73, getitem_91, getitem_92, add_104, rsqrt_37, view_110, mm_36, view_112, add_107, view_115, getitem_93, getitem_94, getitem_95, rsqrt_38, rsqrt_39, unsqueeze_116, unsqueeze_118, rsqrt_40, permute_79, permute_80, permute_81, getitem_100, getitem_101, add_116, rsqrt_41, view_121, mm_40, view_123, add_119, view_126, getitem_102, getitem_103, getitem_104, rsqrt_42, rsqrt_43, unsqueeze_124, unsqueeze_126, rsqrt_44, permute_87, permute_88, permute_89, getitem_109, getitem_110, add_128, rsqrt_45, view_132, mm_44, view_134, add_130, view_137, getitem_111, getitem_112, getitem_113, rsqrt_46, rsqrt_47, unsqueeze_132, unsqueeze_134, rsqrt_48, permute_95, permute_96, permute_97, getitem_118, getitem_119, add_139, rsqrt_49, view_143, mm_48, view_145, add_141, view_148, getitem_120, getitem_121, getitem_122, rsqrt_50, rsqrt_51, unsqueeze_140, unsqueeze_142, rsqrt_52, permute_103, permute_104, permute_105, getitem_127, getitem_128, add_151, rsqrt_53, view_155, mm_52, view_157, add_153, view_160, getitem_129, getitem_130, getitem_131, rsqrt_54, rsqrt_55, unsqueeze_148, unsqueeze_150, rsqrt_56, permute_111, permute_112, permute_113, getitem_136, getitem_137, add_163, rsqrt_57, view_167, mm_56, view_169, add_165, view_172, getitem_138, getitem_139, getitem_140, rsqrt_58, rsqrt_59, unsqueeze_156, unsqueeze_158, rsqrt_60, permute_119, permute_120, permute_121, getitem_145, getitem_146, add_175, rsqrt_61, view_179, mm_60, view_181, add_177, rsqrt_62, view_183, mm_62, amax, log, convert_element_type_324, permute_129, permute_133, permute_137, permute_141, permute_149, permute_153, permute_157, permute_161, permute_169, permute_173, permute_177, permute_181, permute_189, permute_193, permute_197, permute_201, permute_209, permute_213, permute_217, permute_221, permute_229, permute_233, permute_237, permute_241, permute_249, permute_253, permute_257, permute_261, permute_269, permute_273, permute_277, permute_281, permute_289, permute_293, permute_297, permute_301, permute_305, permute_309, permute_317, permute_321, permute_325, permute_329, permute_337, permute_341, permute_345, permute_349, permute_357, permute_361, permute_365, permute_369, permute_377, permute_381, permute_385, permute_389, permute_397, permute_401, permute_405, permute_409, permute_417, permute_421, permute_425, permute_429, permute_437, tangents_1]) | |
return print_performance(fn, times=times, repeat=repeat) | |
if __name__ == "__main__": | |
from torch._inductor.wrapper_benchmark import compiled_module_main | |
compiled_module_main('None', benchmark_compiled_module) |
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# AOT ID: ['0_backward'] | |
from ctypes import c_void_p, c_long, c_int | |
import torch | |
import math | |
import random | |
import os | |
import tempfile | |
from math import inf, nan | |
from cmath import nanj | |
from torch._inductor.hooks import run_intermediate_hooks | |
from torch._inductor.utils import maybe_profile | |
from torch._inductor.codegen.memory_planning import _align as align | |
from torch import device, empty_strided | |
from torch._inductor.async_compile import AsyncCompile | |
from torch._inductor.select_algorithm import extern_kernels | |
from torch._inductor.codegen.multi_kernel import MultiKernelCall | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime.triton_heuristics import ( | |
grid, | |
split_scan_grid, | |
grid_combo_kernels, | |
start_graph, | |
end_graph, | |
cooperative_reduction_grid, | |
) | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
from torch._C import _cuda_getCurrentRawStream as get_raw_stream | |
import torch._inductor.kernel.flex_attention | |
aten = torch.ops.aten | |
inductor_ops = torch.ops.inductor | |
_quantized = torch.ops._quantized | |
assert_size_stride = torch._C._dynamo.guards.assert_size_stride | |
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu | |
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda | |
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu | |
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor | |
alloc_from_pool = torch.ops.inductor._alloc_from_pool | |
async_compile = AsyncCompile() | |
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/hh/chhrb4fk5xruv6uuerqegmi4rbdnnzuq2z23xh2gwuvych6si5w7.py | |
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data] | |
# Source node to ATen node mapping: | |
# add_116 => add_179 | |
# logits => convert_element_type_323 | |
# loss => full_default_12, full_default_13, sub_4, sub_5 | |
# mul_176 => mul_239 | |
# mul_177 => mul_240 | |
# rsqrt => rsqrt_63 | |
# square_16 => pow_80 | |
# Graph fragment: | |
# %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%tangents_1, %convert_element_type_324), kwargs = {}) | |
# %ne_3 : [num_users=2] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_164, -100), kwargs = {}) | |
# %full_default_12 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_6 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_164, %full_default_12), kwargs = {}) | |
# %scatter_upon_const_tensor : [num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [65536, 50304], background_val: 0, dtype: torch.float32, dim: 1, selector: %where_6, val: -1.0}) | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_7 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %div_2, %full_default_13), kwargs = {}) | |
# %mul_241 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter_upon_const_tensor, %where_7), kwargs = {}) | |
# %convert_element_type_323 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_62, torch.float32), kwargs = {}) | |
# %mul_239 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_323, 15), kwargs = {}) | |
# %pow_80 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 2), kwargs = {}) | |
# %add_179 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%pow_80, 225), kwargs = {}) | |
# %rsqrt_63 : [num_users=3] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_179,), kwargs = {}) | |
# %mul_240 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_239, %rsqrt_63), kwargs = {}) | |
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_240, %amax), kwargs = {}) | |
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_4, %log), kwargs = {}) | |
# %exp_1 : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {}) | |
# %sum_10 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_241, [1], True), kwargs = {}) | |
# %mul_242 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp_1, %sum_10), kwargs = {}) | |
# %sub_6 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_241, %mul_242), kwargs = {}) | |
# %mul_243 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %mul_239), kwargs = {}) | |
# %mul_244 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %rsqrt_63), kwargs = {}) | |
# %pow_81 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt_63, 3), kwargs = {}) | |
# %mul_245 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%mul_243, -0.5), kwargs = {}) | |
# %mul_246 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_245, %pow_81), kwargs = {}) | |
# %pow_82 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 1.0), kwargs = {}) | |
# %mul_247 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_82, 2.0), kwargs = {}) | |
# %mul_248 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_246, %mul_247), kwargs = {}) | |
# %mul_249 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_244, 15), kwargs = {}) | |
# %add_180 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_248, %mul_249), kwargs = {}) | |
# %convert_element_type_325 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_180, torch.bfloat16), kwargs = {}) | |
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0 = async_compile.triton('triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 65536, 'r0_': 65536}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 8, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 65536 | |
r0_numel = 50304 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64) | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64) | |
rbase = r0_base | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last') | |
tmp10 = tl.load(in_ptr1 + (0)) | |
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK]) | |
tmp12 = tl.load(in_ptr2 + (0)) | |
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK]) | |
_tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp1 = tl.full([1, 1], -100, tl.int64) | |
tmp2 = tmp0 != tmp1 | |
tmp3 = tl.full([1, 1], 0, tl.int64) | |
tmp4 = tl.where(tmp2, tmp0, tmp3) | |
tmp5 = r0_1 | |
tmp6 = tmp4 == tmp5 | |
tmp7 = -1.0 | |
tmp8 = 0.0 | |
tmp9 = tl.where(tmp6, tmp7, tmp8) | |
tmp14 = (tmp11 / tmp13) | |
tmp15 = tl.where(tmp2, tmp14, tmp8) | |
tmp16 = tmp9 * tmp15 | |
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK]) | |
tmp19 = _tmp18 + tmp17 | |
_tmp18 = tl.where(r0_mask, tmp19, _tmp18) | |
tmp18 = tl.sum(_tmp18, 1)[:, None] | |
tmp29 = tl.load(in_ptr1 + (0)) | |
tmp30 = tl.broadcast_to(tmp29, [XBLOCK, R0_BLOCK]) | |
tmp31 = tl.load(in_ptr2 + (0)) | |
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK]) | |
tmp45 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp47 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp36 = tl.load(in_out_ptr0 + (r0_1 + 50304*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp20 = tl.full([1, 1], -100, tl.int64) | |
tmp21 = tmp0 != tmp20 | |
tmp22 = tl.full([1, 1], 0, tl.int64) | |
tmp23 = tl.where(tmp21, tmp0, tmp22) | |
tmp24 = r0_1 | |
tmp25 = tmp23 == tmp24 | |
tmp26 = -1.0 | |
tmp27 = 0.0 | |
tmp28 = tl.where(tmp25, tmp26, tmp27) | |
tmp33 = (tmp30 / tmp32) | |
tmp34 = tl.where(tmp21, tmp33, tmp27) | |
tmp35 = tmp28 * tmp34 | |
tmp37 = tmp36.to(tl.float32) | |
tmp38 = 15.0 | |
tmp39 = tmp37 * tmp38 | |
tmp40 = tmp37 * tmp37 | |
tmp41 = 225.0 | |
tmp42 = tmp40 + tmp41 | |
tmp43 = libdevice.rsqrt(tmp42) | |
tmp44 = tmp39 * tmp43 | |
tmp46 = tmp44 - tmp45 | |
tmp48 = tmp46 - tmp47 | |
tmp49 = tl_math.exp(tmp48) | |
tmp50 = tmp49 * tmp18 | |
tmp51 = tmp35 - tmp50 | |
tmp52 = tmp51 * tmp39 | |
tmp53 = -0.5 | |
tmp54 = tmp52 * tmp53 | |
tmp55 = tmp43 * tmp43 | |
tmp56 = tmp55 * tmp43 | |
tmp57 = tmp54 * tmp56 | |
tmp58 = 2.0 | |
tmp59 = tmp37 * tmp58 | |
tmp60 = tmp57 * tmp59 | |
tmp61 = tmp51 * tmp43 | |
tmp62 = tmp61 * tmp38 | |
tmp63 = tmp60 + tmp62 | |
tmp64 = tmp63.to(tl.float32) | |
tl.store(in_out_ptr0 + (r0_1 + 50304*x0), tmp64, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/tn/ctnjpswwi3gqhcnwhzij5s7dx5ph7fovddyn7g76xfykplagbfrm.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %convert_element_type_330 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_63, torch.float32), kwargs = {}) | |
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 51511296 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tl.store(out_ptr0 + (x0), tmp1, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/nm/cnml7os74h3shqb5iuld3alyjqrwei2qthuqadrqzek7sgkvlydx.py | |
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# x_144 => convert_element_type_318 | |
# Graph fragment: | |
# %convert_element_type_331 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_184, torch.float32), kwargs = {}) | |
# %convert_element_type_318 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_177, torch.float32), kwargs = {}) | |
# %mul_250 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %convert_element_type_318), kwargs = {}) | |
# %mul_251 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %rsqrt_62), kwargs = {}) | |
# %sum_11 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_250, [2], True), kwargs = {}) | |
# %div_3 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_20, 1024), kwargs = {}) | |
# %pow_84 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_318, 1.0), kwargs = {}) | |
# %mul_254 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_84, 2.0), kwargs = {}) | |
# %mul_255 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %mul_254), kwargs = {}) | |
# %add_181 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_251, %mul_255), kwargs = {}) | |
# %convert_element_type_332 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_181, torch.bfloat16), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp9 = tmp1 * tmp8 | |
tmp10 = -0.5 | |
tmp11 = tmp7 * tmp10 | |
tmp12 = tmp8 * tmp8 | |
tmp13 = tmp12 * tmp8 | |
tmp14 = tmp11 * tmp13 | |
tmp15 = 0.0009765625 | |
tmp16 = tmp14 * tmp15 | |
tmp17 = 2.0 | |
tmp18 = tmp3 * tmp17 | |
tmp19 = tmp16 * tmp18 | |
tmp20 = tmp9 + tmp19 | |
tmp21 = tmp20.to(tl.float32) | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp21, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/62/c62mj7hc5xoylhivn7ul5ogozxwe57duoiwwckqoeymqoszlcalf.py | |
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward] | |
# Source node to ATen node mapping: | |
# relu_15 => relu_15 | |
# Graph fragment: | |
# %relu_15 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%view_180,), kwargs = {}) | |
# %pow_85 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%relu_15, 1.0), kwargs = {}) | |
# %mul_256 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_85, 2.0), kwargs = {}) | |
# %mul_257 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_186, %mul_256), kwargs = {}) | |
# %le_1 : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu_15, 0), kwargs = {}) | |
# %full_default_17 : [num_users=16] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le_1, %full_default_17, %mul_257), kwargs = {}) | |
triton_poi_fused_mul_pow_relu_threshold_backward_3 = async_compile.triton('triton_poi_fused_mul_pow_relu_threshold_backward_3', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 268435456}, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_pow_relu_threshold_backward_3', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_mul_pow_relu_threshold_backward_3(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 268435456 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.full([1], 0, tl.int32) | |
tmp2 = triton_helpers.maximum(tmp1, tmp0) | |
tmp3 = 0.0 | |
tmp4 = tmp2 <= tmp3 | |
tmp6 = 2.0 | |
tmp7 = tmp2 * tmp6 | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.where(tmp4, tmp3, tmp8) | |
tl.store(in_out_ptr0 + (x0), tmp9, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ul/culpiuny7rsb7xrfxqlunbb4ygzdfdptdamd6p2d43ojix2u5dvg.py | |
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# rms_norm_61 => convert_element_type_312 | |
# Graph fragment: | |
# %convert_element_type_341 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_188, torch.float32), kwargs = {}) | |
# %convert_element_type_312 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_175, torch.float32), kwargs = {}) | |
# %mul_258 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %convert_element_type_312), kwargs = {}) | |
# %mul_259 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %rsqrt_61), kwargs = {}) | |
# %sum_12 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_258, [2], True), kwargs = {}) | |
# %div_4 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_21, 1024), kwargs = {}) | |
# %pow_87 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 1.0), kwargs = {}) | |
# %mul_262 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_87, 2.0), kwargs = {}) | |
# %mul_263 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_4, %mul_262), kwargs = {}) | |
# %add_182 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_259, %mul_263), kwargs = {}) | |
# %convert_element_type_342 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_182, torch.bfloat16), kwargs = {}) | |
# %add_183 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_332, %convert_element_type_342), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_4 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_4', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_4', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_4(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp10 = tmp1 * tmp9 | |
tmp11 = -0.5 | |
tmp12 = tmp7 * tmp11 | |
tmp13 = tmp9 * tmp9 | |
tmp14 = tmp13 * tmp9 | |
tmp15 = tmp12 * tmp14 | |
tmp16 = 0.0009765625 | |
tmp17 = tmp15 * tmp16 | |
tmp18 = 2.0 | |
tmp19 = tmp3 * tmp18 | |
tmp20 = tmp17 * tmp19 | |
tmp21 = tmp10 + tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp8 + tmp22 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/3d/c3dqvlchxst3ygwswuqw55ensuczxcndkjpqvcilnha65gtorjio.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {}) | |
triton_per_fused_zeros_5 = async_compile.triton('triton_per_fused_zeros_5', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_zeros_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused_zeros_5(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
R0_BLOCK: tl.constexpr = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[None, :] | |
r0_offset = 0 | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 128*x0), None).to(tl.float32) | |
tmp2 = tmp0 * tmp1 | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) | |
tmp5 = tl.sum(tmp3, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp5, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/g7/cg7vrix5hzkplkjfm4y5yzz67ap26cysqwhzfniuom4fca22jmrl.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {}) | |
triton_poi_fused_zeros_6 = async_compile.triton('triton_poi_fused_zeros_6', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'y': 8, 'x': 65536}, tile_hint=TileHint.SQUARE, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_zeros_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_zeros_6(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): | |
ynumel = 8 | |
xnumel = 65536 | |
yoffset = tl.program_id(1) * YBLOCK | |
yindex = yoffset + tl.arange(0, YBLOCK)[None, :] | |
ymask = yindex < ynumel | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, YBLOCK], True, tl.int1) | |
x1 = xindex | |
y0 = yindex | |
tmp0 = tl.load(in_ptr0 + (y0 + 8*x1), ymask, eviction_policy='evict_last').to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp2 = 0.0 | |
tmp3 = tmp1 - tmp2 | |
tl.store(out_ptr0 + (x1 + 65536*y0), tmp3, ymask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/g5/cg5pjx7bwly5zrm6xs4clxzfh2xs7znnajmp4og73hop4xabftat.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {}) | |
triton_tem_fused_zeros_7 = async_compile.triton('triton_tem_fused_zeros_7', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
@triton_heuristics.template( | |
num_stages=3, | |
num_warps=8, | |
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'kernel_name': 'triton_tem_fused_zeros_7', 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
) | |
@triton.jit | |
def triton_tem_fused_zeros_7(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
Q = arg_Q | |
K = arg_K | |
V = arg_V | |
LSE = arg_LSE | |
DELTA = arg_DELTA | |
DO = arg_DO | |
DQ = arg_DQ | |
DV = arg_DV | |
KV_NUM_BLKS = arg_KV_NUM_BLKS | |
KV_IDX = arg_KV_IDX | |
Q_NUM_BLKS = arg_Q_NUM_BLKS | |
Q_IDX = arg_Q_IDX | |
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS | |
FULL_KV_IDX = arg_FULL_KV_IDX | |
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS | |
FULL_Q_IDX = arg_FULL_Q_IDX | |
# Sub notation for this kernel: | |
# | |
# Q: Query, K: Key, V: Value | |
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) | |
# DELTA: Precomputed sum(OUT*DO, axis=-1) | |
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value | |
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with | |
# inductor codegen | |
# M: Number of queries, N: Number of keys/values | |
# QK_HEAD_DIM: The dimension of the query and key embeddings | |
# V_HEAD_DIM: The dimension of the value embeddings | |
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim | |
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. | |
# (Modifiable) Performance tuning options | |
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. | |
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. | |
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. | |
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. | |
# | |
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. | |
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. | |
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. | |
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. | |
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. | |
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. | |
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. | |
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. | |
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. | |
# The below are kernel options that can be applied for certain score_mods, | |
# or involve a numerics vs. perf tradeoff | |
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has | |
# about 20% more numerical error, but slightly faster. | |
# Define strides of inputs | |
stride_qz, stride_qh, stride_qm, stride_qd = 67108864, 128, 1024, 1 | |
stride_kz, stride_kh, stride_kn, stride_kd = 67108864, 128, 1024, 1 | |
stride_vz, stride_vh, stride_vn, stride_vd = 67108864, 128, 1024, 1 | |
stride_doz, stride_doh, stride_dom, stride_dod = 67108864, 128, 1024, 1 | |
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 67108864, 128, 1024, 1 | |
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 67108864, 128, 1024, 1 | |
ZQ = 1 | |
HQ = 8 | |
HKV = 8 | |
Q_LEN = 65536 | |
ZKV = 1 | |
KV_LEN = 65536 | |
MATMUL_PRECISION = Q.dtype.element_ty | |
pid = tl.program_id(0) | |
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) | |
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) | |
off_hz = tl.program_id(2) | |
off_zq = off_hz // HKV # q batch idx | |
off_hkv = off_hz % HKV # kv head idx | |
off_zkv = off_zq % ZKV # kv batch idx | |
SPARSE_Z = 1 | |
SPARSE_HQ = 1 | |
sparse_idx_z = off_zq % SPARSE_Z | |
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) | |
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) | |
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) | |
# offset K, V, DV pointers for batch/kv-head | |
K += k_adj | |
V += v_adj | |
DV += dv_adj | |
RCP_LN2 = 1.44269504 | |
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
if pid >= NUM_KV_BLOCKS: | |
off_pid = pid - NUM_KV_BLOCKS | |
# THIS BLOCK DOES DQ | |
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) | |
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS | |
start_m2_block = off_pid % NUM_Q_BLOCKS | |
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE | |
stride_kv_num_blks_h = 512 | |
stride_kv_idx_h = 262144 | |
stride_kv_idx_m = 512 | |
sparse_idx_hq2 = off_hq2 % SPARSE_HQ | |
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 | |
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask | |
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 | |
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. | |
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) | |
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) | |
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) | |
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) | |
Q2 = Q + q_adj2 | |
DO2 = DO + do_adj2 | |
# TODO: This does not work if DQ is not the same layout as Q (for example, | |
# if Q is broadcasted) | |
DQ2 = DQ + dq_adj2 | |
LSE2 = LSE + off_chz2 | |
DELTA2 = DELTA + off_chz2 | |
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) | |
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
start_m2 = start_m2_block * BLOCK_M2 | |
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) | |
# load Q and do: they stay in SRAM throughout the inner loop. | |
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) | |
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
if PRESCALE_QK: | |
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
if IS_DIVISIBLE: | |
Di = tl.load(DELTA2 + offs_m2) | |
lse = tl.load(LSE2 + offs_m2) | |
else: | |
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) | |
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) | |
lse = tl.where(lse == -float("inf"), 0.0, lse) | |
lse = lse[:, None] | |
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# KV_IDX and KV_NUM_BLKS are always contiguous. | |
kv_indices = KV_IDX + sparse_kv_idx_offset | |
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
dq = bwd_dq_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
K, V, | |
dq, q, do, Di, lse, | |
off_zq, off_hq2, offs_m2, offs_n2, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=False, | |
) | |
if HAS_FULL_BLOCKS: | |
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. | |
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset | |
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading | |
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) | |
offs_n2 = kv_start + tl.arange(0, BLOCK_N2) | |
dq = bwd_dq_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
K, V, | |
dq, q, do, Di, lse, | |
off_zq, off_hq2, offs_m2, offs_n2, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=True, | |
) | |
# Write back dQ. | |
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd | |
dq *= SM_SCALE | |
if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
tl.store(dq_ptrs, dq) | |
else: | |
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) | |
else: | |
# THIS BLOCK DOES DK & DV | |
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) | |
pid_mask = pid // SPARSE_KV_MULTIPLE | |
stride_q_num_blks_h = 512 | |
stride_q_idx_h = 262144 | |
stride_q_idx_n = 512 | |
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) | |
start_n1 = pid * BLOCK_N1 | |
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) | |
# load K and V: they stay in SRAM throughout the inner loop. | |
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) | |
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) | |
if PRESCALE_QK: | |
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) | |
for off_g in range(0, GQA_SHARED_HEADS): | |
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g | |
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads. | |
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) | |
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) | |
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) | |
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) | |
Q1 = Q + q_adj1 | |
DO1 = DO + do_adj1 | |
# TODO: This does not work if DQ is not the same layout as Q (for example, | |
# if Q is broadcasted) | |
LSE1 = LSE + off_chz1 | |
DELTA1 = DELTA + off_chz1 | |
sparse_idx_hq1 = off_hq1 % SPARSE_HQ | |
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 | |
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask | |
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 | |
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# Q_IDX and Q_NUM_BLKS are always contiguous. | |
q_indices = Q_IDX + sparse_q_idx_offset | |
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) | |
offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
dk, dv = bwd_dkdv_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
Q1, DO1, DELTA1, LSE1, | |
dk, dv, k, v, | |
off_zq, off_hq1, offs_n1, offs_m1, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=False, | |
) | |
if HAS_FULL_BLOCKS: | |
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. | |
q_indices = FULL_Q_IDX + sparse_q_idx_offset | |
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading | |
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) | |
offs_m1 = q_start + tl.arange(0, BLOCK_M1) | |
dk, dv = bwd_dkdv_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
Q1, DO1, DELTA1, LSE1, | |
dk, dv, k, v, | |
off_zq, off_hq1, offs_n1, offs_m1, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS=True, | |
) | |
# Write back dV and dK. | |
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd | |
index_n = offs_n1[:, None] | |
index_k = offs_k[None, :] | |
index_v = offs_v[None, :] | |
if IS_DIVISIBLE and SAFE_HEAD_DIM: | |
tl.store(dv_ptrs, dv) | |
else: | |
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) | |
dk *= SM_SCALE | |
if SAFE_HEAD_DIM: | |
mask = index_n < KV_LEN | |
else: | |
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) | |
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] | |
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] | |
xindex = index_k + 128*index_n + 8388608*off_hkv + 67108864*off_zq | |
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask) | |
@triton.jit | |
def bwd_dq_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
K, V, # pointers | |
dq, q, do, Di, lse, | |
off_z, off_hq, offs_m2, offs_n2, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) | |
RCP_LN2: tl.constexpr = 1.44269504 | |
Q_LEN = 65536 | |
KV_LEN = 65536 | |
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd | |
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd | |
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) | |
if not IS_DIVISIBLE: | |
if hi >= 1: | |
for start_n in range(0, hi - 1): | |
dq = bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_n, kv_indices, sparse_kv_num_blocks, | |
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
) | |
kT_ptrs += offset * stride_kn | |
vT_ptrs += offset * stride_vn | |
offs_n2 += offset | |
dq = bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
) | |
else: | |
for start_n in range(0, hi): | |
dq = bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_n, kv_indices, sparse_kv_num_blocks, | |
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS | |
) | |
kT_ptrs += offset * stride_kn | |
vT_ptrs += offset * stride_vn | |
offs_n2 += offset | |
return dq | |
@triton.jit | |
def bwd_dq_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, | |
stride_kn, stride_kd, stride_vn, stride_vd, | |
kv_indices, sparse_kv_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
# NB reversed order to since K is transposed | |
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) | |
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) | |
if not PRESCALE_QK: | |
qk *= SM_SCALE | |
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
pre_mod_scores = qk | |
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) | |
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim | |
# that the M reads out of bounds prior to the last loop | |
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) | |
tmp0 = (qk) | |
post_mod_scores = tmp0 | |
if CHECK_BLOCK_BOUNDARY: | |
# Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) | |
if not IS_FULL_BLOCKS: | |
tmp1 = (m) | |
tmp2 = (n) | |
tmp3 = tmp1 >= tmp2 | |
tmp4 = tl.load(in_ptr16 + tmp1) | |
tmp5 = tl.load(in_ptr16 + tmp2) | |
tmp6 = tmp4 == tmp5 | |
tmp7 = tmp3 & tmp6 | |
mask_mod_output = tmp7 | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) | |
# apply mask for partial masked block | |
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if not PRESCALE_QK: | |
post_mod_scores *= RCP_LN2 | |
p = tl.math.exp2(post_mod_scores - lse) | |
# Compute dP and dS. | |
# NB reversed order to since V is transposed | |
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) | |
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) | |
ds = p * (dp - Di[:, None]) | |
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
tmp8 = (ds) | |
grad_scores = tmp8 | |
if CHECK_BLOCK_BOUNDARY: | |
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) | |
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
if WRITE_DQ: | |
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
ds = grad_scores | |
if not IS_FULL_BLOCKS: | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) | |
# (grads) apply mask for partially unmasked block | |
ds = tl.where(mask_mod_output, ds, 0.0) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
ds = ds.to(MATMUL_PRECISION) | |
# Compute dQ. | |
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) | |
return dq | |
@triton.jit | |
def bwd_dkdv_inner( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
Q, DO, DELTA, LSE, # pointers | |
dk, dv, k, v, | |
off_z, off_hq, offs_n1, offs_m1, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, | |
IS_FULL_BLOCKS, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) | |
RCP_LN2: tl.constexpr = 1.44269504 | |
Q_LEN = 65536 | |
KV_LEN = 65536 | |
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) | |
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) | |
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd | |
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod | |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) | |
if not IS_DIVISIBLE: | |
if hi >= 1: | |
for start_m in range(0, hi - 1): | |
dk, dv = bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_m, q_indices, sparse_q_num_blocks, | |
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
) | |
qT_ptrs += offset * stride_qm | |
do_ptrs += offset * stride_dom | |
offs_m1 += offset | |
dk, dv = bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, | |
) | |
else: | |
for start_m in range(0, hi): | |
dk, dv = bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, | |
) | |
# Increment pointers. | |
offset = get_offset_for_next_block( | |
start_m, q_indices, sparse_q_num_blocks, | |
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS | |
) | |
qT_ptrs += offset * stride_qm | |
do_ptrs += offset * stride_dom | |
offs_m1 += offset | |
return dk, dv | |
@triton.jit | |
def bwd_dkdv_block_mn( | |
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0, | |
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, | |
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, | |
stride_qm, stride_qd, stride_dom, stride_dod, | |
q_indices, sparse_q_num_blocks, | |
MATMUL_PRECISION, RCP_LN2, | |
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, | |
): | |
PRESCALE_QK : tl.constexpr = False | |
ROWS_GUARANTEED_SAFE : tl.constexpr = False | |
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False | |
WRITE_DQ : tl.constexpr = True | |
OUTPUT_LOGSUMEXP : tl.constexpr = True | |
FLOAT32_PRECISION : tl.constexpr = 'ieee' | |
IS_DIVISIBLE : tl.constexpr = True | |
SM_SCALE : tl.constexpr = 0.12 | |
GQA_SHARED_HEADS : tl.constexpr = 1 | |
HAS_FULL_BLOCKS : tl.constexpr = True | |
QK_HEAD_DIM : tl.constexpr = 128 | |
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
V_HEAD_DIM : tl.constexpr = 128 | |
V_HEAD_DIM_ROUNDED : tl.constexpr = 128 | |
SAFE_HEAD_DIM : tl.constexpr = True | |
BLOCK_M1 : tl.constexpr = 64 | |
BLOCK_N1 : tl.constexpr = 128 | |
BLOCK_M2 : tl.constexpr = 128 | |
BLOCK_N2 : tl.constexpr = 64 | |
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128 | |
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128 | |
# NB reversed order since Q is transposed | |
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) | |
# Load LSE before computing qk to reduce pipeline stall. | |
if IS_DIVISIBLE: | |
lse = tl.load(LSE + offs_m1) | |
else: | |
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) | |
lse = tl.where(lse == -float("inf"), 0.0, lse) | |
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) | |
if not PRESCALE_QK: | |
qkT *= SM_SCALE | |
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ | |
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) | |
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim | |
# that the n reads out of bounds prior to the last loop | |
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) | |
pre_mod_scores = qkT | |
tmp9 = (qkT) | |
post_mod_scores = tmp9 | |
if CHECK_BLOCK_BOUNDARY: | |
# Mask out the elements that are out of the KV_LEN for non divisible seqlen. | |
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) | |
if not IS_FULL_BLOCKS: | |
tmp10 = (m) | |
tmp11 = (n) | |
tmp12 = tmp10 >= tmp11 | |
tmp13 = tl.load(in_ptr16 + tmp10) | |
tmp14 = tl.load(in_ptr16 + tmp11) | |
tmp15 = tmp13 == tmp14 | |
tmp16 = tmp12 & tmp15 | |
mask_mod_output = tmp16 | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) | |
# (grads) apply mask for fully masked block | |
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if not PRESCALE_QK: | |
post_mod_scores *= RCP_LN2 | |
pT = tl.math.exp2(post_mod_scores - lse[None, :]) | |
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) | |
# Compute dV. | |
ppT = pT | |
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) | |
if IS_DIVISIBLE: | |
Di = tl.load(DELTA + offs_m1) | |
else: | |
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) | |
# Compute dP and dS. | |
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) | |
dsT = pT * (dpT - Di[None, :]) | |
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ | |
tmp17 = (dsT) | |
grad_scores = tmp17 | |
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ | |
if not WRITE_DQ: | |
idx_b = off_z | |
idx_h = off_hq | |
idx_m = m | |
idx_n = n | |
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
if CHECK_BLOCK_BOUNDARY: | |
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) | |
dsT = grad_scores | |
if not IS_FULL_BLOCKS: | |
if CHECK_BLOCK_BOUNDARY: | |
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) | |
# (grads) apply mask for partially unmasked block | |
dsT = tl.where(mask_mod_output, dsT, 0.0) | |
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | |
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) | |
return dk, dv | |
@triton.jit | |
def get_offset_for_next_block( | |
loop_iter, col_indices, total_blocks, | |
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, | |
BLOCKS_ARE_CONTIGUOUS: tl.constexpr | |
): | |
if BLOCKS_ARE_CONTIGUOUS: | |
return BLOCK | |
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE | |
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") | |
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) | |
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 | |
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK | |
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK | |
return offset | |
@triton.jit | |
def get_bounded_indices(indices, max_len=None): | |
return indices % max_len if max_len is not None else indices | |
@triton.jit | |
def load_checked_2d( | |
ptr, | |
offs_m, | |
offs_n, | |
stride_m, | |
stride_n, | |
IS_DIVISIBLE_M: tl.constexpr, | |
IS_DIVISIBLE_N: tl.constexpr, | |
M_LEN: tl.constexpr, | |
N_DIM: tl.constexpr, | |
): | |
# Calculate final pointer if strides are provided | |
if stride_m is not None and stride_n is not None: | |
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n | |
# Handle all masking cases | |
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) | |
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: | |
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) | |
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: | |
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) | |
else: # Both divisible | |
return tl.load(ptr) | |
''', device_str='cuda') | |
meta0 = {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.12, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128} | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/sl/csl7aubmdgjkdh44byrp4ynq374ifxk76hyvh6pkbszhj2arqklp.py | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten.sum, aten._to_copy] | |
# Source node to ATen node mapping: | |
# v_43 => convert_element_type_308, convert_element_type_309, mul_234 | |
# Graph fragment: | |
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {}) | |
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {}) | |
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {}) | |
# %mul_234 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_308, %rsqrt_60), kwargs = {}) | |
# %convert_element_type_309 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_234, torch.bfloat16), kwargs = {}) | |
# %mul_267 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %convert_element_type_309), kwargs = {}) | |
# %sum_14 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_267,), kwargs = {}) | |
triton_red_fused__to_copy_mul_sum_8 = async_compile.triton('triton_red_fused__to_copy_mul_sum_8', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 512, 'r0_': 131072}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_mul_sum_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 512 | |
r0_numel = 131072 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x0 = xindex | |
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
_tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp8 = tl.load(in_ptr3 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0) | |
tmp2 = tmp0 * tmp1 | |
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK]) | |
tmp5 = _tmp4 + tmp3 | |
_tmp4 = tl.where(xmask, tmp5, _tmp4) | |
tmp7 = tmp6.to(tl.float32) | |
tmp9 = tmp7 * tmp8 | |
tmp10 = tmp9.to(tl.float32) | |
tmp11 = tmp0 * tmp10 | |
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK]) | |
tmp14 = _tmp13 + tmp12 | |
_tmp13 = tl.where(xmask, tmp14, _tmp13) | |
tmp4 = tl.sum(_tmp4, 1)[:, None] | |
tmp13 = tl.sum(_tmp13, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp4, xmask) | |
tl.store(out_ptr1 + (x0), tmp13, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ip/cipvpmrkhwlcho5qp6b5hh7tf5gzn2zmmw6oqxmrgh7e6dgbrpl4.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {}) | |
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {}) | |
triton_per_fused_mul_sum_9 = async_compile.triton('triton_per_fused_mul_sum_9', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 1, 'r0_': 512}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_sum_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused_mul_sum_9(in_ptr0, out_ptr0, xnumel, r0_numel): | |
xnumel = 1 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 512 | |
R0_BLOCK: tl.constexpr = 512 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_0 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_0), None) | |
tmp1 = tl.broadcast_to(tmp0, [R0_BLOCK]) | |
tmp3 = triton_helpers.promote_to_tensor(tl.sum(tmp1, 0)) | |
tl.store(out_ptr0 + (tl.full([1], 0, tl.int32)), tmp3, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/t7/ct7jp2xwhotrjgyyvvxtn4sdhjkd3odeyv5jotixf4j4625b7o64.py | |
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_43 => convert_element_type_308 | |
# Graph fragment: | |
# %mul_266 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %select_101), kwargs = {}) | |
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {}) | |
# %convert_element_type_349 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_266, torch.float32), kwargs = {}) | |
# %mul_268 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %convert_element_type_308), kwargs = {}) | |
# %mul_269 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %rsqrt_60), kwargs = {}) | |
# %sum_15 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_268, [3], True), kwargs = {}) | |
# %div_5 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_22, 128), kwargs = {}) | |
# %pow_89 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_308, 1.0), kwargs = {}) | |
# %mul_272 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_89, 2.0), kwargs = {}) | |
# %mul_273 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_5, %mul_272), kwargs = {}) | |
# %add_185 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_269, %mul_273), kwargs = {}) | |
# %convert_element_type_350 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_185, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_10 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_10', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (78)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (78)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/qm/cqmbpar5nxuou6yq7grwenf2cpvetjywkppq53bzvstjg6xoegcz.py | |
# Topologically Sorted Source Nodes: [k_43, q_43], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# k_43 => convert_element_type_302 | |
# q_43 => convert_element_type_300 | |
# Graph fragment: | |
# %cat_30 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_187, %add_186], 3), kwargs = {}) | |
# %cat_31 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_189, %add_188], 3), kwargs = {}) | |
# %convert_element_type_302 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_139, torch.float32), kwargs = {}) | |
# %mul_282 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %convert_element_type_302), kwargs = {}) | |
# %mul_283 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %rsqrt_59), kwargs = {}) | |
# %sum_16 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_282, [3], True), kwargs = {}) | |
# %div_6 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_23, 128), kwargs = {}) | |
# %pow_91 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_302, 1.0), kwargs = {}) | |
# %mul_286 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_91, 2.0), kwargs = {}) | |
# %mul_287 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_6, %mul_286), kwargs = {}) | |
# %add_190 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_283, %mul_287), kwargs = {}) | |
# %convert_element_type_356 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_190, torch.bfloat16), kwargs = {}) | |
# %convert_element_type_300 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_138, torch.float32), kwargs = {}) | |
# %mul_288 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %convert_element_type_300), kwargs = {}) | |
# %mul_289 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %rsqrt_58), kwargs = {}) | |
# %sum_17 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_288, [3], True), kwargs = {}) | |
# %div_7 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_24, 128), kwargs = {}) | |
# %pow_93 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_300, 1.0), kwargs = {}) | |
# %mul_292 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_93, 2.0), kwargs = {}) | |
# %mul_293 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_7, %mul_292), kwargs = {}) | |
# %add_191 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_289, %mul_293), kwargs = {}) | |
# %convert_element_type_358 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_191, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11 = async_compile.triton('triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr3': '*bf16', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 21, 'num_reduction': 2, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
x1 = xindex // 8 | |
x0 = (xindex % 8) | |
_tmp55 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp51 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp0 = r0_2 | |
tmp1 = tl.full([1, 1], 0, tl.int64) | |
tmp2 = tmp0 >= tmp1 | |
tmp3 = tl.full([1, 1], 64, tl.int64) | |
tmp4 = tmp0 < tmp3 | |
tmp5 = tl.load(in_ptr0 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tl.load(in_ptr1 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0) | |
tmp8 = -tmp7 | |
tmp9 = tmp6 * tmp8 | |
tmp10 = tl.load(in_ptr0 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp11 = tmp10.to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0) | |
tmp13 = tmp11 * tmp12 | |
tmp14 = tmp9 + tmp13 | |
tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype) | |
tmp16 = tl.where(tmp4, tmp14, tmp15) | |
tmp17 = tmp0 >= tmp3 | |
tmp18 = tl.full([1, 1], 128, tl.int64) | |
tmp19 = tmp0 < tmp18 | |
tmp20 = tl.load(in_ptr0 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp21 = tmp20.to(tl.float32) | |
tmp22 = tl.load(in_ptr2 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0) | |
tmp23 = tmp21 * tmp22 | |
tmp24 = tl.load(in_ptr0 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = tl.load(in_ptr1 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0) | |
tmp27 = tmp25 * tmp26 | |
tmp28 = tmp23 + tmp27 | |
tmp29 = tl.full(tmp28.shape, 0.0, tmp28.dtype) | |
tmp30 = tl.where(tmp17, tmp28, tmp29) | |
tmp31 = tl.where(tmp4, tmp16, tmp30) | |
tmp32 = tl.load(in_ptr3 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp33 = tmp32.to(tl.float32) | |
tmp34 = tmp33 * tmp8 | |
tmp35 = tl.load(in_ptr3 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tmp36 * tmp12 | |
tmp38 = tmp34 + tmp37 | |
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype) | |
tmp40 = tl.where(tmp4, tmp38, tmp39) | |
tmp41 = tl.load(in_ptr3 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp42 * tmp22 | |
tmp44 = tl.load(in_ptr3 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp45 = tmp44.to(tl.float32) | |
tmp46 = tmp45 * tmp26 | |
tmp47 = tmp43 + tmp46 | |
tmp48 = tl.full(tmp47.shape, 0.0, tmp47.dtype) | |
tmp49 = tl.where(tmp17, tmp47, tmp48) | |
tmp50 = tl.where(tmp4, tmp40, tmp49) | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp31 * tmp52 | |
tmp54 = tl.broadcast_to(tmp53, [XBLOCK, R0_BLOCK]) | |
tmp56 = _tmp55 + tmp54 | |
_tmp55 = tl.where(r0_mask, tmp56, _tmp55) | |
tl.store(out_ptr0 + (r0_2 + 128*x3), tmp31, r0_mask) | |
tl.store(out_ptr1 + (r0_2 + 128*x3), tmp50, r0_mask) | |
tmp55 = tl.sum(_tmp55, 1)[:, None] | |
tmp58 = tl.load(in_ptr5 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp57 = tl.load(out_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0) | |
tmp67 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp59 = tmp57 * tmp58 | |
tmp60 = -0.5 | |
tmp61 = tmp55 * tmp60 | |
tmp62 = tmp58 * tmp58 | |
tmp63 = tmp62 * tmp58 | |
tmp64 = tmp61 * tmp63 | |
tmp65 = 0.0078125 | |
tmp66 = tmp64 * tmp65 | |
tmp68 = tmp67.to(tl.float32) | |
tmp69 = 2.0 | |
tmp70 = tmp68 * tmp69 | |
tmp71 = tmp66 * tmp70 | |
tmp72 = tmp59 + tmp71 | |
tmp73 = tmp72.to(tl.float32) | |
tl.store(out_ptr3 + (r0_2 + 128*x0 + 3072*x1), tmp73, r0_mask) | |
_tmp79 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp74 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0) | |
tmp75 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp76 = tmp75.to(tl.float32) | |
tmp77 = tmp74 * tmp76 | |
tmp78 = tl.broadcast_to(tmp77, [XBLOCK, R0_BLOCK]) | |
tmp80 = _tmp79 + tmp78 | |
_tmp79 = tl.where(r0_mask, tmp80, _tmp79) | |
tmp79 = tl.sum(_tmp79, 1)[:, None] | |
tmp82 = tl.load(in_ptr7 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp81 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0) | |
tmp91 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp83 = tmp81 * tmp82 | |
tmp84 = -0.5 | |
tmp85 = tmp79 * tmp84 | |
tmp86 = tmp82 * tmp82 | |
tmp87 = tmp86 * tmp82 | |
tmp88 = tmp85 * tmp87 | |
tmp89 = 0.0078125 | |
tmp90 = tmp88 * tmp89 | |
tmp92 = tmp91.to(tl.float32) | |
tmp93 = 2.0 | |
tmp94 = tmp92 * tmp93 | |
tmp95 = tmp90 * tmp94 | |
tmp96 = tmp83 + tmp95 | |
tmp97 = tmp96.to(tl.float32) | |
tl.store(out_ptr5 + (r0_2 + 128*x0 + 3072*x1), tmp97, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/k2/ck2d2pyt4c2xy4hbpuabbajwzm2ocm7rodqo6uzeivtga7ibn3pa.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_18 : [num_users=30] = call_function[target=torch.ops.aten.full.default](args = ([4, 1024, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_18, %mm_69, 0, 3), kwargs = {}) | |
# %slice_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_18, %view_196, 0, 0, 3), kwargs = {}) | |
# %add_193 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default, %slice_scatter_default), kwargs = {}) | |
triton_poi_fused_add_select_backward_12 = async_compile.triton('triton_poi_fused_add_select_backward_12', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 4194304}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_select_backward_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_select_backward_12(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 4194304 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x1 = xindex // 1048576 | |
x0 = (xindex % 1048576) | |
x2 = xindex | |
tmp3 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32) | |
tmp0 = x1 | |
tmp1 = tl.full([1], 3, tl.int32) | |
tmp2 = tmp0 == tmp1 | |
tmp4 = 0.0 | |
tmp5 = tl.where(tmp2, tmp3, tmp4) | |
tmp6 = tl.full([1], 3, tl.int64) | |
tmp7 = tmp0 < tmp6 | |
tmp8 = tl.load(in_ptr1 + (x2), tmp7, other=0.0).to(tl.float32) | |
tmp9 = tl.where(tmp7, tmp8, tmp4) | |
tmp10 = tmp5 + tmp9 | |
tl.store(out_ptr0 + (x2), tmp10, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/a5/ca56au744fshqomiwmgjdngi2et5vflsgkwrpod6zptxlqclh6dv.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %mul_296 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_99), kwargs = {}) | |
triton_poi_fused_add_mul_13 = async_compile.triton('triton_poi_fused_add_mul_13', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_13(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (46)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/md/cmdhjwnox4mptoiaadn5oichuh2ehyfrlgqib4p7z5y6kig7caz6.py | |
# Topologically Sorted Source Nodes: [rms_norm_57], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# rms_norm_57 => convert_element_type_292 | |
# Graph fragment: | |
# %convert_element_type_373 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_200, torch.float32), kwargs = {}) | |
# %convert_element_type_292 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_163, torch.float32), kwargs = {}) | |
# %mul_300 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %convert_element_type_292), kwargs = {}) | |
# %mul_301 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %rsqrt_57), kwargs = {}) | |
# %sum_20 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_300, [2], True), kwargs = {}) | |
# %div_8 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_25, 1024), kwargs = {}) | |
# %pow_96 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_292, 1.0), kwargs = {}) | |
# %mul_304 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_96, 2.0), kwargs = {}) | |
# %mul_305 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_8, %mul_304), kwargs = {}) | |
# %add_195 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_301, %mul_305), kwargs = {}) | |
# %convert_element_type_374 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_195, torch.bfloat16), kwargs = {}) | |
# %add_196 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_296, %convert_element_type_374), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_14 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_14', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_14(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp10 = tmp1 * tmp9 | |
tmp11 = -0.5 | |
tmp12 = tmp7 * tmp11 | |
tmp13 = tmp9 * tmp9 | |
tmp14 = tmp13 * tmp9 | |
tmp15 = tmp12 * tmp14 | |
tmp16 = 0.0009765625 | |
tmp17 = tmp15 * tmp16 | |
tmp18 = 2.0 | |
tmp19 = tmp3 * tmp18 | |
tmp20 = tmp17 * tmp19 | |
tmp21 = tmp10 + tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp8 + tmp22 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/xg/cxged3dxs4opgztavvr7fkyy2lg6uqfqee4esalldmdyijnennf5.py | |
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_40 => convert_element_type_288 | |
# Graph fragment: | |
# %mul_308 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_164, %select_94), kwargs = {}) | |
# %convert_element_type_288 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_131, torch.float32), kwargs = {}) | |
# %convert_element_type_381 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_308, torch.float32), kwargs = {}) | |
# %mul_310 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %convert_element_type_288), kwargs = {}) | |
# %mul_311 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %rsqrt_56), kwargs = {}) | |
# %sum_23 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_310, [3], True), kwargs = {}) | |
# %div_9 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_26, 128), kwargs = {}) | |
# %pow_98 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_288, 1.0), kwargs = {}) | |
# %mul_314 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_98, 2.0), kwargs = {}) | |
# %mul_315 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_9, %mul_314), kwargs = {}) | |
# %add_198 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_311, %mul_315), kwargs = {}) | |
# %convert_element_type_382 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_198, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_15 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_15', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_15(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (76)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (76)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/st/cstxqu5wndh2mme6n7qysm6mhvq6r36u3keo4fwrmmgxo5rjyjzu.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {}) | |
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {}) | |
# %mul_297 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %add_165), kwargs = {}) | |
# %sum_19 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_297,), kwargs = {}) | |
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {}) | |
# %mul_337 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %convert_element_type_11), kwargs = {}) | |
# %sum_26 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_337,), kwargs = {}) | |
# %mul_338 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_92), kwargs = {}) | |
# %mul_339 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %add_153), kwargs = {}) | |
# %sum_27 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_339,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_16 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_16', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 9, 'num_reduction': 4, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_16(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_ptr8 + (44)) | |
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = tmp19 * tmp31 | |
tl.store(out_ptr4 + (r0_1 + 1024*x0), tmp32, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/4y/c4ykabbdrwrdol5x6eczezz3t7rsq5fk7zv6xmtnuvk22t2q5jdc.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {}) | |
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {}) | |
triton_red_fused__to_copy_add_mul_sum_17 = async_compile.triton('triton_red_fused__to_copy_add_mul_sum_17', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 1, 'r0_': 65536}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mul_sum_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_mul_sum_17(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 1 | |
r0_numel = 65536 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_0 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_0), None, eviction_policy='evict_first') | |
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK]) | |
tmp3 = _tmp2 + tmp1 | |
_tmp2 = tmp3 | |
tmp2 = tl.sum(_tmp2, 1)[:, None] | |
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp2, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/vf/cvf3n5zrn4y5xotwv7lzpbxws7wmfhbxv2cnzgqiqengfefpfz4g.py | |
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_37 => convert_element_type_268 | |
# Graph fragment: | |
# %mul_350 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_184, %select_87), kwargs = {}) | |
# %convert_element_type_268 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_122, torch.float32), kwargs = {}) | |
# %convert_element_type_413 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_350, torch.float32), kwargs = {}) | |
# %mul_352 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %convert_element_type_268), kwargs = {}) | |
# %mul_353 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %rsqrt_52), kwargs = {}) | |
# %sum_31 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_352, [3], True), kwargs = {}) | |
# %div_13 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_30, 128), kwargs = {}) | |
# %pow_107 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_268, 1.0), kwargs = {}) | |
# %mul_356 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_107, 2.0), kwargs = {}) | |
# %mul_357 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_13, %mul_356), kwargs = {}) | |
# %add_214 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_353, %mul_357), kwargs = {}) | |
# %convert_element_type_414 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_214, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_18 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_18', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_18', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_18(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (74)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (74)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/he/chewvuajuq55rjjxwpgjkdibqap2uf4svllmqmaiitre66ivyggh.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten.mul, aten._to_copy, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {}) | |
# %mul_294 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_100), kwargs = {}) | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {}) | |
# %mul_336 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_93), kwargs = {}) | |
# %add_207 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_294, %mul_336), kwargs = {}) | |
# %add_221 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_212, %view_219), kwargs = {}) | |
# %mul_378 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_86), kwargs = {}) | |
# %mul_379 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %convert_element_type_11), kwargs = {}) | |
# %sum_34 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_379,), kwargs = {}) | |
# %add_223 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_207, %mul_378), kwargs = {}) | |
# %mul_380 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_85), kwargs = {}) | |
# %mul_381 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %add_141), kwargs = {}) | |
# %sum_35 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_381,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_19 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_19', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_19', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 2, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_19(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr1 + (47)) | |
tmp4 = tl.broadcast_to(tmp3, [R0_BLOCK]) | |
tmp7 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp10 = tl.load(in_ptr1 + (45)) | |
tmp11 = tl.broadcast_to(tmp10, [R0_BLOCK]) | |
tmp15 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp16 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr1 + (43)) | |
tmp19 = tl.broadcast_to(tmp18, [R0_BLOCK]) | |
tmp23 = tl.load(in_ptr1 + (42)) | |
tmp24 = tl.broadcast_to(tmp23, [R0_BLOCK]) | |
tmp27 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last') | |
tmp36 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tmp9 = tmp7 + tmp8 | |
tmp12 = tmp11.to(tl.float32) | |
tmp13 = tmp9 * tmp12 | |
tmp14 = tmp6 + tmp13 | |
tmp17 = tmp15 + tmp16 | |
tmp20 = tmp19.to(tl.float32) | |
tmp21 = tmp17 * tmp20 | |
tmp22 = tmp14 + tmp21 | |
tmp25 = tmp24.to(tl.float32) | |
tmp26 = tmp17 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = tmp17 * tmp31 | |
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK]) | |
tmp35 = triton_helpers.promote_to_tensor(tl.sum(tmp33, 0)) | |
tmp37 = tmp17 * tmp36 | |
tmp38 = tl.broadcast_to(tmp37, [R0_BLOCK]) | |
tmp40 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0)) | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp22, None) | |
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp26, None) | |
tl.store(out_ptr1 + (x0), tmp35, None) | |
tl.store(out_ptr2 + (x0), tmp40, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/zw/czwk3untwmpllx2wu6vrpai3qkimsauvxusvmc74tmu73zz3k6zn.py | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum] | |
# Source node to ATen node mapping: | |
# v_34 => convert_element_type_248, convert_element_type_249, mul_187 | |
# Graph fragment: | |
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {}) | |
# %mul_187 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_248, %rsqrt_48), kwargs = {}) | |
# %convert_element_type_249 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_187, torch.bfloat16), kwargs = {}) | |
# %mul_391 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %convert_element_type_249), kwargs = {}) | |
# %sum_37 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_391,), kwargs = {}) | |
triton_red_fused__to_copy_mul_sum_20 = async_compile.triton('triton_red_fused__to_copy_mul_sum_20', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 512, 'r0_': 131072}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_20', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_mul_sum_20(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 512 | |
r0_numel = 131072 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = xindex < xnumel | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x0 = xindex | |
_tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0) | |
tmp2 = tmp1.to(tl.float32) | |
tmp4 = tmp2 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp0 * tmp5 | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK]) | |
tmp9 = _tmp8 + tmp7 | |
_tmp8 = tl.where(xmask, tmp9, _tmp8) | |
tmp8 = tl.sum(_tmp8, 1)[:, None] | |
tl.store(out_ptr0 + (x0), tmp8, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/xt/cxt3lc5hr6vgginbc2ivwhx3tof46klook7rqcj5a7asaqgh3yez.py | |
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_34 => convert_element_type_248 | |
# Graph fragment: | |
# %mul_390 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %select_81), kwargs = {}) | |
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {}) | |
# %convert_element_type_444 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_390, torch.float32), kwargs = {}) | |
# %mul_392 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %convert_element_type_248), kwargs = {}) | |
# %mul_393 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %rsqrt_48), kwargs = {}) | |
# %sum_38 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_392, [3], True), kwargs = {}) | |
# %div_17 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_34, 128), kwargs = {}) | |
# %pow_116 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_248, 1.0), kwargs = {}) | |
# %mul_396 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_116, 2.0), kwargs = {}) | |
# %mul_397 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_17, %mul_396), kwargs = {}) | |
# %add_229 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_393, %mul_397), kwargs = {}) | |
# %convert_element_type_445 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_229, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_21 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_21', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_21(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (72)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (72)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/i5/ci5mcynkhhv2drjtnbviwl5ibwpvhvcagyfiku76l7vlecpg2j4c.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {}) | |
# %mul_420 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_79), kwargs = {}) | |
triton_poi_fused_add_mul_22 = async_compile.triton('triton_poi_fused_add_mul_22', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (40)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/2q/c2q426mapna6iekh3hgp2hihn6zqzbqso4gblzbvbonknafgmvyc.py | |
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_31 => convert_element_type_228 | |
# Graph fragment: | |
# %mul_430 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_224, %select_75), kwargs = {}) | |
# %convert_element_type_228 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_104, torch.float32), kwargs = {}) | |
# %convert_element_type_475 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_430, torch.float32), kwargs = {}) | |
# %mul_432 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %convert_element_type_228), kwargs = {}) | |
# %mul_433 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %rsqrt_44), kwargs = {}) | |
# %sum_45 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_432, [3], True), kwargs = {}) | |
# %div_21 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_38, 128), kwargs = {}) | |
# %pow_125 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_228, 1.0), kwargs = {}) | |
# %mul_436 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_125, 2.0), kwargs = {}) | |
# %mul_437 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_21, %mul_436), kwargs = {}) | |
# %add_244 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_433, %mul_437), kwargs = {}) | |
# %convert_element_type_476 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_244, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_23 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_23', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (70)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (70)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/7j/c7j54wiarheofqb4gbyajl45ernnjlddmdz77p5ehb46w46cawmo.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {}) | |
# %mul_419 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %convert_element_type_11), kwargs = {}) | |
# %sum_41 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_419,), kwargs = {}) | |
# %mul_421 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %add_130), kwargs = {}) | |
# %sum_42 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_421,), kwargs = {}) | |
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {}) | |
# %mul_459 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %convert_element_type_11), kwargs = {}) | |
# %sum_48 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_459,), kwargs = {}) | |
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {}) | |
# %mul_461 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %add_119), kwargs = {}) | |
# %sum_49 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_461,), kwargs = {}) | |
# %mul_463 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %add_36), kwargs = {}) | |
# %sum_50 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_463,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_24 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_24', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'in_ptr9': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 10, 'num_reduction': 5, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_24(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_ptr8 + (38)) | |
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK]) | |
tmp33 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp30.to(tl.float32) | |
tmp32 = tmp19 * tmp31 | |
tmp34 = tmp32 * tmp33 | |
tmp35 = tl.broadcast_to(tmp34, [R0_BLOCK]) | |
tmp37 = triton_helpers.promote_to_tensor(tl.sum(tmp35, 0)) | |
tl.store(out_ptr5 + (r0_1 + 1024*x0), tmp32, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
tl.store(out_ptr4 + (x0), tmp37, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/f6/cf6ojfz6bfezrzis2b5mf55ez4x3btg3ufefdfa5vevfoiqfwice.py | |
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_28 => convert_element_type_208 | |
# Graph fragment: | |
# %mul_472 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_244, %select_68), kwargs = {}) | |
# %convert_element_type_208 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_95, torch.float32), kwargs = {}) | |
# %convert_element_type_507 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_472, torch.float32), kwargs = {}) | |
# %mul_474 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %convert_element_type_208), kwargs = {}) | |
# %mul_475 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %rsqrt_40), kwargs = {}) | |
# %sum_53 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_474, [3], True), kwargs = {}) | |
# %div_25 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_42, 128), kwargs = {}) | |
# %pow_134 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_208, 1.0), kwargs = {}) | |
# %mul_478 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_134, 2.0), kwargs = {}) | |
# %mul_479 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_25, %mul_478), kwargs = {}) | |
# %add_259 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_475, %mul_479), kwargs = {}) | |
# %convert_element_type_508 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_259, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_25 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_25', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (68)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (68)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/yt/cytha5ygxd572kkbk6gsg75fiya7avvgj346towpq75mnegn35vm.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {}) | |
# %mul_418 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_80), kwargs = {}) | |
# %add_238 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_223, %mul_418), kwargs = {}) | |
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {}) | |
# %mul_458 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_74), kwargs = {}) | |
# %add_253 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_238, %mul_458), kwargs = {}) | |
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {}) | |
# %mul_500 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_67), kwargs = {}) | |
# %mul_501 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %convert_element_type_11), kwargs = {}) | |
# %sum_56 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_501,), kwargs = {}) | |
# %add_268 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_253, %mul_500), kwargs = {}) | |
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {}) | |
# %mul_503 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %add_107), kwargs = {}) | |
# %sum_57 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_503,), kwargs = {}) | |
# %mul_505 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %add_58), kwargs = {}) | |
# %sum_58 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_505,), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_26 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_26', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_26', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 15, 'num_reduction': 3, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_26(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (36)) | |
tmp18 = tl.broadcast_to(tmp17, [R0_BLOCK]) | |
tmp21 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp26 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp27 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp28 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp30 = tl.load(in_ptr5 + (41)) | |
tmp31 = tl.broadcast_to(tmp30, [R0_BLOCK]) | |
tmp35 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp36 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp38 = tl.load(in_ptr5 + (39)) | |
tmp39 = tl.broadcast_to(tmp38, [R0_BLOCK]) | |
tmp43 = tl.load(in_ptr5 + (37)) | |
tmp44 = tl.broadcast_to(tmp43, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp18.to(tl.float32) | |
tmp20 = tmp2 * tmp19 | |
tmp22 = tmp20 * tmp21 | |
tmp23 = tl.broadcast_to(tmp22, [R0_BLOCK]) | |
tmp25 = triton_helpers.promote_to_tensor(tl.sum(tmp23, 0)) | |
tmp29 = tmp27 + tmp28 | |
tmp32 = tmp31.to(tl.float32) | |
tmp33 = tmp29 * tmp32 | |
tmp34 = tmp26 + tmp33 | |
tmp37 = tmp35 + tmp36 | |
tmp40 = tmp39.to(tl.float32) | |
tmp41 = tmp37 * tmp40 | |
tmp42 = tmp34 + tmp41 | |
tmp45 = tmp44.to(tl.float32) | |
tmp46 = tmp2 * tmp45 | |
tmp47 = tmp42 + tmp46 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp47, None) | |
tl.store(out_ptr3 + (r0_1 + 1024*x0), tmp20, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp25, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/7j/c7julr3b2yeaunzxy3p4uhxnmaq2mkueth2riqtaxxiceiks7e4f.py | |
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_25 => convert_element_type_188 | |
# Graph fragment: | |
# %mul_514 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_264, %select_61), kwargs = {}) | |
# %convert_element_type_188 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_86, torch.float32), kwargs = {}) | |
# %convert_element_type_539 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_514, torch.float32), kwargs = {}) | |
# %mul_516 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %convert_element_type_188), kwargs = {}) | |
# %mul_517 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %rsqrt_36), kwargs = {}) | |
# %sum_61 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_516, [3], True), kwargs = {}) | |
# %div_29 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_46, 128), kwargs = {}) | |
# %pow_143 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_188, 1.0), kwargs = {}) | |
# %mul_520 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_143, 2.0), kwargs = {}) | |
# %mul_521 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_29, %mul_520), kwargs = {}) | |
# %add_275 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_517, %mul_521), kwargs = {}) | |
# %convert_element_type_540 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_275, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_27 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_27', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_27(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (66)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (66)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/px/cpxzk6lleynsgxocbapsdzouctqzdj734tcdpezlppd4n46yo5ov.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {}) | |
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {}) | |
triton_poi_fused_add_mul_28 = async_compile.triton('triton_poi_fused_add_mul_28', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_28(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (34)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ha/chasxebvoqyu5qi3wjqxf4oumcj4fcqploj745fbxu6kdwm3tfoh.py | |
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_22 => convert_element_type_168 | |
# Graph fragment: | |
# %mul_556 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_284, %select_54), kwargs = {}) | |
# %convert_element_type_168 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_77, torch.float32), kwargs = {}) | |
# %convert_element_type_571 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_556, torch.float32), kwargs = {}) | |
# %mul_558 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %convert_element_type_168), kwargs = {}) | |
# %mul_559 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %rsqrt_32), kwargs = {}) | |
# %sum_69 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_558, [3], True), kwargs = {}) | |
# %div_33 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_50, 128), kwargs = {}) | |
# %pow_152 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_168, 1.0), kwargs = {}) | |
# %mul_562 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_152, 2.0), kwargs = {}) | |
# %mul_563 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_33, %mul_562), kwargs = {}) | |
# %add_291 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_559, %mul_563), kwargs = {}) | |
# %convert_element_type_572 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_291, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_29 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_29', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_29(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (64)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (64)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/hi/chiekkfjcumovdpaptfvgzwej74uquxfu45x5g5on3x3zr6rjrc2.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {}) | |
# %mul_586 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_52), kwargs = {}) | |
triton_poi_fused_add_mul_30 = async_compile.triton('triton_poi_fused_add_mul_30', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_30(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (32)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/a6/ca646ft4uesvxwm7jyd4ahfawp3vmjghfta3dwmjfxoyizhpowhp.py | |
# Topologically Sorted Source Nodes: [x, rms_norm_29], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum, aten.div, aten.pow] | |
# Source node to ATen node mapping: | |
# rms_norm_29 => convert_element_type_152 | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {}) | |
# %mul_542 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_60), kwargs = {}) | |
# %mul_543 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %convert_element_type_11), kwargs = {}) | |
# %sum_64 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_543,), kwargs = {}) | |
# %add_284 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_268, %mul_542), kwargs = {}) | |
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {}) | |
# %mul_545 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %add_95), kwargs = {}) | |
# %sum_65 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_545,), kwargs = {}) | |
# %mul_546 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %select_56), kwargs = {}) | |
# %mul_547 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %add_80), kwargs = {}) | |
# %sum_66 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_547,), kwargs = {}) | |
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {}) | |
# %mul_584 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_53), kwargs = {}) | |
# %mul_585 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %convert_element_type_11), kwargs = {}) | |
# %sum_72 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_585,), kwargs = {}) | |
# %add_300 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_284, %mul_584), kwargs = {}) | |
# %mul_587 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %add_83), kwargs = {}) | |
# %sum_73 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_587,), kwargs = {}) | |
# %convert_element_type_595 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_279, torch.float32), kwargs = {}) | |
# %convert_element_type_152 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_81, torch.float32), kwargs = {}) | |
# %mul_590 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %convert_element_type_152), kwargs = {}) | |
# %mul_591 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %rsqrt_29), kwargs = {}) | |
# %sum_74 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_590, [2], True), kwargs = {}) | |
# %div_36 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_53, 1024), kwargs = {}) | |
# %pow_159 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_152, 1.0), kwargs = {}) | |
# %mul_594 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_159, 2.0), kwargs = {}) | |
# %mul_595 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_36, %mul_594), kwargs = {}) | |
# %add_304 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_591, %mul_595), kwargs = {}) | |
# %convert_element_type_596 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_304, torch.bfloat16), kwargs = {}) | |
# %add_305 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_586, %convert_element_type_596), kwargs = {}) | |
# %mul_596 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_49), kwargs = {}) | |
# %mul_597 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %convert_element_type_11), kwargs = {}) | |
# %sum_75 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_597,), kwargs = {}) | |
# %add_306 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_300, %mul_596), kwargs = {}) | |
# %mul_598 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_48), kwargs = {}) | |
# %mul_599 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %add_80), kwargs = {}) | |
# %sum_76 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_599,), kwargs = {}) | |
# %add_307 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_546, %mul_598), kwargs = {}) | |
triton_per_fused__to_copy_add_div_mul_pow_sum_31 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_31', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*fp32', 'out_ptr6': '*fp32', 'out_ptr7': '*fp32', 'out_ptr8': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]], (19,): [['tt.divisibility', 16]], (20,): [['tt.divisibility', 16]], (21,): [['tt.divisibility', 16]], (22,): [['tt.divisibility', 16]], (23,): [['tt.divisibility', 16]], (24,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_31', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 20, 'num_reduction': 8, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_div_mul_pow_sum_31(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last') | |
tmp24 = tl.load(in_out_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp25 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp26 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp28 = tl.load(in_ptr5 + (35)) | |
tmp29 = tl.broadcast_to(tmp28, [R0_BLOCK]) | |
tmp33 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp34 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp36 = tl.load(in_ptr5 + (33)) | |
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK]) | |
tmp41 = tl.load(in_ptr5 + (31)) | |
tmp42 = tl.broadcast_to(tmp41, [R0_BLOCK]) | |
tmp46 = tl.load(in_ptr5 + (34)) | |
tmp47 = tl.broadcast_to(tmp46, [R0_BLOCK]) | |
tmp50 = tl.load(in_ptr5 + (6)) | |
tmp51 = tl.broadcast_to(tmp50, [R0_BLOCK]) | |
tmp54 = tl.load(in_ptr5 + (30)) | |
tmp55 = tl.broadcast_to(tmp54, [R0_BLOCK]) | |
tmp59 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp61 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last') | |
tmp68 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp73 = tl.load(in_ptr11 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp82 = tl.load(in_ptr12 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tmp0.to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp1 * tmp3 | |
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK]) | |
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0)) | |
tmp10 = tmp1 * tmp9 | |
tmp11 = -0.5 | |
tmp12 = tmp7 * tmp11 | |
tmp13 = tmp9 * tmp9 | |
tmp14 = tmp13 * tmp9 | |
tmp15 = tmp12 * tmp14 | |
tmp16 = 0.0009765625 | |
tmp17 = tmp15 * tmp16 | |
tmp18 = 2.0 | |
tmp19 = tmp3 * tmp18 | |
tmp20 = tmp17 * tmp19 | |
tmp21 = tmp10 + tmp20 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp8 + tmp22 | |
tmp27 = tmp25 + tmp26 | |
tmp30 = tmp29.to(tl.float32) | |
tmp31 = tmp27 * tmp30 | |
tmp32 = tmp24 + tmp31 | |
tmp35 = tmp33 + tmp34 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp35 * tmp38 | |
tmp40 = tmp32 + tmp39 | |
tmp43 = tmp42.to(tl.float32) | |
tmp44 = tmp23 * tmp43 | |
tmp45 = tmp40 + tmp44 | |
tmp48 = tmp47.to(tl.float32) | |
tmp49 = tmp27 * tmp48 | |
tmp52 = tmp51.to(tl.float32) | |
tmp53 = tmp49 * tmp52 | |
tmp56 = tmp55.to(tl.float32) | |
tmp57 = tmp23 * tmp56 | |
tmp58 = tmp53 + tmp57 | |
tmp60 = tmp59.to(tl.float32) | |
tmp62 = tmp60 * tmp61 | |
tmp63 = tmp62.to(tl.float32) | |
tmp64 = tmp27 * tmp63 | |
tmp65 = tl.broadcast_to(tmp64, [R0_BLOCK]) | |
tmp67 = triton_helpers.promote_to_tensor(tl.sum(tmp65, 0)) | |
tmp69 = tmp27 * tmp68 | |
tmp70 = tl.broadcast_to(tmp69, [R0_BLOCK]) | |
tmp72 = triton_helpers.promote_to_tensor(tl.sum(tmp70, 0)) | |
tmp74 = tmp49 * tmp73 | |
tmp75 = tl.broadcast_to(tmp74, [R0_BLOCK]) | |
tmp77 = triton_helpers.promote_to_tensor(tl.sum(tmp75, 0)) | |
tmp78 = tmp35 * tmp63 | |
tmp79 = tl.broadcast_to(tmp78, [R0_BLOCK]) | |
tmp81 = triton_helpers.promote_to_tensor(tl.sum(tmp79, 0)) | |
tmp83 = tmp35 * tmp82 | |
tmp84 = tl.broadcast_to(tmp83, [R0_BLOCK]) | |
tmp86 = triton_helpers.promote_to_tensor(tl.sum(tmp84, 0)) | |
tmp87 = tmp23 * tmp63 | |
tmp88 = tl.broadcast_to(tmp87, [R0_BLOCK]) | |
tmp90 = triton_helpers.promote_to_tensor(tl.sum(tmp88, 0)) | |
tmp91 = tmp23 * tmp73 | |
tmp92 = tl.broadcast_to(tmp91, [R0_BLOCK]) | |
tmp94 = triton_helpers.promote_to_tensor(tl.sum(tmp92, 0)) | |
tl.store(in_out_ptr1 + (r0_1 + 1024*x0), tmp45, None) | |
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp58, None) | |
tl.store(out_ptr2 + (x0), tmp67, None) | |
tl.store(out_ptr3 + (x0), tmp72, None) | |
tl.store(out_ptr4 + (x0), tmp77, None) | |
tl.store(out_ptr5 + (x0), tmp81, None) | |
tl.store(out_ptr6 + (x0), tmp86, None) | |
tl.store(out_ptr7 + (x0), tmp90, None) | |
tl.store(out_ptr8 + (x0), tmp94, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/bw/cbwa3hunkne3dnuzwzdtrx3cnl2c43w6e2yxbyvp2aetgwpzbih3.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten._to_copy, aten.add] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_363 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_18, torch.float32), kwargs = {}) | |
# %select_scatter_default_3 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_363, 0, 1), kwargs = {}) | |
# %convert_element_type_364 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_19, torch.float32), kwargs = {}) | |
# %select_scatter_default_4 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_364, 0, 0), kwargs = {}) | |
# %add_194 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_3, %select_scatter_default_4), kwargs = {}) | |
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_6 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_194, 0, 15), kwargs = {}) | |
# %convert_element_type_395 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_26, torch.float32), kwargs = {}) | |
# %select_scatter_default_10 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_395, 0, 1), kwargs = {}) | |
# %convert_element_type_396 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_27, torch.float32), kwargs = {}) | |
# %select_scatter_default_11 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_396, 0, 0), kwargs = {}) | |
# %add_208 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_10, %select_scatter_default_11), kwargs = {}) | |
# %select_scatter_default_13 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_208, 0, 14), kwargs = {}) | |
# %add_210 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_6, %select_scatter_default_13), kwargs = {}) | |
# %convert_element_type_427 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_34, torch.float32), kwargs = {}) | |
# %select_scatter_default_17 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_427, 0, 1), kwargs = {}) | |
# %convert_element_type_428 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_35, torch.float32), kwargs = {}) | |
# %select_scatter_default_18 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_428, 0, 0), kwargs = {}) | |
# %add_224 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_17, %select_scatter_default_18), kwargs = {}) | |
# %select_scatter_default_20 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_224, 0, 13), kwargs = {}) | |
# %add_226 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_210, %select_scatter_default_20), kwargs = {}) | |
# %convert_element_type_458 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_41, torch.float32), kwargs = {}) | |
# %select_scatter_default_23 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_458, 0, 1), kwargs = {}) | |
# %convert_element_type_459 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_42, torch.float32), kwargs = {}) | |
# %select_scatter_default_24 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_459, 0, 0), kwargs = {}) | |
# %add_239 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_23, %select_scatter_default_24), kwargs = {}) | |
# %select_scatter_default_26 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_239, 0, 12), kwargs = {}) | |
# %add_241 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_226, %select_scatter_default_26), kwargs = {}) | |
# %convert_element_type_489 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_48, torch.float32), kwargs = {}) | |
# %select_scatter_default_29 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_489, 0, 1), kwargs = {}) | |
# %convert_element_type_490 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_49, torch.float32), kwargs = {}) | |
# %select_scatter_default_30 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_490, 0, 0), kwargs = {}) | |
# %add_254 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_29, %select_scatter_default_30), kwargs = {}) | |
# %select_scatter_default_32 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_254, 0, 11), kwargs = {}) | |
# %add_256 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_241, %select_scatter_default_32), kwargs = {}) | |
# %convert_element_type_521 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_56, torch.float32), kwargs = {}) | |
# %select_scatter_default_36 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_521, 0, 1), kwargs = {}) | |
# %convert_element_type_522 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_57, torch.float32), kwargs = {}) | |
# %select_scatter_default_37 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_522, 0, 0), kwargs = {}) | |
# %add_269 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_36, %select_scatter_default_37), kwargs = {}) | |
# %select_scatter_default_39 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_269, 0, 10), kwargs = {}) | |
# %add_271 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_256, %select_scatter_default_39), kwargs = {}) | |
# %convert_element_type_553 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_64, torch.float32), kwargs = {}) | |
# %select_scatter_default_43 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_553, 0, 1), kwargs = {}) | |
# %convert_element_type_554 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_65, torch.float32), kwargs = {}) | |
# %select_scatter_default_44 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_554, 0, 0), kwargs = {}) | |
# %add_285 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_43, %select_scatter_default_44), kwargs = {}) | |
# %select_scatter_default_46 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_285, 0, 9), kwargs = {}) | |
# %add_287 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_271, %select_scatter_default_46), kwargs = {}) | |
# %convert_element_type_585 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_72, torch.float32), kwargs = {}) | |
# %select_scatter_default_50 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_585, 0, 1), kwargs = {}) | |
# %convert_element_type_586 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_73, torch.float32), kwargs = {}) | |
# %select_scatter_default_51 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_586, 0, 0), kwargs = {}) | |
# %add_301 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_50, %select_scatter_default_51), kwargs = {}) | |
# %select_scatter_default_53 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_301, 0, 8), kwargs = {}) | |
# %add_303 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_287, %select_scatter_default_53), kwargs = {}) | |
# %convert_element_type_597 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_75, torch.float32), kwargs = {}) | |
# %select_scatter_default_54 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_597, 0, 1), kwargs = {}) | |
# %convert_element_type_598 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_76, torch.float32), kwargs = {}) | |
# %select_scatter_default_55 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_598, 0, 0), kwargs = {}) | |
# %add_308 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_54, %select_scatter_default_55), kwargs = {}) | |
# %select_scatter_default_56 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_308, 0, 7), kwargs = {}) | |
# %add_309 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_303, %select_scatter_default_56), kwargs = {}) | |
triton_poi_fused__to_copy_add_select_backward_32 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_32', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 32}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'in_ptr15': '*bf16', 'in_ptr16': '*bf16', 'in_ptr17': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]], (19,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 18, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_add_select_backward_32(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x1 = xindex // 2 | |
x0 = (xindex % 2) | |
x2 = xindex | |
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK]) | |
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK]) | |
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK]) | |
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32) | |
tmp26 = tl.broadcast_to(tmp25, [XBLOCK]) | |
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32) | |
tmp35 = tl.broadcast_to(tmp34, [XBLOCK]) | |
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32) | |
tmp39 = tl.broadcast_to(tmp38, [XBLOCK]) | |
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32) | |
tmp48 = tl.broadcast_to(tmp47, [XBLOCK]) | |
tmp51 = tl.load(in_ptr7 + (0)).to(tl.float32) | |
tmp52 = tl.broadcast_to(tmp51, [XBLOCK]) | |
tmp60 = tl.load(in_ptr8 + (0)).to(tl.float32) | |
tmp61 = tl.broadcast_to(tmp60, [XBLOCK]) | |
tmp64 = tl.load(in_ptr9 + (0)).to(tl.float32) | |
tmp65 = tl.broadcast_to(tmp64, [XBLOCK]) | |
tmp73 = tl.load(in_ptr10 + (0)).to(tl.float32) | |
tmp74 = tl.broadcast_to(tmp73, [XBLOCK]) | |
tmp77 = tl.load(in_ptr11 + (0)).to(tl.float32) | |
tmp78 = tl.broadcast_to(tmp77, [XBLOCK]) | |
tmp86 = tl.load(in_ptr12 + (0)).to(tl.float32) | |
tmp87 = tl.broadcast_to(tmp86, [XBLOCK]) | |
tmp90 = tl.load(in_ptr13 + (0)).to(tl.float32) | |
tmp91 = tl.broadcast_to(tmp90, [XBLOCK]) | |
tmp99 = tl.load(in_ptr14 + (0)).to(tl.float32) | |
tmp100 = tl.broadcast_to(tmp99, [XBLOCK]) | |
tmp103 = tl.load(in_ptr15 + (0)).to(tl.float32) | |
tmp104 = tl.broadcast_to(tmp103, [XBLOCK]) | |
tmp112 = tl.load(in_ptr16 + (0)).to(tl.float32) | |
tmp113 = tl.broadcast_to(tmp112, [XBLOCK]) | |
tmp116 = tl.load(in_ptr17 + (0)).to(tl.float32) | |
tmp117 = tl.broadcast_to(tmp116, [XBLOCK]) | |
tmp0 = x1 | |
tmp1 = tl.full([1], 15, tl.int32) | |
tmp2 = tmp0 == tmp1 | |
tmp3 = x0 | |
tmp4 = tl.full([1], 1, tl.int32) | |
tmp5 = tmp3 == tmp4 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = 0.0 | |
tmp10 = tl.where(tmp5, tmp8, tmp9) | |
tmp11 = tl.full([1], 0, tl.int32) | |
tmp12 = tmp3 == tmp11 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tl.where(tmp12, tmp15, tmp9) | |
tmp17 = tmp10 + tmp16 | |
tmp18 = tl.where(tmp2, tmp17, tmp9) | |
tmp19 = tl.full([1], 14, tl.int32) | |
tmp20 = tmp0 == tmp19 | |
tmp23 = tmp22.to(tl.float32) | |
tmp24 = tl.where(tmp5, tmp23, tmp9) | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tl.where(tmp12, tmp27, tmp9) | |
tmp29 = tmp24 + tmp28 | |
tmp30 = tl.where(tmp20, tmp29, tmp9) | |
tmp31 = tmp18 + tmp30 | |
tmp32 = tl.full([1], 13, tl.int32) | |
tmp33 = tmp0 == tmp32 | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tl.where(tmp5, tmp36, tmp9) | |
tmp40 = tmp39.to(tl.float32) | |
tmp41 = tl.where(tmp12, tmp40, tmp9) | |
tmp42 = tmp37 + tmp41 | |
tmp43 = tl.where(tmp33, tmp42, tmp9) | |
tmp44 = tmp31 + tmp43 | |
tmp45 = tl.full([1], 12, tl.int32) | |
tmp46 = tmp0 == tmp45 | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tl.where(tmp5, tmp49, tmp9) | |
tmp53 = tmp52.to(tl.float32) | |
tmp54 = tl.where(tmp12, tmp53, tmp9) | |
tmp55 = tmp50 + tmp54 | |
tmp56 = tl.where(tmp46, tmp55, tmp9) | |
tmp57 = tmp44 + tmp56 | |
tmp58 = tl.full([1], 11, tl.int32) | |
tmp59 = tmp0 == tmp58 | |
tmp62 = tmp61.to(tl.float32) | |
tmp63 = tl.where(tmp5, tmp62, tmp9) | |
tmp66 = tmp65.to(tl.float32) | |
tmp67 = tl.where(tmp12, tmp66, tmp9) | |
tmp68 = tmp63 + tmp67 | |
tmp69 = tl.where(tmp59, tmp68, tmp9) | |
tmp70 = tmp57 + tmp69 | |
tmp71 = tl.full([1], 10, tl.int32) | |
tmp72 = tmp0 == tmp71 | |
tmp75 = tmp74.to(tl.float32) | |
tmp76 = tl.where(tmp5, tmp75, tmp9) | |
tmp79 = tmp78.to(tl.float32) | |
tmp80 = tl.where(tmp12, tmp79, tmp9) | |
tmp81 = tmp76 + tmp80 | |
tmp82 = tl.where(tmp72, tmp81, tmp9) | |
tmp83 = tmp70 + tmp82 | |
tmp84 = tl.full([1], 9, tl.int32) | |
tmp85 = tmp0 == tmp84 | |
tmp88 = tmp87.to(tl.float32) | |
tmp89 = tl.where(tmp5, tmp88, tmp9) | |
tmp92 = tmp91.to(tl.float32) | |
tmp93 = tl.where(tmp12, tmp92, tmp9) | |
tmp94 = tmp89 + tmp93 | |
tmp95 = tl.where(tmp85, tmp94, tmp9) | |
tmp96 = tmp83 + tmp95 | |
tmp97 = tl.full([1], 8, tl.int32) | |
tmp98 = tmp0 == tmp97 | |
tmp101 = tmp100.to(tl.float32) | |
tmp102 = tl.where(tmp5, tmp101, tmp9) | |
tmp105 = tmp104.to(tl.float32) | |
tmp106 = tl.where(tmp12, tmp105, tmp9) | |
tmp107 = tmp102 + tmp106 | |
tmp108 = tl.where(tmp98, tmp107, tmp9) | |
tmp109 = tmp96 + tmp108 | |
tmp110 = tl.full([1], 7, tl.int32) | |
tmp111 = tmp0 == tmp110 | |
tmp114 = tmp113.to(tl.float32) | |
tmp115 = tl.where(tmp5, tmp114, tmp9) | |
tmp118 = tmp117.to(tl.float32) | |
tmp119 = tl.where(tmp12, tmp118, tmp9) | |
tmp120 = tmp115 + tmp119 | |
tmp121 = tl.where(tmp111, tmp120, tmp9) | |
tmp122 = tmp109 + tmp121 | |
tl.store(out_ptr0 + (x2), tmp122, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/4i/c4i2c47wbii5qwc7mntsdhih6p2hh5tnkpctkkyukymyjyzxkagu.py | |
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_19 => convert_element_type_142 | |
# Graph fragment: | |
# %mul_608 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_312, %select_45), kwargs = {}) | |
# %convert_element_type_142 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_68, torch.float32), kwargs = {}) | |
# %convert_element_type_614 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_608, torch.float32), kwargs = {}) | |
# %mul_610 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %convert_element_type_142), kwargs = {}) | |
# %mul_611 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %rsqrt_27), kwargs = {}) | |
# %sum_79 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_610, [3], True), kwargs = {}) | |
# %div_38 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_55, 128), kwargs = {}) | |
# %pow_164 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_142, 1.0), kwargs = {}) | |
# %mul_614 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_164, 2.0), kwargs = {}) | |
# %mul_615 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_38, %mul_614), kwargs = {}) | |
# %add_312 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_611, %mul_615), kwargs = {}) | |
# %convert_element_type_615 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_312, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_33 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_33', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_33(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (60)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (60)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/t4/ct4e3vdtvk4aq7yblppd4mvqfalkrljcmgsro5womvtonc2jbebt.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {}) | |
# %mul_638 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_43), kwargs = {}) | |
triton_poi_fused_add_mul_34 = async_compile.triton('triton_poi_fused_add_mul_34', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_34(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (28)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/6h/c6h57xn6lpkcocdgjcl3vgnzzktkctorpl7zvtirarkpn7qgycrr.py | |
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_16 => convert_element_type_122 | |
# Graph fragment: | |
# %mul_648 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_332, %select_39), kwargs = {}) | |
# %convert_element_type_122 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_59, torch.float32), kwargs = {}) | |
# %convert_element_type_645 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_648, torch.float32), kwargs = {}) | |
# %mul_650 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %convert_element_type_122), kwargs = {}) | |
# %mul_651 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %rsqrt_23), kwargs = {}) | |
# %sum_86 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_650, [3], True), kwargs = {}) | |
# %div_42 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_59, 128), kwargs = {}) | |
# %pow_173 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_122, 1.0), kwargs = {}) | |
# %mul_654 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_173, 2.0), kwargs = {}) | |
# %mul_655 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_42, %mul_654), kwargs = {}) | |
# %add_327 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_651, %mul_655), kwargs = {}) | |
# %convert_element_type_646 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_327, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_35 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_35', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_35(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (58)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (58)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/wl/cwlkwk35d2q4uxlip2fikqjyurywrohrez5ylmucjva6zxu36fx4.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {}) | |
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {}) | |
# %mul_504 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %select_63), kwargs = {}) | |
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {}) | |
# %mul_637 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %convert_element_type_11), kwargs = {}) | |
# %sum_82 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_637,), kwargs = {}) | |
# %mul_639 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %add_69), kwargs = {}) | |
# %sum_83 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_639,), kwargs = {}) | |
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {}) | |
# %mul_677 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %convert_element_type_11), kwargs = {}) | |
# %sum_89 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_677,), kwargs = {}) | |
# %mul_678 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_37), kwargs = {}) | |
# %mul_679 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %add_58), kwargs = {}) | |
# %sum_90 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_679,), kwargs = {}) | |
# %add_337 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_504, %mul_678), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_36 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_36', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_36', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_36(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp32 = tl.load(in_ptr9 + (36)) | |
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK]) | |
tmp36 = tl.load(in_ptr9 + (4)) | |
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK]) | |
tmp40 = tl.load(in_ptr9 + (26)) | |
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp29 + tmp30 | |
tmp34 = tmp33.to(tl.float32) | |
tmp35 = tmp31 * tmp34 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp35 * tmp38 | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp19 * tmp42 | |
tmp44 = tmp39 + tmp43 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/wb/cwbn6yy6ck5xt3fzhic3wvkkrf6xfqeppzluritiwvusplzfnr6d.py | |
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_13 => convert_element_type_102 | |
# Graph fragment: | |
# %mul_688 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_352, %select_33), kwargs = {}) | |
# %convert_element_type_102 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_50, torch.float32), kwargs = {}) | |
# %convert_element_type_676 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_688, torch.float32), kwargs = {}) | |
# %mul_690 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %convert_element_type_102), kwargs = {}) | |
# %mul_691 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %rsqrt_19), kwargs = {}) | |
# %sum_93 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_690, [3], True), kwargs = {}) | |
# %div_46 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_63, 128), kwargs = {}) | |
# %pow_182 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_102, 1.0), kwargs = {}) | |
# %mul_694 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_182, 2.0), kwargs = {}) | |
# %mul_695 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_46, %mul_694), kwargs = {}) | |
# %add_343 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_691, %mul_695), kwargs = {}) | |
# %convert_element_type_677 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_343, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_37 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_37', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_37(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (56)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (56)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/bv/cbvz6ftrc4j32z2qzbtvnccdakox773dpuzkgduiwvzrk5muc5rw.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {}) | |
# %mul_636 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_44), kwargs = {}) | |
# %add_321 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_306, %mul_636), kwargs = {}) | |
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {}) | |
# %mul_676 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_38), kwargs = {}) | |
# %add_336 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_321, %mul_676), kwargs = {}) | |
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {}) | |
# %mul_716 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_32), kwargs = {}) | |
# %add_352 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_336, %mul_716), kwargs = {}) | |
# %mul_718 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_31), kwargs = {}) | |
triton_poi_fused_add_mul_38 = async_compile.triton('triton_poi_fused_add_mul_38', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_38', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_38(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp2 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp4 = tl.load(in_ptr2 + (29)) | |
tmp5 = tl.broadcast_to(tmp4, [XBLOCK]) | |
tmp9 = tl.load(in_ptr3 + (x0), None).to(tl.float32) | |
tmp10 = tl.load(in_ptr4 + (x0), None).to(tl.float32) | |
tmp12 = tl.load(in_ptr2 + (27)) | |
tmp13 = tl.broadcast_to(tmp12, [XBLOCK]) | |
tmp17 = tl.load(in_ptr5 + (x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (x0), None).to(tl.float32) | |
tmp20 = tl.load(in_ptr2 + (25)) | |
tmp21 = tl.broadcast_to(tmp20, [XBLOCK]) | |
tmp25 = tl.load(in_ptr2 + (24)) | |
tmp26 = tl.broadcast_to(tmp25, [XBLOCK]) | |
tmp3 = tmp1 + tmp2 | |
tmp6 = tmp5.to(tl.float32) | |
tmp7 = tmp3 * tmp6 | |
tmp8 = tmp0 + tmp7 | |
tmp11 = tmp9 + tmp10 | |
tmp14 = tmp13.to(tl.float32) | |
tmp15 = tmp11 * tmp14 | |
tmp16 = tmp8 + tmp15 | |
tmp19 = tmp17 + tmp18 | |
tmp22 = tmp21.to(tl.float32) | |
tmp23 = tmp19 * tmp22 | |
tmp24 = tmp16 + tmp23 | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tmp19 * tmp27 | |
tl.store(in_out_ptr0 + (x0), tmp24, None) | |
tl.store(out_ptr0 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ae/caeu7pgawm6k6vk7vpmr2zuhpdwebvifn4flr2bzxrarimqrm2t3.py | |
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_10 => convert_element_type_82 | |
# Graph fragment: | |
# %mul_728 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_372, %select_27), kwargs = {}) | |
# %convert_element_type_82 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_41, torch.float32), kwargs = {}) | |
# %convert_element_type_707 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_728, torch.float32), kwargs = {}) | |
# %mul_730 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %convert_element_type_82), kwargs = {}) | |
# %mul_731 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %rsqrt_15), kwargs = {}) | |
# %sum_100 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_730, [3], True), kwargs = {}) | |
# %div_50 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_67, 128), kwargs = {}) | |
# %pow_191 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_82, 1.0), kwargs = {}) | |
# %mul_734 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_191, 2.0), kwargs = {}) | |
# %mul_735 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_50, %mul_734), kwargs = {}) | |
# %add_358 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_731, %mul_735), kwargs = {}) | |
# %convert_element_type_708 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_358, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_39 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_39', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_39(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (54)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (54)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/e7/ce7jis7nrcnkxdvhijgd3ls2ioixwrag75qnb3dbukxtzj6su45m.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {}) | |
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {}) | |
# %mul_462 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %select_70), kwargs = {}) | |
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {}) | |
# %mul_717 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %convert_element_type_11), kwargs = {}) | |
# %sum_96 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_717,), kwargs = {}) | |
# %mul_719 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %add_47), kwargs = {}) | |
# %sum_97 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_719,), kwargs = {}) | |
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {}) | |
# %mul_757 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %convert_element_type_11), kwargs = {}) | |
# %sum_103 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_757,), kwargs = {}) | |
# %mul_758 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_25), kwargs = {}) | |
# %mul_759 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %add_36), kwargs = {}) | |
# %sum_104 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_759,), kwargs = {}) | |
# %add_368 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_462, %mul_758), kwargs = {}) | |
triton_per_fused__to_copy_add_mul_sum_40 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_40', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.persistent_reduction( | |
size_hints={'x': 65536, 'r0_': 1024}, | |
reduction_hint=ReductionHint.INNER, | |
filename=__file__, | |
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_40', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_per_fused__to_copy_add_mul_sum_40(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel): | |
xnumel = 65536 | |
XBLOCK: tl.constexpr = 1 | |
r0_numel = 1024 | |
R0_BLOCK: tl.constexpr = 1024 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = tl.full([1], xoffset, tl.int32) | |
xmask = tl.full([R0_BLOCK], True, tl.int1) | |
r0_index = tl.arange(0, R0_BLOCK)[:] | |
r0_offset = 0 | |
r0_mask = tl.full([R0_BLOCK], True, tl.int1) | |
roffset = r0_offset | |
rindex = r0_index | |
r0_1 = r0_index | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last') | |
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32) | |
tmp32 = tl.load(in_ptr9 + (38)) | |
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK]) | |
tmp36 = tl.load(in_ptr9 + (2)) | |
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK]) | |
tmp40 = tl.load(in_ptr9 + (22)) | |
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp4 = tmp3.to(tl.float32) | |
tmp6 = tmp4 * tmp5 | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp2 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK]) | |
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0)) | |
tmp13 = tmp2 * tmp12 | |
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK]) | |
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0)) | |
tmp19 = tmp17 + tmp18 | |
tmp20 = tmp19 * tmp7 | |
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK]) | |
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0)) | |
tmp25 = tmp19 * tmp24 | |
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK]) | |
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0)) | |
tmp31 = tmp29 + tmp30 | |
tmp34 = tmp33.to(tl.float32) | |
tmp35 = tmp31 * tmp34 | |
tmp38 = tmp37.to(tl.float32) | |
tmp39 = tmp35 * tmp38 | |
tmp42 = tmp41.to(tl.float32) | |
tmp43 = tmp19 * tmp42 | |
tmp44 = tmp39 + tmp43 | |
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None) | |
tl.store(out_ptr0 + (x0), tmp11, None) | |
tl.store(out_ptr1 + (x0), tmp16, None) | |
tl.store(out_ptr2 + (x0), tmp23, None) | |
tl.store(out_ptr3 + (x0), tmp28, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/rm/crm2pq6wksbw6bjoc7vc22gi6qau7lq7ec3klu22afni6b5bl6ck.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %convert_element_type_347 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_13, torch.float32), kwargs = {}) | |
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_1 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_347, 0, 1), kwargs = {}) | |
# %convert_element_type_348 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_14, torch.float32), kwargs = {}) | |
# %select_scatter_default_2 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_348, 0, 0), kwargs = {}) | |
# %add_184 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_1, %select_scatter_default_2), kwargs = {}) | |
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %select_scatter_default_5 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_184, 0, 15), kwargs = {}) | |
# %convert_element_type_379 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_21, torch.float32), kwargs = {}) | |
# %select_scatter_default_8 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_379, 0, 1), kwargs = {}) | |
# %convert_element_type_380 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_22, torch.float32), kwargs = {}) | |
# %select_scatter_default_9 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_380, 0, 0), kwargs = {}) | |
# %add_197 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_8, %select_scatter_default_9), kwargs = {}) | |
# %select_scatter_default_12 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_197, 0, 14), kwargs = {}) | |
# %add_209 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_5, %select_scatter_default_12), kwargs = {}) | |
# %convert_element_type_411 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_29, torch.float32), kwargs = {}) | |
# %select_scatter_default_15 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_411, 0, 1), kwargs = {}) | |
# %convert_element_type_412 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_30, torch.float32), kwargs = {}) | |
# %select_scatter_default_16 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_412, 0, 0), kwargs = {}) | |
# %add_213 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_15, %select_scatter_default_16), kwargs = {}) | |
# %select_scatter_default_19 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_213, 0, 13), kwargs = {}) | |
# %add_225 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_209, %select_scatter_default_19), kwargs = {}) | |
# %convert_element_type_443 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_37, torch.float32), kwargs = {}) | |
# %select_scatter_default_22 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_443, 0, 0), kwargs = {}) | |
# %select_scatter_default_25 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_22, 0, 12), kwargs = {}) | |
# %add_240 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_225, %select_scatter_default_25), kwargs = {}) | |
# %convert_element_type_474 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_44, torch.float32), kwargs = {}) | |
# %select_scatter_default_28 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_474, 0, 0), kwargs = {}) | |
# %select_scatter_default_31 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_28, 0, 11), kwargs = {}) | |
# %add_255 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_240, %select_scatter_default_31), kwargs = {}) | |
# %convert_element_type_506 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_52, torch.float32), kwargs = {}) | |
# %select_scatter_default_35 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_506, 0, 0), kwargs = {}) | |
# %select_scatter_default_38 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_35, 0, 10), kwargs = {}) | |
# %add_270 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_255, %select_scatter_default_38), kwargs = {}) | |
# %convert_element_type_538 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_60, torch.float32), kwargs = {}) | |
# %select_scatter_default_42 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_538, 0, 0), kwargs = {}) | |
# %select_scatter_default_45 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_42, 0, 9), kwargs = {}) | |
# %add_286 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_270, %select_scatter_default_45), kwargs = {}) | |
# %convert_element_type_570 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_68, torch.float32), kwargs = {}) | |
# %select_scatter_default_49 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_570, 0, 0), kwargs = {}) | |
# %select_scatter_default_52 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_49, 0, 8), kwargs = {}) | |
# %add_302 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_286, %select_scatter_default_52), kwargs = {}) | |
# %convert_element_type_613 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_78, torch.float32), kwargs = {}) | |
# %select_scatter_default_58 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_613, 0, 0), kwargs = {}) | |
# %select_scatter_default_61 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_58, 0, 6), kwargs = {}) | |
# %add_323 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_302, %select_scatter_default_61), kwargs = {}) | |
# %convert_element_type_644 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_85, torch.float32), kwargs = {}) | |
# %select_scatter_default_64 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_644, 0, 0), kwargs = {}) | |
# %select_scatter_default_67 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_64, 0, 5), kwargs = {}) | |
# %add_339 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_323, %select_scatter_default_67), kwargs = {}) | |
# %convert_element_type_675 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_92, torch.float32), kwargs = {}) | |
# %select_scatter_default_70 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_675, 0, 0), kwargs = {}) | |
# %select_scatter_default_73 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_70, 0, 4), kwargs = {}) | |
# %add_354 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_339, %select_scatter_default_73), kwargs = {}) | |
# %convert_element_type_706 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_99, torch.float32), kwargs = {}) | |
# %select_scatter_default_76 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_706, 0, 0), kwargs = {}) | |
# %select_scatter_default_79 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_76, 0, 3), kwargs = {}) | |
# %add_370 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_354, %select_scatter_default_79), kwargs = {}) | |
triton_poi_fused__to_copy_add_select_backward_41 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_41', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 32}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 15, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused__to_copy_add_select_backward_41(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 32 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x1 = xindex // 2 | |
x0 = (xindex % 2) | |
x2 = xindex | |
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32) | |
tmp7 = tl.broadcast_to(tmp6, [XBLOCK]) | |
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK]) | |
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32) | |
tmp22 = tl.broadcast_to(tmp21, [XBLOCK]) | |
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32) | |
tmp26 = tl.broadcast_to(tmp25, [XBLOCK]) | |
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32) | |
tmp35 = tl.broadcast_to(tmp34, [XBLOCK]) | |
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32) | |
tmp39 = tl.broadcast_to(tmp38, [XBLOCK]) | |
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32) | |
tmp48 = tl.broadcast_to(tmp47, [XBLOCK]) | |
tmp55 = tl.load(in_ptr7 + (0)).to(tl.float32) | |
tmp56 = tl.broadcast_to(tmp55, [XBLOCK]) | |
tmp63 = tl.load(in_ptr8 + (0)).to(tl.float32) | |
tmp64 = tl.broadcast_to(tmp63, [XBLOCK]) | |
tmp71 = tl.load(in_ptr9 + (0)).to(tl.float32) | |
tmp72 = tl.broadcast_to(tmp71, [XBLOCK]) | |
tmp79 = tl.load(in_ptr10 + (0)).to(tl.float32) | |
tmp80 = tl.broadcast_to(tmp79, [XBLOCK]) | |
tmp87 = tl.load(in_ptr11 + (0)).to(tl.float32) | |
tmp88 = tl.broadcast_to(tmp87, [XBLOCK]) | |
tmp95 = tl.load(in_ptr12 + (0)).to(tl.float32) | |
tmp96 = tl.broadcast_to(tmp95, [XBLOCK]) | |
tmp103 = tl.load(in_ptr13 + (0)).to(tl.float32) | |
tmp104 = tl.broadcast_to(tmp103, [XBLOCK]) | |
tmp111 = tl.load(in_ptr14 + (0)).to(tl.float32) | |
tmp112 = tl.broadcast_to(tmp111, [XBLOCK]) | |
tmp0 = x1 | |
tmp1 = tl.full([1], 15, tl.int32) | |
tmp2 = tmp0 == tmp1 | |
tmp3 = x0 | |
tmp4 = tl.full([1], 1, tl.int32) | |
tmp5 = tmp3 == tmp4 | |
tmp8 = tmp7.to(tl.float32) | |
tmp9 = 0.0 | |
tmp10 = tl.where(tmp5, tmp8, tmp9) | |
tmp11 = tl.full([1], 0, tl.int32) | |
tmp12 = tmp3 == tmp11 | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tl.where(tmp12, tmp15, tmp9) | |
tmp17 = tmp10 + tmp16 | |
tmp18 = tl.where(tmp2, tmp17, tmp9) | |
tmp19 = tl.full([1], 14, tl.int32) | |
tmp20 = tmp0 == tmp19 | |
tmp23 = tmp22.to(tl.float32) | |
tmp24 = tl.where(tmp5, tmp23, tmp9) | |
tmp27 = tmp26.to(tl.float32) | |
tmp28 = tl.where(tmp12, tmp27, tmp9) | |
tmp29 = tmp24 + tmp28 | |
tmp30 = tl.where(tmp20, tmp29, tmp9) | |
tmp31 = tmp18 + tmp30 | |
tmp32 = tl.full([1], 13, tl.int32) | |
tmp33 = tmp0 == tmp32 | |
tmp36 = tmp35.to(tl.float32) | |
tmp37 = tl.where(tmp5, tmp36, tmp9) | |
tmp40 = tmp39.to(tl.float32) | |
tmp41 = tl.where(tmp12, tmp40, tmp9) | |
tmp42 = tmp37 + tmp41 | |
tmp43 = tl.where(tmp33, tmp42, tmp9) | |
tmp44 = tmp31 + tmp43 | |
tmp45 = tl.full([1], 12, tl.int32) | |
tmp46 = tmp0 == tmp45 | |
tmp49 = tmp48.to(tl.float32) | |
tmp50 = tl.where(tmp12, tmp49, tmp9) | |
tmp51 = tl.where(tmp46, tmp50, tmp9) | |
tmp52 = tmp44 + tmp51 | |
tmp53 = tl.full([1], 11, tl.int32) | |
tmp54 = tmp0 == tmp53 | |
tmp57 = tmp56.to(tl.float32) | |
tmp58 = tl.where(tmp12, tmp57, tmp9) | |
tmp59 = tl.where(tmp54, tmp58, tmp9) | |
tmp60 = tmp52 + tmp59 | |
tmp61 = tl.full([1], 10, tl.int32) | |
tmp62 = tmp0 == tmp61 | |
tmp65 = tmp64.to(tl.float32) | |
tmp66 = tl.where(tmp12, tmp65, tmp9) | |
tmp67 = tl.where(tmp62, tmp66, tmp9) | |
tmp68 = tmp60 + tmp67 | |
tmp69 = tl.full([1], 9, tl.int32) | |
tmp70 = tmp0 == tmp69 | |
tmp73 = tmp72.to(tl.float32) | |
tmp74 = tl.where(tmp12, tmp73, tmp9) | |
tmp75 = tl.where(tmp70, tmp74, tmp9) | |
tmp76 = tmp68 + tmp75 | |
tmp77 = tl.full([1], 8, tl.int32) | |
tmp78 = tmp0 == tmp77 | |
tmp81 = tmp80.to(tl.float32) | |
tmp82 = tl.where(tmp12, tmp81, tmp9) | |
tmp83 = tl.where(tmp78, tmp82, tmp9) | |
tmp84 = tmp76 + tmp83 | |
tmp85 = tl.full([1], 6, tl.int32) | |
tmp86 = tmp0 == tmp85 | |
tmp89 = tmp88.to(tl.float32) | |
tmp90 = tl.where(tmp12, tmp89, tmp9) | |
tmp91 = tl.where(tmp86, tmp90, tmp9) | |
tmp92 = tmp84 + tmp91 | |
tmp93 = tl.full([1], 5, tl.int32) | |
tmp94 = tmp0 == tmp93 | |
tmp97 = tmp96.to(tl.float32) | |
tmp98 = tl.where(tmp12, tmp97, tmp9) | |
tmp99 = tl.where(tmp94, tmp98, tmp9) | |
tmp100 = tmp92 + tmp99 | |
tmp101 = tl.full([1], 4, tl.int32) | |
tmp102 = tmp0 == tmp101 | |
tmp105 = tmp104.to(tl.float32) | |
tmp106 = tl.where(tmp12, tmp105, tmp9) | |
tmp107 = tl.where(tmp102, tmp106, tmp9) | |
tmp108 = tmp100 + tmp107 | |
tmp109 = tl.full([1], 3, tl.int32) | |
tmp110 = tmp0 == tmp109 | |
tmp113 = tmp112.to(tl.float32) | |
tmp114 = tl.where(tmp12, tmp113, tmp9) | |
tmp115 = tl.where(tmp110, tmp114, tmp9) | |
tmp116 = tmp108 + tmp115 | |
tl.store(out_ptr0 + (x2), tmp116, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/tm/ctmmn45fi4yzwoybueak7k3sy6rxrcom6pr6q6vovoawwcixsrod.py | |
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add] | |
# Source node to ATen node mapping: | |
# v_7 => convert_element_type_62 | |
# Graph fragment: | |
# %mul_770 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_392, %select_20), kwargs = {}) | |
# %convert_element_type_62 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_32, torch.float32), kwargs = {}) | |
# %convert_element_type_739 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_770, torch.float32), kwargs = {}) | |
# %mul_772 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %convert_element_type_62), kwargs = {}) | |
# %mul_773 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %rsqrt_11), kwargs = {}) | |
# %sum_108 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_772, [3], True), kwargs = {}) | |
# %div_54 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_71, 128), kwargs = {}) | |
# %pow_200 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_62, 1.0), kwargs = {}) | |
# %mul_776 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_200, 2.0), kwargs = {}) | |
# %mul_777 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_54, %mul_776), kwargs = {}) | |
# %add_376 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_773, %mul_777), kwargs = {}) | |
# %convert_element_type_740 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_376, torch.bfloat16), kwargs = {}) | |
triton_red_fused__to_copy_add_div_mul_pow_sum_42 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_42', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_mul_pow_sum_42(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (52)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (52)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/iy/ciyfyhl44ar5iouy2c6tng6yyeqta2nvvqv2qkv5v47eqdfdiequ.py | |
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul] | |
# Source node to ATen node mapping: | |
# Graph fragment: | |
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {}) | |
# %mul_800 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_18), kwargs = {}) | |
triton_poi_fused_add_mul_43 = async_compile.triton('triton_poi_fused_add_mul_43', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_mul_43(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 67108864 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = tl.full([XBLOCK], True, tl.int1) | |
x0 = xindex | |
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) | |
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32) | |
tmp3 = tl.load(in_ptr2 + (20)) | |
tmp4 = tl.broadcast_to(tmp3, [XBLOCK]) | |
tmp2 = tmp0 + tmp1 | |
tmp5 = tmp4.to(tl.float32) | |
tmp6 = tmp2 * tmp5 | |
tl.store(out_ptr0 + (x0), tmp6, None) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/de/cdez6a43hw7ae3gtzsdmit5bwldgblxmtsegoexxdboxn64ojfhd.py | |
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {}) | |
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {}) | |
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {}) | |
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44 = async_compile.triton('triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.pointwise( | |
size_hints={'x': 67108864}, | |
filename=__file__, | |
triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}, | |
min_elem_per_thread=0 | |
) | |
@triton.jit | |
def triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44(out_ptr0, xnumel, XBLOCK : tl.constexpr): | |
xnumel = 51463168 | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:] | |
xmask = xindex < xnumel | |
x0 = xindex | |
tmp0 = 0.0 | |
tl.store(out_ptr0 + (x0), tmp0, xmask) | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/wc/cwcnc3zg2mhcjo34dem5rdh6dr3jzso4l2wbr3kinwnai7rrc2xj.py | |
# Topologically Sorted Source Nodes: [loss, v_4], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward] | |
# Source node to ATen node mapping: | |
# loss => full_default_13 | |
# v_4 => convert_element_type_42 | |
# Graph fragment: | |
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {}) | |
# %mul_812 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_412, %select_13), kwargs = {}) | |
# %convert_element_type_42 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_23, torch.float32), kwargs = {}) | |
# %convert_element_type_771 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_812, torch.float32), kwargs = {}) | |
# %mul_814 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %convert_element_type_42), kwargs = {}) | |
# %mul_815 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %rsqrt_7), kwargs = {}) | |
# %sum_116 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_814, [3], True), kwargs = {}) | |
# %div_58 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_75, 128), kwargs = {}) | |
# %pow_209 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_42, 1.0), kwargs = {}) | |
# %mul_818 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_209, 2.0), kwargs = {}) | |
# %mul_819 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_58, %mul_818), kwargs = {}) | |
# %add_393 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_815, %mul_819), kwargs = {}) | |
# %convert_element_type_772 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_393, torch.bfloat16), kwargs = {}) | |
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False}) | |
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {}) | |
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {}) | |
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {}) | |
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45 = async_compile.triton('triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', ''' | |
import triton | |
import triton.language as tl | |
from torch._inductor.runtime import triton_helpers, triton_heuristics | |
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math | |
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties | |
triton_helpers.set_driver_to_gpu() | |
@triton_heuristics.reduction( | |
size_hints={'x': 524288, 'r0_': 128}, | |
reduction_hint=ReductionHint.DEFAULT, | |
filename=__file__, | |
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*i32', 'in_ptr5': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]}, | |
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', 'mutated_arg_names': ['out_ptr2'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False} | |
) | |
@triton.jit | |
def triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): | |
xnumel = 524288 | |
r0_numel = 128 | |
rnumel = r0_numel | |
RBLOCK: tl.constexpr = R0_BLOCK | |
xoffset = tl.program_id(0) * XBLOCK | |
xindex = xoffset + tl.arange(0, XBLOCK)[:, None] | |
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1) | |
r0_base = tl.arange(0, R0_BLOCK)[None, :] | |
rbase = r0_base | |
x3 = xindex | |
tmp1 = tl.load(in_ptr1 + (50)) | |
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK]) | |
x0 = (xindex % 8) | |
x1 = xindex // 8 | |
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32) | |
tmp3 = tmp2.to(tl.float32) | |
tmp4 = tmp0 * tmp3 | |
tmp5 = tmp4.to(tl.float32) | |
tmp7 = tmp6.to(tl.float32) | |
tmp8 = tmp5 * tmp7 | |
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK]) | |
tmp11 = _tmp10 + tmp9 | |
_tmp10 = tl.where(r0_mask, tmp11, _tmp10) | |
tmp10 = tl.sum(_tmp10, 1)[:, None] | |
tmp13 = tl.load(in_ptr1 + (50)) | |
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK]) | |
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last') | |
tmp34 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last') | |
tmp44 = tl.load(in_ptr1 + (77)) | |
tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK]) | |
tmp48 = tl.load(in_ptr1 + (51)) | |
tmp49 = tl.broadcast_to(tmp48, [XBLOCK, R0_BLOCK]) | |
for r0_offset in range(0, r0_numel, R0_BLOCK): | |
r0_index = r0_offset + r0_base | |
r0_mask = r0_index < r0_numel | |
roffset = r0_offset | |
rindex = r0_index | |
r0_2 = r0_index | |
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp43 = tl.load(in_ptr5 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32) | |
tmp15 = tmp14.to(tl.float32) | |
tmp16 = tmp12 * tmp15 | |
tmp17 = tmp16.to(tl.float32) | |
tmp19 = tmp17 * tmp18 | |
tmp20 = -0.5 | |
tmp21 = tmp10 * tmp20 | |
tmp22 = tmp18 * tmp18 | |
tmp23 = tmp22 * tmp18 | |
tmp24 = tmp21 * tmp23 | |
tmp25 = 0.0078125 | |
tmp26 = tmp24 * tmp25 | |
tmp28 = tmp27.to(tl.float32) | |
tmp29 = 2.0 | |
tmp30 = tmp28 * tmp29 | |
tmp31 = tmp26 * tmp30 | |
tmp32 = tmp19 + tmp31 | |
tmp33 = tmp32.to(tl.float32) | |
tmp35 = tmp34.to(tl.int64) | |
tmp36 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32) | |
tmp37 = tmp35 + tmp36 | |
tmp38 = tmp35 < 0 | |
tmp39 = tl.where(tmp38, tmp37, tmp35) | |
tl.device_assert((0 <= tmp39) & (tmp39 < 50257), "index out of bounds: 0 <= tmp39 < 50257") | |
tmp41 = tl.full([1, 1], -1, tl.int64) | |
tmp42 = tmp35 == tmp41 | |
tmp46 = tmp45.to(tl.float32) | |
tmp47 = tmp43 * tmp46 | |
tmp50 = tmp49.to(tl.float32) | |
tmp51 = tmp12 * tmp50 | |
tmp52 = tmp47 + tmp51 | |
tmp53 = tmp52.to(tl.float32) | |
tmp54 = 0.0 | |
tmp55 = tl.where(tmp42, tmp54, tmp53) | |
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask) | |
tl.atomic_add(out_ptr2 + (r0_2 + 128*x0 + 1024*tmp39), tmp55, r0_mask, sem='relaxed') | |
''', device_str='cuda') | |
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ty/ctyl36nsnpehqd4df5ksk7bsvoehmhqsdu3iojfgd3yky7vkrutv.py | |
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum] | |
# Source node to ATen node mapping: | |
# x => convert_element_type_10, convert_element_type_11, mul | |
# Graph fragment: | |
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {}) | |
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {}) | |
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {}) | |
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {}) | |
# %mul_756 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_26), kwargs = {}) | |
# %add_367 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_352, %mul_756), kwargs = {}) | |
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {}) | |
# %mul_798 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_19), kwargs = {}) | |
# %mul_799 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %convert_element_type_11), kwargs = {}) | |
# %sum_111 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_799,), kwargs = {}) | |
# %add_385 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_367, %mul_798), kwargs = {}) | |
# %m |
View raw
(Sorry about that, but we can’t show files that are this big right now.)
View raw
(Sorry about that, but we can’t show files that are this big right now.)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment