Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Last active May 13, 2025 18:41
Show Gist options
  • Save davidberard98/f10db5520c96111254e614b53db9f501 to your computer and use it in GitHub Desktop.
Save davidberard98/f10db5520c96111254e614b53db9f501 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch._inductor.kernel.flex_attention
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/mc/cmc66sj3co2qioktlabn5mqewgz4bthhd7gyxhv7ppsvnpgk2niq.py
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data]
# Source node to ATen node mapping:
# add_116 => add_179
# logits => convert_element_type_323
# loss => full_default_12, full_default_13, sub_4, sub_5
# mul_176 => mul_239
# mul_177 => mul_240
# rsqrt => rsqrt_63
# square_16 => pow_80
# Graph fragment:
# %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%tangents_1, %convert_element_type_324), kwargs = {})
# %ne_3 : [num_users=2] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_164, -100), kwargs = {})
# %full_default_12 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_6 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_164, %full_default_12), kwargs = {})
# %scatter_upon_const_tensor : [num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [65536, 50304], background_val: 0, dtype: torch.float32, dim: 1, selector: %where_6, val: -1.0})
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_7 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %div_2, %full_default_13), kwargs = {})
# %mul_241 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter_upon_const_tensor, %where_7), kwargs = {})
# %convert_element_type_323 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_62, torch.float32), kwargs = {})
# %mul_239 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_323, 15), kwargs = {})
# %pow_80 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 2), kwargs = {})
# %add_179 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%pow_80, 225), kwargs = {})
# %rsqrt_63 : [num_users=3] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_179,), kwargs = {})
# %mul_240 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_239, %rsqrt_63), kwargs = {})
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_240, %amax), kwargs = {})
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_4, %log), kwargs = {})
# %exp_1 : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {})
# %sum_10 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_241, [1], True), kwargs = {})
# %mul_242 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp_1, %sum_10), kwargs = {})
# %sub_6 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_241, %mul_242), kwargs = {})
# %mul_243 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %mul_239), kwargs = {})
# %mul_244 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %rsqrt_63), kwargs = {})
# %pow_81 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt_63, 3), kwargs = {})
# %mul_245 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%mul_243, -0.5), kwargs = {})
# %mul_246 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_245, %pow_81), kwargs = {})
# %pow_82 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 1.0), kwargs = {})
# %mul_247 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_82, 2.0), kwargs = {})
# %mul_248 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_246, %mul_247), kwargs = {})
# %mul_249 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_244, 15), kwargs = {})
# %add_180 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_248, %mul_249), kwargs = {})
# %convert_element_type_325 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_180, torch.bfloat16), kwargs = {})
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0 = async_compile.triton('triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 65536, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 8, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 65536
r0_numel = 50304
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
rbase = r0_base
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr1 + (0))
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp12 = tl.load(in_ptr2 + (0))
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
_tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp1 = tl.full([1, 1], -100, tl.int64)
tmp2 = tmp0 != tmp1
tmp3 = tl.full([1, 1], 0, tl.int64)
tmp4 = tl.where(tmp2, tmp0, tmp3)
tmp5 = r0_1
tmp6 = tmp4 == tmp5
tmp7 = -1.0
tmp8 = 0.0
tmp9 = tl.where(tmp6, tmp7, tmp8)
tmp14 = (tmp11 / tmp13)
tmp15 = tl.where(tmp2, tmp14, tmp8)
tmp16 = tmp9 * tmp15
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
tmp19 = _tmp18 + tmp17
_tmp18 = tl.where(r0_mask, tmp19, _tmp18)
tmp18 = tl.sum(_tmp18, 1)[:, None]
tmp29 = tl.load(in_ptr1 + (0))
tmp30 = tl.broadcast_to(tmp29, [XBLOCK, R0_BLOCK])
tmp31 = tl.load(in_ptr2 + (0))
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK])
tmp45 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp36 = tl.load(in_out_ptr0 + (r0_1 + 50304*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tl.full([1, 1], -100, tl.int64)
tmp21 = tmp0 != tmp20
tmp22 = tl.full([1, 1], 0, tl.int64)
tmp23 = tl.where(tmp21, tmp0, tmp22)
tmp24 = r0_1
tmp25 = tmp23 == tmp24
tmp26 = -1.0
tmp27 = 0.0
tmp28 = tl.where(tmp25, tmp26, tmp27)
tmp33 = (tmp30 / tmp32)
tmp34 = tl.where(tmp21, tmp33, tmp27)
tmp35 = tmp28 * tmp34
tmp37 = tmp36.to(tl.float32)
tmp38 = 15.0
tmp39 = tmp37 * tmp38
tmp40 = tmp37 * tmp37
tmp41 = 225.0
tmp42 = tmp40 + tmp41
tmp43 = libdevice.rsqrt(tmp42)
tmp44 = tmp39 * tmp43
tmp46 = tmp44 - tmp45
tmp48 = tmp46 - tmp47
tmp49 = tl_math.exp(tmp48)
tmp50 = tmp49 * tmp18
tmp51 = tmp35 - tmp50
tmp52 = tmp51 * tmp39
tmp53 = -0.5
tmp54 = tmp52 * tmp53
tmp55 = tmp43 * tmp43
tmp56 = tmp55 * tmp43
tmp57 = tmp54 * tmp56
tmp58 = 2.0
tmp59 = tmp37 * tmp58
tmp60 = tmp57 * tmp59
tmp61 = tmp51 * tmp43
tmp62 = tmp61 * tmp38
tmp63 = tmp60 + tmp62
tmp64 = tmp63.to(tl.float32)
tl.store(in_out_ptr0 + (r0_1 + 50304*x0), tmp64, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/hc/chcgxykfor4ysb6yjas3ttkjs6ocau7waeacbq3dzddsuruaodll.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# Graph fragment:
# %convert_element_type_330 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_63, torch.float32), kwargs = {})
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 51511296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/rj/crjpkoeeqp2o6eema7exap5o6s65g7qiso5aoy7y5k5i6vlpjzmf.py
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# x_144 => convert_element_type_318
# Graph fragment:
# %convert_element_type_331 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_184, torch.float32), kwargs = {})
# %convert_element_type_318 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_177, torch.float32), kwargs = {})
# %mul_250 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %convert_element_type_318), kwargs = {})
# %mul_251 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %rsqrt_62), kwargs = {})
# %sum_11 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_250, [2], True), kwargs = {})
# %div_3 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_20, 1024), kwargs = {})
# %pow_84 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_318, 1.0), kwargs = {})
# %mul_254 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_84, 2.0), kwargs = {})
# %mul_255 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %mul_254), kwargs = {})
# %add_181 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_251, %mul_255), kwargs = {})
# %convert_element_type_332 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_181, torch.bfloat16), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp9 = tmp1 * tmp8
tmp10 = -0.5
tmp11 = tmp7 * tmp10
tmp12 = tmp8 * tmp8
tmp13 = tmp12 * tmp8
tmp14 = tmp11 * tmp13
tmp15 = 0.0009765625
tmp16 = tmp14 * tmp15
tmp17 = 2.0
tmp18 = tmp3 * tmp17
tmp19 = tmp16 * tmp18
tmp20 = tmp9 + tmp19
tmp21 = tmp20.to(tl.float32)
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp21, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/3h/c3hwz76pisnffnc7nnxhxwrt6t4rxytpglcpb6fhj2bl3pieorlm.py
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward]
# Source node to ATen node mapping:
# relu_15 => relu_15
# Graph fragment:
# %relu_15 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%view_180,), kwargs = {})
# %pow_85 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%relu_15, 1.0), kwargs = {})
# %mul_256 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_85, 2.0), kwargs = {})
# %mul_257 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_186, %mul_256), kwargs = {})
# %le_1 : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu_15, 0), kwargs = {})
# %full_default_17 : [num_users=16] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le_1, %full_default_17, %mul_257), kwargs = {})
triton_poi_fused_mul_pow_relu_threshold_backward_3 = async_compile.triton('triton_poi_fused_mul_pow_relu_threshold_backward_3', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 268435456},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_pow_relu_threshold_backward_3', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_pow_relu_threshold_backward_3(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 268435456
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.full([1], 0, tl.int32)
tmp2 = triton_helpers.maximum(tmp1, tmp0)
tmp3 = 0.0
tmp4 = tmp2 <= tmp3
tmp6 = 2.0
tmp7 = tmp2 * tmp6
tmp8 = tmp5 * tmp7
tmp9 = tl.where(tmp4, tmp3, tmp8)
tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ha/chamkr3xbm5j7zmqnyuw7rzcqqcagrpu5phejb47d5uhs5kil4oo.py
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# rms_norm_61 => convert_element_type_312
# Graph fragment:
# %convert_element_type_341 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_188, torch.float32), kwargs = {})
# %convert_element_type_312 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_175, torch.float32), kwargs = {})
# %mul_258 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %convert_element_type_312), kwargs = {})
# %mul_259 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %rsqrt_61), kwargs = {})
# %sum_12 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_258, [2], True), kwargs = {})
# %div_4 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_21, 1024), kwargs = {})
# %pow_87 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 1.0), kwargs = {})
# %mul_262 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_87, 2.0), kwargs = {})
# %mul_263 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_4, %mul_262), kwargs = {})
# %add_182 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_259, %mul_263), kwargs = {})
# %convert_element_type_342 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_182, torch.bfloat16), kwargs = {})
# %add_183 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_332, %convert_element_type_342), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_4 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_4', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_4', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_4(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp10 = tmp1 * tmp9
tmp11 = -0.5
tmp12 = tmp7 * tmp11
tmp13 = tmp9 * tmp9
tmp14 = tmp13 * tmp9
tmp15 = tmp12 * tmp14
tmp16 = 0.0009765625
tmp17 = tmp15 * tmp16
tmp18 = 2.0
tmp19 = tmp3 * tmp18
tmp20 = tmp17 * tmp19
tmp21 = tmp10 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp8 + tmp22
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/js/cjsujn4ityfatbyjrtktt6lsm5m7xfr7x34cjcfu46bev6xvbhap.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {})
triton_per_fused_zeros_5 = async_compile.triton('triton_per_fused_zeros_5', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_zeros_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused_zeros_5(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
R0_BLOCK: tl.constexpr = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 128*x0), None).to(tl.float32)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp5 = tl.sum(tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/yq/cyqtxyu2pr6foycexsqodih4tqalw2u65pbtgsks4qsewqny5wvg.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {})
triton_poi_fused_zeros_6 = async_compile.triton('triton_poi_fused_zeros_6', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'y': 8, 'x': 65536}, tile_hint=TileHint.SQUARE,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_zeros_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_zeros_6(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 8
xnumel = 65536
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, YBLOCK], True, tl.int1)
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (y0 + 8*x1), ymask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.0
tmp3 = tmp1 - tmp2
tl.store(out_ptr0 + (x1 + 65536*y0), tmp3, ymask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/3y/c3yyb2vkcu6dpcv66oqlrqufxdjxjeocdhudp7s5hgiowvxx7smq.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {})
triton_tem_fused_zeros_7 = async_compile.triton('triton_tem_fused_zeros_7', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=3,
num_warps=8,
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'kernel_name': 'triton_tem_fused_zeros_7', 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
)
@triton.jit
def triton_tem_fused_zeros_7(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
Q = arg_Q
K = arg_K
V = arg_V
LSE = arg_LSE
DELTA = arg_DELTA
DO = arg_DO
DQ = arg_DQ
DV = arg_DV
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
Q_NUM_BLKS = arg_Q_NUM_BLKS
Q_IDX = arg_Q_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
FULL_Q_IDX = arg_FULL_Q_IDX
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = 67108864, 128, 1024, 1
stride_kz, stride_kh, stride_kn, stride_kd = 67108864, 128, 1024, 1
stride_vz, stride_vh, stride_vn, stride_vd = 67108864, 128, 1024, 1
stride_doz, stride_doh, stride_dom, stride_dod = 67108864, 128, 1024, 1
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 67108864, 128, 1024, 1
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 67108864, 128, 1024, 1
ZQ = 1
HQ = 8
HKV = 8
Q_LEN = 65536
ZKV = 1
KV_LEN = 65536
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2)
off_zq = off_hz // HKV # q batch idx
off_hkv = off_hz % HKV # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = 1
SPARSE_HQ = 1
sparse_idx_z = off_zq % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = 512
stride_kv_idx_h = 262144
stride_kv_idx_m = 512
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = 512
stride_q_idx_h = 262144
stride_q_idx_n = 512
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
index_v = offs_v[None, :]
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
dk *= SM_SCALE
if SAFE_HEAD_DIM:
mask = index_n < KV_LEN
else:
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
xindex = index_k + 128*index_n + 8388608*off_hkv + 67108864*off_zq
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
@triton.jit
def bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 65536
KV_LEN = 65536
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_n in range(0, hi - 1):
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_n in range(0, hi):
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
return dq
@triton.jit
def bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
# NB reversed order to since K is transposed
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
# that the M reads out of bounds prior to the last loop
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
tmp0 = (qk)
post_mod_scores = tmp0
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp1 = (m)
tmp2 = (n)
tmp3 = tmp1 >= tmp2
tmp4 = tl.load(in_ptr16 + tmp1)
tmp5 = tl.load(in_ptr16 + tmp2)
tmp6 = tmp4 == tmp5
tmp7 = tmp3 & tmp6
mask_mod_output = tmp7
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
# NB reversed order to since V is transposed
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
tmp8 = (ds)
grad_scores = tmp8
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
return dq
@triton.jit
def bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 65536
KV_LEN = 65536
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_m in range(0, hi - 1):
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_m in range(0, hi):
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
return dk, dv
@triton.jit
def bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
# NB reversed order since Q is transposed
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
lse = tl.load(LSE + offs_m1)
else:
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
# that the n reads out of bounds prior to the last loop
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
pre_mod_scores = qkT
tmp9 = (qkT)
post_mod_scores = tmp9
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp10 = (m)
tmp11 = (n)
tmp12 = tmp10 >= tmp11
tmp13 = tl.load(in_ptr16 + tmp10)
tmp14 = tl.load(in_ptr16 + tmp11)
tmp15 = tmp13 == tmp14
tmp16 = tmp12 & tmp15
mask_mod_output = tmp16
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
tmp17 = (dsT)
grad_scores = tmp17
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
dsT = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
return dk, dv
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_DIM: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
''', device_str='cuda')
meta0 = {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.12, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/zf/czf7pjyultaixltpmtzbkgpcs6y2wothlasa3ypnkt5l3damv72z.py
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten.sum, aten._to_copy]
# Source node to ATen node mapping:
# v_43 => convert_element_type_308, convert_element_type_309, mul_234
# Graph fragment:
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {})
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {})
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {})
# %mul_234 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_308, %rsqrt_60), kwargs = {})
# %convert_element_type_309 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_234, torch.bfloat16), kwargs = {})
# %mul_267 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %convert_element_type_309), kwargs = {})
# %sum_14 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_267,), kwargs = {})
triton_red_fused__to_copy_mul_sum_8 = async_compile.triton('triton_red_fused__to_copy_mul_sum_8', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 512, 'r0_': 131072},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_mul_sum_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 512
r0_numel = 131072
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
_tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp8 = tl.load(in_ptr3 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(xmask, tmp5, _tmp4)
tmp7 = tmp6.to(tl.float32)
tmp9 = tmp7 * tmp8
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp0 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(xmask, tmp14, _tmp13)
tmp4 = tl.sum(_tmp4, 1)[:, None]
tmp13 = tl.sum(_tmp13, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp4, xmask)
tl.store(out_ptr1 + (x0), tmp13, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/id/cidd6fgxlf7gs3tonqx5yreybf2dwl3utnhsnlnvarag43y63rtu.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
# Source node to ATen node mapping:
# Graph fragment:
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {})
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {})
triton_per_fused_mul_sum_9 = async_compile.triton('triton_per_fused_mul_sum_9', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1, 'r0_': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 3), 'tt.equal_to': (2,)}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_sum_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused_mul_sum_9(in_ptr0, out_ptr0, xnumel, r0_numel):
xnumel = 1
XBLOCK: tl.constexpr = 1
r0_numel = 512
R0_BLOCK: tl.constexpr = 512
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_0 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_0), None)
tmp1 = tl.broadcast_to(tmp0, [R0_BLOCK])
tmp3 = triton_helpers.promote_to_tensor(tl.sum(tmp1, 0))
tl.store(out_ptr0 + (tl.full([1], 0, tl.int32)), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/vc/cvckp5jouod57mmfqeg2deebnbb7ny7dnkuh4ptuho2w6tgsf36c.py
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_43 => convert_element_type_308
# Graph fragment:
# %mul_266 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %select_101), kwargs = {})
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {})
# %convert_element_type_349 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_266, torch.float32), kwargs = {})
# %mul_268 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %convert_element_type_308), kwargs = {})
# %mul_269 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %rsqrt_60), kwargs = {})
# %sum_15 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_268, [3], True), kwargs = {})
# %div_5 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_22, 128), kwargs = {})
# %pow_89 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_308, 1.0), kwargs = {})
# %mul_272 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_89, 2.0), kwargs = {})
# %mul_273 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_5, %mul_272), kwargs = {})
# %add_185 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_269, %mul_273), kwargs = {})
# %convert_element_type_350 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_185, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_10 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_10', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (78))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (78))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/sw/cswnxerjagcklzy6jfmvlvkyel6hqazhxrakyffns5ponz5c5yzq.py
# Topologically Sorted Source Nodes: [k_43, q_43], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# k_43 => convert_element_type_302
# q_43 => convert_element_type_300
# Graph fragment:
# %cat_30 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_187, %add_186], 3), kwargs = {})
# %cat_31 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_189, %add_188], 3), kwargs = {})
# %convert_element_type_302 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_139, torch.float32), kwargs = {})
# %mul_282 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %convert_element_type_302), kwargs = {})
# %mul_283 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %rsqrt_59), kwargs = {})
# %sum_16 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_282, [3], True), kwargs = {})
# %div_6 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_23, 128), kwargs = {})
# %pow_91 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_302, 1.0), kwargs = {})
# %mul_286 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_91, 2.0), kwargs = {})
# %mul_287 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_6, %mul_286), kwargs = {})
# %add_190 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_283, %mul_287), kwargs = {})
# %convert_element_type_356 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_190, torch.bfloat16), kwargs = {})
# %convert_element_type_300 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_138, torch.float32), kwargs = {})
# %mul_288 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %convert_element_type_300), kwargs = {})
# %mul_289 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %rsqrt_58), kwargs = {})
# %sum_17 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_288, [3], True), kwargs = {})
# %div_7 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_24, 128), kwargs = {})
# %pow_93 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_300, 1.0), kwargs = {})
# %mul_292 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_93, 2.0), kwargs = {})
# %mul_293 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_7, %mul_292), kwargs = {})
# %add_191 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_289, %mul_293), kwargs = {})
# %convert_element_type_358 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_191, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11 = async_compile.triton('triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr3': '*bf16', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 21, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
x1 = xindex // 8
x0 = (xindex % 8)
_tmp55 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp51 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp0 = r0_2
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 64, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.load(in_ptr1 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0)
tmp8 = -tmp7
tmp9 = tmp6 * tmp8
tmp10 = tl.load(in_ptr0 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.load(in_ptr2 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0)
tmp13 = tmp11 * tmp12
tmp14 = tmp9 + tmp13
tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype)
tmp16 = tl.where(tmp4, tmp14, tmp15)
tmp17 = tmp0 >= tmp3
tmp18 = tl.full([1, 1], 128, tl.int64)
tmp19 = tmp0 < tmp18
tmp20 = tl.load(in_ptr0 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tmp22 = tl.load(in_ptr2 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0)
tmp23 = tmp21 * tmp22
tmp24 = tl.load(in_ptr0 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp25 = tmp24.to(tl.float32)
tmp26 = tl.load(in_ptr1 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0)
tmp27 = tmp25 * tmp26
tmp28 = tmp23 + tmp27
tmp29 = tl.full(tmp28.shape, 0.0, tmp28.dtype)
tmp30 = tl.where(tmp17, tmp28, tmp29)
tmp31 = tl.where(tmp4, tmp16, tmp30)
tmp32 = tl.load(in_ptr3 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp33 * tmp8
tmp35 = tl.load(in_ptr3 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp36 = tmp35.to(tl.float32)
tmp37 = tmp36 * tmp12
tmp38 = tmp34 + tmp37
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp4, tmp38, tmp39)
tmp41 = tl.load(in_ptr3 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp42 * tmp22
tmp44 = tl.load(in_ptr3 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp45 = tmp44.to(tl.float32)
tmp46 = tmp45 * tmp26
tmp47 = tmp43 + tmp46
tmp48 = tl.full(tmp47.shape, 0.0, tmp47.dtype)
tmp49 = tl.where(tmp17, tmp47, tmp48)
tmp50 = tl.where(tmp4, tmp40, tmp49)
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp31 * tmp52
tmp54 = tl.broadcast_to(tmp53, [XBLOCK, R0_BLOCK])
tmp56 = _tmp55 + tmp54
_tmp55 = tl.where(r0_mask, tmp56, _tmp55)
tl.store(out_ptr0 + (r0_2 + 128*x3), tmp31, r0_mask)
tl.store(out_ptr1 + (r0_2 + 128*x3), tmp50, r0_mask)
tmp55 = tl.sum(_tmp55, 1)[:, None]
tmp58 = tl.load(in_ptr5 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp57 = tl.load(out_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0)
tmp67 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp59 = tmp57 * tmp58
tmp60 = -0.5
tmp61 = tmp55 * tmp60
tmp62 = tmp58 * tmp58
tmp63 = tmp62 * tmp58
tmp64 = tmp61 * tmp63
tmp65 = 0.0078125
tmp66 = tmp64 * tmp65
tmp68 = tmp67.to(tl.float32)
tmp69 = 2.0
tmp70 = tmp68 * tmp69
tmp71 = tmp66 * tmp70
tmp72 = tmp59 + tmp71
tmp73 = tmp72.to(tl.float32)
tl.store(out_ptr3 + (r0_2 + 128*x0 + 3072*x1), tmp73, r0_mask)
_tmp79 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp74 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0)
tmp75 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp76 = tmp75.to(tl.float32)
tmp77 = tmp74 * tmp76
tmp78 = tl.broadcast_to(tmp77, [XBLOCK, R0_BLOCK])
tmp80 = _tmp79 + tmp78
_tmp79 = tl.where(r0_mask, tmp80, _tmp79)
tmp79 = tl.sum(_tmp79, 1)[:, None]
tmp82 = tl.load(in_ptr7 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp81 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0)
tmp91 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp83 = tmp81 * tmp82
tmp84 = -0.5
tmp85 = tmp79 * tmp84
tmp86 = tmp82 * tmp82
tmp87 = tmp86 * tmp82
tmp88 = tmp85 * tmp87
tmp89 = 0.0078125
tmp90 = tmp88 * tmp89
tmp92 = tmp91.to(tl.float32)
tmp93 = 2.0
tmp94 = tmp92 * tmp93
tmp95 = tmp90 * tmp94
tmp96 = tmp83 + tmp95
tmp97 = tmp96.to(tl.float32)
tl.store(out_ptr5 + (r0_2 + 128*x0 + 3072*x1), tmp97, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/5d/c5dlh2bjlk4ei2idmqft6cji5oj5hyp7yehj6tk5eb5qsk2hiubd.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_18 : [num_users=30] = call_function[target=torch.ops.aten.full.default](args = ([4, 1024, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_18, %mm_69, 0, 3), kwargs = {})
# %slice_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_18, %view_196, 0, 0, 3), kwargs = {})
# %add_193 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default, %slice_scatter_default), kwargs = {})
triton_poi_fused_add_select_backward_12 = async_compile.triton('triton_poi_fused_add_select_backward_12', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 4194304},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_select_backward_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_select_backward_12(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4194304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x1 = xindex // 1048576
x0 = (xindex % 1048576)
x2 = xindex
tmp3 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
tmp0 = x1
tmp1 = tl.full([1], 3, tl.int32)
tmp2 = tmp0 == tmp1
tmp4 = 0.0
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = tl.full([1], 3, tl.int64)
tmp7 = tmp0 < tmp6
tmp8 = tl.load(in_ptr1 + (x2), tmp7, other=0.0).to(tl.float32)
tmp9 = tl.where(tmp7, tmp8, tmp4)
tmp10 = tmp5 + tmp9
tl.store(out_ptr0 + (x2), tmp10, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/rh/crhlyhq55hgkmgz65givk2zndj2w4bvvle5lwhpxcz4poewwdmnz.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %mul_296 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_99), kwargs = {})
triton_poi_fused_add_mul_13 = async_compile.triton('triton_poi_fused_add_mul_13', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_13(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (46))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/hc/chc7hawo3j62h3g5bffii536wpkihnsguegjwg6n26mklfjjcjn3.py
# Topologically Sorted Source Nodes: [rms_norm_57], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# rms_norm_57 => convert_element_type_292
# Graph fragment:
# %convert_element_type_373 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_200, torch.float32), kwargs = {})
# %convert_element_type_292 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_163, torch.float32), kwargs = {})
# %mul_300 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %convert_element_type_292), kwargs = {})
# %mul_301 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %rsqrt_57), kwargs = {})
# %sum_20 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_300, [2], True), kwargs = {})
# %div_8 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_25, 1024), kwargs = {})
# %pow_96 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_292, 1.0), kwargs = {})
# %mul_304 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_96, 2.0), kwargs = {})
# %mul_305 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_8, %mul_304), kwargs = {})
# %add_195 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_301, %mul_305), kwargs = {})
# %convert_element_type_374 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_195, torch.bfloat16), kwargs = {})
# %add_196 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_296, %convert_element_type_374), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_14 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_14', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_14(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp10 = tmp1 * tmp9
tmp11 = -0.5
tmp12 = tmp7 * tmp11
tmp13 = tmp9 * tmp9
tmp14 = tmp13 * tmp9
tmp15 = tmp12 * tmp14
tmp16 = 0.0009765625
tmp17 = tmp15 * tmp16
tmp18 = 2.0
tmp19 = tmp3 * tmp18
tmp20 = tmp17 * tmp19
tmp21 = tmp10 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp8 + tmp22
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/dq/cdqthdpfu5gimvpxpp775dnqoysmbyfs3mtk4f5xsqfjde6wil3e.py
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_40 => convert_element_type_288
# Graph fragment:
# %mul_308 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_164, %select_94), kwargs = {})
# %convert_element_type_288 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_131, torch.float32), kwargs = {})
# %convert_element_type_381 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_308, torch.float32), kwargs = {})
# %mul_310 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %convert_element_type_288), kwargs = {})
# %mul_311 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %rsqrt_56), kwargs = {})
# %sum_23 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_310, [3], True), kwargs = {})
# %div_9 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_26, 128), kwargs = {})
# %pow_98 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_288, 1.0), kwargs = {})
# %mul_314 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_98, 2.0), kwargs = {})
# %mul_315 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_9, %mul_314), kwargs = {})
# %add_198 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_311, %mul_315), kwargs = {})
# %convert_element_type_382 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_198, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_15 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_15', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_15(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (76))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (76))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/rw/crwt7reaapsh2trvps6lvjbbnnscw3ujc7z7ghzdznxwumm6b43r.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {})
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {})
# %mul_297 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %add_165), kwargs = {})
# %sum_19 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_297,), kwargs = {})
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {})
# %mul_337 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %convert_element_type_11), kwargs = {})
# %sum_26 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_337,), kwargs = {})
# %mul_338 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_92), kwargs = {})
# %mul_339 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %add_153), kwargs = {})
# %sum_27 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_339,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_16 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_16', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 9, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_16(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_ptr8 + (44))
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp19 * tmp31
tl.store(out_ptr4 + (r0_1 + 1024*x0), tmp32, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/2e/c2e4k3473pxml3aiowmqgdvsvsjm635x3t5jxglwwuodopxjzrqe.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {})
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {})
triton_red_fused__to_copy_add_mul_sum_17 = async_compile.triton('triton_red_fused__to_copy_add_mul_sum_17', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 1, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 3), 'tt.equal_to': (2,)}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mul_sum_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_mul_sum_17(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1
r0_numel = 65536
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_0 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_0), None, eviction_policy='evict_first')
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tmp3
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp2, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/yb/cybckfvcqgm55idccl4nb26mpv4fp2n6sltf2mq7jzlrn37bjned.py
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_37 => convert_element_type_268
# Graph fragment:
# %mul_350 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_184, %select_87), kwargs = {})
# %convert_element_type_268 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_122, torch.float32), kwargs = {})
# %convert_element_type_413 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_350, torch.float32), kwargs = {})
# %mul_352 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %convert_element_type_268), kwargs = {})
# %mul_353 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %rsqrt_52), kwargs = {})
# %sum_31 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_352, [3], True), kwargs = {})
# %div_13 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_30, 128), kwargs = {})
# %pow_107 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_268, 1.0), kwargs = {})
# %mul_356 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_107, 2.0), kwargs = {})
# %mul_357 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_13, %mul_356), kwargs = {})
# %add_214 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_353, %mul_357), kwargs = {})
# %convert_element_type_414 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_214, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_18 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_18', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_18', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_18(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (74))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (74))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/3w/c3wolzn6xji5tvbc62o57dehpylfh33s7yixvo3nby3vsodrtfpj.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten.mul, aten._to_copy, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %mul_294 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_100), kwargs = {})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {})
# %mul_336 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_93), kwargs = {})
# %add_207 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_294, %mul_336), kwargs = {})
# %add_221 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_212, %view_219), kwargs = {})
# %mul_378 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_86), kwargs = {})
# %mul_379 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %convert_element_type_11), kwargs = {})
# %sum_34 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_379,), kwargs = {})
# %add_223 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_207, %mul_378), kwargs = {})
# %mul_380 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_85), kwargs = {})
# %mul_381 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %add_141), kwargs = {})
# %sum_35 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_381,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_19 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_19', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_19', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_19(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (47))
tmp4 = tl.broadcast_to(tmp3, [R0_BLOCK])
tmp7 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp10 = tl.load(in_ptr1 + (45))
tmp11 = tl.broadcast_to(tmp10, [R0_BLOCK])
tmp15 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp16 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr1 + (43))
tmp19 = tl.broadcast_to(tmp18, [R0_BLOCK])
tmp23 = tl.load(in_ptr1 + (42))
tmp24 = tl.broadcast_to(tmp23, [R0_BLOCK])
tmp27 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last')
tmp36 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tmp9 = tmp7 + tmp8
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp9 * tmp12
tmp14 = tmp6 + tmp13
tmp17 = tmp15 + tmp16
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp17 * tmp20
tmp22 = tmp14 + tmp21
tmp25 = tmp24.to(tl.float32)
tmp26 = tmp17 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp30 = tmp28 * tmp29
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp17 * tmp31
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK])
tmp35 = triton_helpers.promote_to_tensor(tl.sum(tmp33, 0))
tmp37 = tmp17 * tmp36
tmp38 = tl.broadcast_to(tmp37, [R0_BLOCK])
tmp40 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0))
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp22, None)
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp26, None)
tl.store(out_ptr1 + (x0), tmp35, None)
tl.store(out_ptr2 + (x0), tmp40, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/vm/cvmnq4yi35zsfxtrc6yd7p7r3wbfy6nb6zghevp246rtyncjsb46.py
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# v_34 => convert_element_type_248, convert_element_type_249, mul_187
# Graph fragment:
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {})
# %mul_187 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_248, %rsqrt_48), kwargs = {})
# %convert_element_type_249 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_187, torch.bfloat16), kwargs = {})
# %mul_391 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %convert_element_type_249), kwargs = {})
# %sum_37 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_391,), kwargs = {})
triton_red_fused__to_copy_mul_sum_20 = async_compile.triton('triton_red_fused__to_copy_mul_sum_20', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 512, 'r0_': 131072},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_20', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_mul_sum_20(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 512
r0_numel = 131072
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0)
tmp2 = tmp1.to(tl.float32)
tmp4 = tmp2 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp0 * tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(xmask, tmp9, _tmp8)
tmp8 = tl.sum(_tmp8, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp8, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/fe/cfepqopoqmnpzbgue7ezqbcqee5l7skpuhz2cft2vnpout7de5lb.py
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_34 => convert_element_type_248
# Graph fragment:
# %mul_390 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %select_81), kwargs = {})
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {})
# %convert_element_type_444 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_390, torch.float32), kwargs = {})
# %mul_392 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %convert_element_type_248), kwargs = {})
# %mul_393 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %rsqrt_48), kwargs = {})
# %sum_38 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_392, [3], True), kwargs = {})
# %div_17 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_34, 128), kwargs = {})
# %pow_116 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_248, 1.0), kwargs = {})
# %mul_396 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_116, 2.0), kwargs = {})
# %mul_397 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_17, %mul_396), kwargs = {})
# %add_229 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_393, %mul_397), kwargs = {})
# %convert_element_type_445 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_229, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_21 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_21', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_21(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (72))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (72))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/yj/cyjkhrb27hvv3l3kua6flocagrnqlqmc5754eswnsklhtleriys4.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {})
# %mul_420 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_79), kwargs = {})
triton_poi_fused_add_mul_22 = async_compile.triton('triton_poi_fused_add_mul_22', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (40))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/5l/c5l5r3u6uyjim3crjlcztqsca7jjltxx4qahm4o4c3sfeyhnral7.py
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_31 => convert_element_type_228
# Graph fragment:
# %mul_430 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_224, %select_75), kwargs = {})
# %convert_element_type_228 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_104, torch.float32), kwargs = {})
# %convert_element_type_475 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_430, torch.float32), kwargs = {})
# %mul_432 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %convert_element_type_228), kwargs = {})
# %mul_433 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %rsqrt_44), kwargs = {})
# %sum_45 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_432, [3], True), kwargs = {})
# %div_21 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_38, 128), kwargs = {})
# %pow_125 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_228, 1.0), kwargs = {})
# %mul_436 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_125, 2.0), kwargs = {})
# %mul_437 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_21, %mul_436), kwargs = {})
# %add_244 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_433, %mul_437), kwargs = {})
# %convert_element_type_476 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_244, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_23 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_23', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (70))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (70))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ee/ceemqyuqc5se34jwuz7av5jovsufciprlj22yvbhuuydkqt6tt3k.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {})
# %mul_419 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %convert_element_type_11), kwargs = {})
# %sum_41 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_419,), kwargs = {})
# %mul_421 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %add_130), kwargs = {})
# %sum_42 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_421,), kwargs = {})
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {})
# %mul_459 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %convert_element_type_11), kwargs = {})
# %sum_48 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_459,), kwargs = {})
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {})
# %mul_461 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %add_119), kwargs = {})
# %sum_49 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_461,), kwargs = {})
# %mul_463 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %add_36), kwargs = {})
# %sum_50 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_463,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_24 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_24', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'in_ptr9': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 10, 'num_reduction': 5, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_24(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_ptr8 + (38))
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK])
tmp33 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp19 * tmp31
tmp34 = tmp32 * tmp33
tmp35 = tl.broadcast_to(tmp34, [R0_BLOCK])
tmp37 = triton_helpers.promote_to_tensor(tl.sum(tmp35, 0))
tl.store(out_ptr5 + (r0_1 + 1024*x0), tmp32, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
tl.store(out_ptr4 + (x0), tmp37, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/jg/cjgcteapcf4w5vv6fwhouzamepnj6zi7b3lj4lab2tysbpwawhfk.py
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_28 => convert_element_type_208
# Graph fragment:
# %mul_472 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_244, %select_68), kwargs = {})
# %convert_element_type_208 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_95, torch.float32), kwargs = {})
# %convert_element_type_507 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_472, torch.float32), kwargs = {})
# %mul_474 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %convert_element_type_208), kwargs = {})
# %mul_475 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %rsqrt_40), kwargs = {})
# %sum_53 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_474, [3], True), kwargs = {})
# %div_25 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_42, 128), kwargs = {})
# %pow_134 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_208, 1.0), kwargs = {})
# %mul_478 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_134, 2.0), kwargs = {})
# %mul_479 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_25, %mul_478), kwargs = {})
# %add_259 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_475, %mul_479), kwargs = {})
# %convert_element_type_508 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_259, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_25 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_25', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (68))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (68))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/tg/ctg6mqjt4r3k5i3go5y7jwzewso2vhohvzrh7cb2aqs7s3qv7ndn.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {})
# %mul_418 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_80), kwargs = {})
# %add_238 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_223, %mul_418), kwargs = {})
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {})
# %mul_458 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_74), kwargs = {})
# %add_253 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_238, %mul_458), kwargs = {})
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {})
# %mul_500 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_67), kwargs = {})
# %mul_501 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %convert_element_type_11), kwargs = {})
# %sum_56 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_501,), kwargs = {})
# %add_268 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_253, %mul_500), kwargs = {})
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {})
# %mul_503 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %add_107), kwargs = {})
# %sum_57 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_503,), kwargs = {})
# %mul_505 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %add_58), kwargs = {})
# %sum_58 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_505,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_26 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_26', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_26', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 15, 'num_reduction': 3, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_26(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (36))
tmp18 = tl.broadcast_to(tmp17, [R0_BLOCK])
tmp21 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp26 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp27 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp28 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp30 = tl.load(in_ptr5 + (41))
tmp31 = tl.broadcast_to(tmp30, [R0_BLOCK])
tmp35 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp36 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp38 = tl.load(in_ptr5 + (39))
tmp39 = tl.broadcast_to(tmp38, [R0_BLOCK])
tmp43 = tl.load(in_ptr5 + (37))
tmp44 = tl.broadcast_to(tmp43, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp2 * tmp19
tmp22 = tmp20 * tmp21
tmp23 = tl.broadcast_to(tmp22, [R0_BLOCK])
tmp25 = triton_helpers.promote_to_tensor(tl.sum(tmp23, 0))
tmp29 = tmp27 + tmp28
tmp32 = tmp31.to(tl.float32)
tmp33 = tmp29 * tmp32
tmp34 = tmp26 + tmp33
tmp37 = tmp35 + tmp36
tmp40 = tmp39.to(tl.float32)
tmp41 = tmp37 * tmp40
tmp42 = tmp34 + tmp41
tmp45 = tmp44.to(tl.float32)
tmp46 = tmp2 * tmp45
tmp47 = tmp42 + tmp46
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp47, None)
tl.store(out_ptr3 + (r0_1 + 1024*x0), tmp20, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp25, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/am/camyez2exa6htwupvbzdzegpgqt727hcqpldhjcpx3zt5nnd7onq.py
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_25 => convert_element_type_188
# Graph fragment:
# %mul_514 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_264, %select_61), kwargs = {})
# %convert_element_type_188 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_86, torch.float32), kwargs = {})
# %convert_element_type_539 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_514, torch.float32), kwargs = {})
# %mul_516 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %convert_element_type_188), kwargs = {})
# %mul_517 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %rsqrt_36), kwargs = {})
# %sum_61 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_516, [3], True), kwargs = {})
# %div_29 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_46, 128), kwargs = {})
# %pow_143 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_188, 1.0), kwargs = {})
# %mul_520 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_143, 2.0), kwargs = {})
# %mul_521 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_29, %mul_520), kwargs = {})
# %add_275 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_517, %mul_521), kwargs = {})
# %convert_element_type_540 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_275, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_27 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_27', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_27(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (66))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (66))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/w5/cw5lm2xilanikcb3utse3lrzhou2ow25ytp5um7ilwwffymj7vvv.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {})
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {})
triton_poi_fused_add_mul_28 = async_compile.triton('triton_poi_fused_add_mul_28', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_28(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (34))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/2m/c2mids55xghemi35me3p4us3qt5tzutbm52tf45sv4r3kifthkqn.py
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_22 => convert_element_type_168
# Graph fragment:
# %mul_556 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_284, %select_54), kwargs = {})
# %convert_element_type_168 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_77, torch.float32), kwargs = {})
# %convert_element_type_571 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_556, torch.float32), kwargs = {})
# %mul_558 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %convert_element_type_168), kwargs = {})
# %mul_559 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %rsqrt_32), kwargs = {})
# %sum_69 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_558, [3], True), kwargs = {})
# %div_33 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_50, 128), kwargs = {})
# %pow_152 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_168, 1.0), kwargs = {})
# %mul_562 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_152, 2.0), kwargs = {})
# %mul_563 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_33, %mul_562), kwargs = {})
# %add_291 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_559, %mul_563), kwargs = {})
# %convert_element_type_572 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_291, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_29 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_29', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_29(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (64))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (64))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/7g/c7ggnywfvwyw3iiwnlm66m4wogxu6ssncnpxmdwolqb7dfrdx2i3.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {})
# %mul_586 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_52), kwargs = {})
triton_poi_fused_add_mul_30 = async_compile.triton('triton_poi_fused_add_mul_30', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_30(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (32))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/7d/c7dbizqshojxmody3a5tubxthr47qpwkudrrqa5hfb6qgxo5ekow.py
# Topologically Sorted Source Nodes: [x, rms_norm_29], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum, aten.div, aten.pow]
# Source node to ATen node mapping:
# rms_norm_29 => convert_element_type_152
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {})
# %mul_542 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_60), kwargs = {})
# %mul_543 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %convert_element_type_11), kwargs = {})
# %sum_64 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_543,), kwargs = {})
# %add_284 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_268, %mul_542), kwargs = {})
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {})
# %mul_545 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %add_95), kwargs = {})
# %sum_65 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_545,), kwargs = {})
# %mul_546 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %select_56), kwargs = {})
# %mul_547 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %add_80), kwargs = {})
# %sum_66 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_547,), kwargs = {})
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {})
# %mul_584 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_53), kwargs = {})
# %mul_585 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %convert_element_type_11), kwargs = {})
# %sum_72 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_585,), kwargs = {})
# %add_300 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_284, %mul_584), kwargs = {})
# %mul_587 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %add_83), kwargs = {})
# %sum_73 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_587,), kwargs = {})
# %convert_element_type_595 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_279, torch.float32), kwargs = {})
# %convert_element_type_152 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_81, torch.float32), kwargs = {})
# %mul_590 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %convert_element_type_152), kwargs = {})
# %mul_591 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %rsqrt_29), kwargs = {})
# %sum_74 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_590, [2], True), kwargs = {})
# %div_36 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_53, 1024), kwargs = {})
# %pow_159 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_152, 1.0), kwargs = {})
# %mul_594 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_159, 2.0), kwargs = {})
# %mul_595 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_36, %mul_594), kwargs = {})
# %add_304 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_591, %mul_595), kwargs = {})
# %convert_element_type_596 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_304, torch.bfloat16), kwargs = {})
# %add_305 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_586, %convert_element_type_596), kwargs = {})
# %mul_596 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_49), kwargs = {})
# %mul_597 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %convert_element_type_11), kwargs = {})
# %sum_75 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_597,), kwargs = {})
# %add_306 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_300, %mul_596), kwargs = {})
# %mul_598 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_48), kwargs = {})
# %mul_599 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %add_80), kwargs = {})
# %sum_76 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_599,), kwargs = {})
# %add_307 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_546, %mul_598), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_31 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_31', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*fp32', 'out_ptr6': '*fp32', 'out_ptr7': '*fp32', 'out_ptr8': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_31', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 20, 'num_reduction': 8, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_31(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp24 = tl.load(in_out_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp25 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp26 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp28 = tl.load(in_ptr5 + (35))
tmp29 = tl.broadcast_to(tmp28, [R0_BLOCK])
tmp33 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp34 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp36 = tl.load(in_ptr5 + (33))
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK])
tmp41 = tl.load(in_ptr5 + (31))
tmp42 = tl.broadcast_to(tmp41, [R0_BLOCK])
tmp46 = tl.load(in_ptr5 + (34))
tmp47 = tl.broadcast_to(tmp46, [R0_BLOCK])
tmp50 = tl.load(in_ptr5 + (6))
tmp51 = tl.broadcast_to(tmp50, [R0_BLOCK])
tmp54 = tl.load(in_ptr5 + (30))
tmp55 = tl.broadcast_to(tmp54, [R0_BLOCK])
tmp59 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp61 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last')
tmp68 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp73 = tl.load(in_ptr11 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp82 = tl.load(in_ptr12 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp10 = tmp1 * tmp9
tmp11 = -0.5
tmp12 = tmp7 * tmp11
tmp13 = tmp9 * tmp9
tmp14 = tmp13 * tmp9
tmp15 = tmp12 * tmp14
tmp16 = 0.0009765625
tmp17 = tmp15 * tmp16
tmp18 = 2.0
tmp19 = tmp3 * tmp18
tmp20 = tmp17 * tmp19
tmp21 = tmp10 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp8 + tmp22
tmp27 = tmp25 + tmp26
tmp30 = tmp29.to(tl.float32)
tmp31 = tmp27 * tmp30
tmp32 = tmp24 + tmp31
tmp35 = tmp33 + tmp34
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp35 * tmp38
tmp40 = tmp32 + tmp39
tmp43 = tmp42.to(tl.float32)
tmp44 = tmp23 * tmp43
tmp45 = tmp40 + tmp44
tmp48 = tmp47.to(tl.float32)
tmp49 = tmp27 * tmp48
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp49 * tmp52
tmp56 = tmp55.to(tl.float32)
tmp57 = tmp23 * tmp56
tmp58 = tmp53 + tmp57
tmp60 = tmp59.to(tl.float32)
tmp62 = tmp60 * tmp61
tmp63 = tmp62.to(tl.float32)
tmp64 = tmp27 * tmp63
tmp65 = tl.broadcast_to(tmp64, [R0_BLOCK])
tmp67 = triton_helpers.promote_to_tensor(tl.sum(tmp65, 0))
tmp69 = tmp27 * tmp68
tmp70 = tl.broadcast_to(tmp69, [R0_BLOCK])
tmp72 = triton_helpers.promote_to_tensor(tl.sum(tmp70, 0))
tmp74 = tmp49 * tmp73
tmp75 = tl.broadcast_to(tmp74, [R0_BLOCK])
tmp77 = triton_helpers.promote_to_tensor(tl.sum(tmp75, 0))
tmp78 = tmp35 * tmp63
tmp79 = tl.broadcast_to(tmp78, [R0_BLOCK])
tmp81 = triton_helpers.promote_to_tensor(tl.sum(tmp79, 0))
tmp83 = tmp35 * tmp82
tmp84 = tl.broadcast_to(tmp83, [R0_BLOCK])
tmp86 = triton_helpers.promote_to_tensor(tl.sum(tmp84, 0))
tmp87 = tmp23 * tmp63
tmp88 = tl.broadcast_to(tmp87, [R0_BLOCK])
tmp90 = triton_helpers.promote_to_tensor(tl.sum(tmp88, 0))
tmp91 = tmp23 * tmp73
tmp92 = tl.broadcast_to(tmp91, [R0_BLOCK])
tmp94 = triton_helpers.promote_to_tensor(tl.sum(tmp92, 0))
tl.store(in_out_ptr1 + (r0_1 + 1024*x0), tmp45, None)
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp58, None)
tl.store(out_ptr2 + (x0), tmp67, None)
tl.store(out_ptr3 + (x0), tmp72, None)
tl.store(out_ptr4 + (x0), tmp77, None)
tl.store(out_ptr5 + (x0), tmp81, None)
tl.store(out_ptr6 + (x0), tmp86, None)
tl.store(out_ptr7 + (x0), tmp90, None)
tl.store(out_ptr8 + (x0), tmp94, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/tv/ctvx4p64xr75uwoae5pprzrej2pwv7rbwkrsu5y2zaip2qp6b45i.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten._to_copy, aten.add]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_363 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_18, torch.float32), kwargs = {})
# %select_scatter_default_3 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_363, 0, 1), kwargs = {})
# %convert_element_type_364 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_19, torch.float32), kwargs = {})
# %select_scatter_default_4 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_364, 0, 0), kwargs = {})
# %add_194 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_3, %select_scatter_default_4), kwargs = {})
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_6 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_194, 0, 15), kwargs = {})
# %convert_element_type_395 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_26, torch.float32), kwargs = {})
# %select_scatter_default_10 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_395, 0, 1), kwargs = {})
# %convert_element_type_396 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_27, torch.float32), kwargs = {})
# %select_scatter_default_11 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_396, 0, 0), kwargs = {})
# %add_208 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_10, %select_scatter_default_11), kwargs = {})
# %select_scatter_default_13 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_208, 0, 14), kwargs = {})
# %add_210 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_6, %select_scatter_default_13), kwargs = {})
# %convert_element_type_427 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_34, torch.float32), kwargs = {})
# %select_scatter_default_17 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_427, 0, 1), kwargs = {})
# %convert_element_type_428 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_35, torch.float32), kwargs = {})
# %select_scatter_default_18 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_428, 0, 0), kwargs = {})
# %add_224 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_17, %select_scatter_default_18), kwargs = {})
# %select_scatter_default_20 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_224, 0, 13), kwargs = {})
# %add_226 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_210, %select_scatter_default_20), kwargs = {})
# %convert_element_type_458 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_41, torch.float32), kwargs = {})
# %select_scatter_default_23 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_458, 0, 1), kwargs = {})
# %convert_element_type_459 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_42, torch.float32), kwargs = {})
# %select_scatter_default_24 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_459, 0, 0), kwargs = {})
# %add_239 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_23, %select_scatter_default_24), kwargs = {})
# %select_scatter_default_26 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_239, 0, 12), kwargs = {})
# %add_241 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_226, %select_scatter_default_26), kwargs = {})
# %convert_element_type_489 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_48, torch.float32), kwargs = {})
# %select_scatter_default_29 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_489, 0, 1), kwargs = {})
# %convert_element_type_490 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_49, torch.float32), kwargs = {})
# %select_scatter_default_30 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_490, 0, 0), kwargs = {})
# %add_254 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_29, %select_scatter_default_30), kwargs = {})
# %select_scatter_default_32 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_254, 0, 11), kwargs = {})
# %add_256 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_241, %select_scatter_default_32), kwargs = {})
# %convert_element_type_521 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_56, torch.float32), kwargs = {})
# %select_scatter_default_36 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_521, 0, 1), kwargs = {})
# %convert_element_type_522 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_57, torch.float32), kwargs = {})
# %select_scatter_default_37 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_522, 0, 0), kwargs = {})
# %add_269 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_36, %select_scatter_default_37), kwargs = {})
# %select_scatter_default_39 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_269, 0, 10), kwargs = {})
# %add_271 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_256, %select_scatter_default_39), kwargs = {})
# %convert_element_type_553 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_64, torch.float32), kwargs = {})
# %select_scatter_default_43 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_553, 0, 1), kwargs = {})
# %convert_element_type_554 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_65, torch.float32), kwargs = {})
# %select_scatter_default_44 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_554, 0, 0), kwargs = {})
# %add_285 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_43, %select_scatter_default_44), kwargs = {})
# %select_scatter_default_46 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_285, 0, 9), kwargs = {})
# %add_287 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_271, %select_scatter_default_46), kwargs = {})
# %convert_element_type_585 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_72, torch.float32), kwargs = {})
# %select_scatter_default_50 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_585, 0, 1), kwargs = {})
# %convert_element_type_586 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_73, torch.float32), kwargs = {})
# %select_scatter_default_51 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_586, 0, 0), kwargs = {})
# %add_301 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_50, %select_scatter_default_51), kwargs = {})
# %select_scatter_default_53 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_301, 0, 8), kwargs = {})
# %add_303 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_287, %select_scatter_default_53), kwargs = {})
# %convert_element_type_597 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_75, torch.float32), kwargs = {})
# %select_scatter_default_54 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_597, 0, 1), kwargs = {})
# %convert_element_type_598 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_76, torch.float32), kwargs = {})
# %select_scatter_default_55 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_598, 0, 0), kwargs = {})
# %add_308 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_54, %select_scatter_default_55), kwargs = {})
# %select_scatter_default_56 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_308, 0, 7), kwargs = {})
# %add_309 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_303, %select_scatter_default_56), kwargs = {})
triton_poi_fused__to_copy_add_select_backward_32 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_32', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 32},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'in_ptr15': '*bf16', 'in_ptr16': '*bf16', 'in_ptr17': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 18, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_select_backward_32(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex // 2
x0 = (xindex % 2)
x2 = xindex
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK])
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32)
tmp14 = tl.broadcast_to(tmp13, [XBLOCK])
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32)
tmp22 = tl.broadcast_to(tmp21, [XBLOCK])
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK])
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32)
tmp35 = tl.broadcast_to(tmp34, [XBLOCK])
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32)
tmp39 = tl.broadcast_to(tmp38, [XBLOCK])
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32)
tmp48 = tl.broadcast_to(tmp47, [XBLOCK])
tmp51 = tl.load(in_ptr7 + (0)).to(tl.float32)
tmp52 = tl.broadcast_to(tmp51, [XBLOCK])
tmp60 = tl.load(in_ptr8 + (0)).to(tl.float32)
tmp61 = tl.broadcast_to(tmp60, [XBLOCK])
tmp64 = tl.load(in_ptr9 + (0)).to(tl.float32)
tmp65 = tl.broadcast_to(tmp64, [XBLOCK])
tmp73 = tl.load(in_ptr10 + (0)).to(tl.float32)
tmp74 = tl.broadcast_to(tmp73, [XBLOCK])
tmp77 = tl.load(in_ptr11 + (0)).to(tl.float32)
tmp78 = tl.broadcast_to(tmp77, [XBLOCK])
tmp86 = tl.load(in_ptr12 + (0)).to(tl.float32)
tmp87 = tl.broadcast_to(tmp86, [XBLOCK])
tmp90 = tl.load(in_ptr13 + (0)).to(tl.float32)
tmp91 = tl.broadcast_to(tmp90, [XBLOCK])
tmp99 = tl.load(in_ptr14 + (0)).to(tl.float32)
tmp100 = tl.broadcast_to(tmp99, [XBLOCK])
tmp103 = tl.load(in_ptr15 + (0)).to(tl.float32)
tmp104 = tl.broadcast_to(tmp103, [XBLOCK])
tmp112 = tl.load(in_ptr16 + (0)).to(tl.float32)
tmp113 = tl.broadcast_to(tmp112, [XBLOCK])
tmp116 = tl.load(in_ptr17 + (0)).to(tl.float32)
tmp117 = tl.broadcast_to(tmp116, [XBLOCK])
tmp0 = x1
tmp1 = tl.full([1], 15, tl.int32)
tmp2 = tmp0 == tmp1
tmp3 = x0
tmp4 = tl.full([1], 1, tl.int32)
tmp5 = tmp3 == tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = 0.0
tmp10 = tl.where(tmp5, tmp8, tmp9)
tmp11 = tl.full([1], 0, tl.int32)
tmp12 = tmp3 == tmp11
tmp15 = tmp14.to(tl.float32)
tmp16 = tl.where(tmp12, tmp15, tmp9)
tmp17 = tmp10 + tmp16
tmp18 = tl.where(tmp2, tmp17, tmp9)
tmp19 = tl.full([1], 14, tl.int32)
tmp20 = tmp0 == tmp19
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.where(tmp5, tmp23, tmp9)
tmp27 = tmp26.to(tl.float32)
tmp28 = tl.where(tmp12, tmp27, tmp9)
tmp29 = tmp24 + tmp28
tmp30 = tl.where(tmp20, tmp29, tmp9)
tmp31 = tmp18 + tmp30
tmp32 = tl.full([1], 13, tl.int32)
tmp33 = tmp0 == tmp32
tmp36 = tmp35.to(tl.float32)
tmp37 = tl.where(tmp5, tmp36, tmp9)
tmp40 = tmp39.to(tl.float32)
tmp41 = tl.where(tmp12, tmp40, tmp9)
tmp42 = tmp37 + tmp41
tmp43 = tl.where(tmp33, tmp42, tmp9)
tmp44 = tmp31 + tmp43
tmp45 = tl.full([1], 12, tl.int32)
tmp46 = tmp0 == tmp45
tmp49 = tmp48.to(tl.float32)
tmp50 = tl.where(tmp5, tmp49, tmp9)
tmp53 = tmp52.to(tl.float32)
tmp54 = tl.where(tmp12, tmp53, tmp9)
tmp55 = tmp50 + tmp54
tmp56 = tl.where(tmp46, tmp55, tmp9)
tmp57 = tmp44 + tmp56
tmp58 = tl.full([1], 11, tl.int32)
tmp59 = tmp0 == tmp58
tmp62 = tmp61.to(tl.float32)
tmp63 = tl.where(tmp5, tmp62, tmp9)
tmp66 = tmp65.to(tl.float32)
tmp67 = tl.where(tmp12, tmp66, tmp9)
tmp68 = tmp63 + tmp67
tmp69 = tl.where(tmp59, tmp68, tmp9)
tmp70 = tmp57 + tmp69
tmp71 = tl.full([1], 10, tl.int32)
tmp72 = tmp0 == tmp71
tmp75 = tmp74.to(tl.float32)
tmp76 = tl.where(tmp5, tmp75, tmp9)
tmp79 = tmp78.to(tl.float32)
tmp80 = tl.where(tmp12, tmp79, tmp9)
tmp81 = tmp76 + tmp80
tmp82 = tl.where(tmp72, tmp81, tmp9)
tmp83 = tmp70 + tmp82
tmp84 = tl.full([1], 9, tl.int32)
tmp85 = tmp0 == tmp84
tmp88 = tmp87.to(tl.float32)
tmp89 = tl.where(tmp5, tmp88, tmp9)
tmp92 = tmp91.to(tl.float32)
tmp93 = tl.where(tmp12, tmp92, tmp9)
tmp94 = tmp89 + tmp93
tmp95 = tl.where(tmp85, tmp94, tmp9)
tmp96 = tmp83 + tmp95
tmp97 = tl.full([1], 8, tl.int32)
tmp98 = tmp0 == tmp97
tmp101 = tmp100.to(tl.float32)
tmp102 = tl.where(tmp5, tmp101, tmp9)
tmp105 = tmp104.to(tl.float32)
tmp106 = tl.where(tmp12, tmp105, tmp9)
tmp107 = tmp102 + tmp106
tmp108 = tl.where(tmp98, tmp107, tmp9)
tmp109 = tmp96 + tmp108
tmp110 = tl.full([1], 7, tl.int32)
tmp111 = tmp0 == tmp110
tmp114 = tmp113.to(tl.float32)
tmp115 = tl.where(tmp5, tmp114, tmp9)
tmp118 = tmp117.to(tl.float32)
tmp119 = tl.where(tmp12, tmp118, tmp9)
tmp120 = tmp115 + tmp119
tmp121 = tl.where(tmp111, tmp120, tmp9)
tmp122 = tmp109 + tmp121
tl.store(out_ptr0 + (x2), tmp122, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/qz/cqzbz3hka7het3halgo66rhyco4yixsw5yvj3ce7xtlbrrkl3q44.py
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_19 => convert_element_type_142
# Graph fragment:
# %mul_608 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_312, %select_45), kwargs = {})
# %convert_element_type_142 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_68, torch.float32), kwargs = {})
# %convert_element_type_614 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_608, torch.float32), kwargs = {})
# %mul_610 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %convert_element_type_142), kwargs = {})
# %mul_611 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %rsqrt_27), kwargs = {})
# %sum_79 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_610, [3], True), kwargs = {})
# %div_38 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_55, 128), kwargs = {})
# %pow_164 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_142, 1.0), kwargs = {})
# %mul_614 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_164, 2.0), kwargs = {})
# %mul_615 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_38, %mul_614), kwargs = {})
# %add_312 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_611, %mul_615), kwargs = {})
# %convert_element_type_615 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_312, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_33 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_33', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_33(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (60))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (60))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/cf/ccfrqdp3jfaejzety3ut7kp3qjdtsb2f57ehfivai7bm4enygzav.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {})
# %mul_638 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_43), kwargs = {})
triton_poi_fused_add_mul_34 = async_compile.triton('triton_poi_fused_add_mul_34', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_34(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (28))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/zu/czuup4uodzelkv3jrzpk43tivi7ihymxeqi6lh5xpm4tylw5ufnq.py
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_16 => convert_element_type_122
# Graph fragment:
# %mul_648 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_332, %select_39), kwargs = {})
# %convert_element_type_122 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_59, torch.float32), kwargs = {})
# %convert_element_type_645 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_648, torch.float32), kwargs = {})
# %mul_650 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %convert_element_type_122), kwargs = {})
# %mul_651 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %rsqrt_23), kwargs = {})
# %sum_86 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_650, [3], True), kwargs = {})
# %div_42 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_59, 128), kwargs = {})
# %pow_173 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_122, 1.0), kwargs = {})
# %mul_654 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_173, 2.0), kwargs = {})
# %mul_655 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_42, %mul_654), kwargs = {})
# %add_327 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_651, %mul_655), kwargs = {})
# %convert_element_type_646 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_327, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_35 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_35', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_35(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (58))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (58))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ye/cyevw6rg4ip64gwg4p72a7ympinl3ibkmxfljw4vtadz2pjlhji5.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {})
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {})
# %mul_504 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %select_63), kwargs = {})
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {})
# %mul_637 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %convert_element_type_11), kwargs = {})
# %sum_82 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_637,), kwargs = {})
# %mul_639 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %add_69), kwargs = {})
# %sum_83 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_639,), kwargs = {})
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {})
# %mul_677 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %convert_element_type_11), kwargs = {})
# %sum_89 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_677,), kwargs = {})
# %mul_678 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_37), kwargs = {})
# %mul_679 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %add_58), kwargs = {})
# %sum_90 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_679,), kwargs = {})
# %add_337 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_504, %mul_678), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_36 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_36', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_36', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_36(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp32 = tl.load(in_ptr9 + (36))
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK])
tmp36 = tl.load(in_ptr9 + (4))
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK])
tmp40 = tl.load(in_ptr9 + (26))
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp29 + tmp30
tmp34 = tmp33.to(tl.float32)
tmp35 = tmp31 * tmp34
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp35 * tmp38
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp19 * tmp42
tmp44 = tmp39 + tmp43
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ad/caddjc5juf35pqoynq6vglejv2oh6vukmzqdu6mtjvkncst3kqyo.py
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_13 => convert_element_type_102
# Graph fragment:
# %mul_688 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_352, %select_33), kwargs = {})
# %convert_element_type_102 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_50, torch.float32), kwargs = {})
# %convert_element_type_676 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_688, torch.float32), kwargs = {})
# %mul_690 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %convert_element_type_102), kwargs = {})
# %mul_691 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %rsqrt_19), kwargs = {})
# %sum_93 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_690, [3], True), kwargs = {})
# %div_46 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_63, 128), kwargs = {})
# %pow_182 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_102, 1.0), kwargs = {})
# %mul_694 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_182, 2.0), kwargs = {})
# %mul_695 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_46, %mul_694), kwargs = {})
# %add_343 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_691, %mul_695), kwargs = {})
# %convert_element_type_677 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_343, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_37 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_37', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_37(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (56))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (56))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/7f/c7fhg3d33i52ftlp3mhnqlchtl7dk3n3m3n4hihwykvl52gttcka.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {})
# %mul_636 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_44), kwargs = {})
# %add_321 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_306, %mul_636), kwargs = {})
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {})
# %mul_676 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_38), kwargs = {})
# %add_336 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_321, %mul_676), kwargs = {})
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {})
# %mul_716 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_32), kwargs = {})
# %add_352 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_336, %mul_716), kwargs = {})
# %mul_718 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_31), kwargs = {})
triton_poi_fused_add_mul_38 = async_compile.triton('triton_poi_fused_add_mul_38', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_38', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_38(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (29))
tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
tmp9 = tl.load(in_ptr3 + (x0), None).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (x0), None).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (27))
tmp13 = tl.broadcast_to(tmp12, [XBLOCK])
tmp17 = tl.load(in_ptr5 + (x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (x0), None).to(tl.float32)
tmp20 = tl.load(in_ptr2 + (25))
tmp21 = tl.broadcast_to(tmp20, [XBLOCK])
tmp25 = tl.load(in_ptr2 + (24))
tmp26 = tl.broadcast_to(tmp25, [XBLOCK])
tmp3 = tmp1 + tmp2
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp3 * tmp6
tmp8 = tmp0 + tmp7
tmp11 = tmp9 + tmp10
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp11 * tmp14
tmp16 = tmp8 + tmp15
tmp19 = tmp17 + tmp18
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp19 * tmp22
tmp24 = tmp16 + tmp23
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp19 * tmp27
tl.store(in_out_ptr0 + (x0), tmp24, None)
tl.store(out_ptr0 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/kw/ckwwptmpfidft6quaist42sma7dmjhpwkm7ihduw6nsauc6r776j.py
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_10 => convert_element_type_82
# Graph fragment:
# %mul_728 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_372, %select_27), kwargs = {})
# %convert_element_type_82 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_41, torch.float32), kwargs = {})
# %convert_element_type_707 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_728, torch.float32), kwargs = {})
# %mul_730 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %convert_element_type_82), kwargs = {})
# %mul_731 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %rsqrt_15), kwargs = {})
# %sum_100 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_730, [3], True), kwargs = {})
# %div_50 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_67, 128), kwargs = {})
# %pow_191 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_82, 1.0), kwargs = {})
# %mul_734 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_191, 2.0), kwargs = {})
# %mul_735 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_50, %mul_734), kwargs = {})
# %add_358 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_731, %mul_735), kwargs = {})
# %convert_element_type_708 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_358, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_39 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_39', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_39(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (54))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (54))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/dy/cdyh5oakmezmdxspcqgny4mt5oeduceikz5mfycxi4la3phv73r5.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {})
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {})
# %mul_462 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %select_70), kwargs = {})
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {})
# %mul_717 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %convert_element_type_11), kwargs = {})
# %sum_96 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_717,), kwargs = {})
# %mul_719 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %add_47), kwargs = {})
# %sum_97 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_719,), kwargs = {})
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {})
# %mul_757 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %convert_element_type_11), kwargs = {})
# %sum_103 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_757,), kwargs = {})
# %mul_758 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_25), kwargs = {})
# %mul_759 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %add_36), kwargs = {})
# %sum_104 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_759,), kwargs = {})
# %add_368 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_462, %mul_758), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_40 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_40', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_40', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_40(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp32 = tl.load(in_ptr9 + (38))
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK])
tmp36 = tl.load(in_ptr9 + (2))
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK])
tmp40 = tl.load(in_ptr9 + (22))
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp29 + tmp30
tmp34 = tmp33.to(tl.float32)
tmp35 = tmp31 * tmp34
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp35 * tmp38
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp19 * tmp42
tmp44 = tmp39 + tmp43
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/lv/clvyuwrkiqq5tjp6zq253fl6uzqefz5dja2cdbo5svtjygwezv6b.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add]
# Source node to ATen node mapping:
# Graph fragment:
# %convert_element_type_347 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_13, torch.float32), kwargs = {})
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_1 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_347, 0, 1), kwargs = {})
# %convert_element_type_348 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_14, torch.float32), kwargs = {})
# %select_scatter_default_2 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_348, 0, 0), kwargs = {})
# %add_184 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_1, %select_scatter_default_2), kwargs = {})
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_5 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_184, 0, 15), kwargs = {})
# %convert_element_type_379 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_21, torch.float32), kwargs = {})
# %select_scatter_default_8 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_379, 0, 1), kwargs = {})
# %convert_element_type_380 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_22, torch.float32), kwargs = {})
# %select_scatter_default_9 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_380, 0, 0), kwargs = {})
# %add_197 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_8, %select_scatter_default_9), kwargs = {})
# %select_scatter_default_12 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_197, 0, 14), kwargs = {})
# %add_209 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_5, %select_scatter_default_12), kwargs = {})
# %convert_element_type_411 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_29, torch.float32), kwargs = {})
# %select_scatter_default_15 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_411, 0, 1), kwargs = {})
# %convert_element_type_412 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_30, torch.float32), kwargs = {})
# %select_scatter_default_16 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_412, 0, 0), kwargs = {})
# %add_213 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_15, %select_scatter_default_16), kwargs = {})
# %select_scatter_default_19 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_213, 0, 13), kwargs = {})
# %add_225 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_209, %select_scatter_default_19), kwargs = {})
# %convert_element_type_443 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_37, torch.float32), kwargs = {})
# %select_scatter_default_22 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_443, 0, 0), kwargs = {})
# %select_scatter_default_25 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_22, 0, 12), kwargs = {})
# %add_240 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_225, %select_scatter_default_25), kwargs = {})
# %convert_element_type_474 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_44, torch.float32), kwargs = {})
# %select_scatter_default_28 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_474, 0, 0), kwargs = {})
# %select_scatter_default_31 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_28, 0, 11), kwargs = {})
# %add_255 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_240, %select_scatter_default_31), kwargs = {})
# %convert_element_type_506 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_52, torch.float32), kwargs = {})
# %select_scatter_default_35 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_506, 0, 0), kwargs = {})
# %select_scatter_default_38 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_35, 0, 10), kwargs = {})
# %add_270 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_255, %select_scatter_default_38), kwargs = {})
# %convert_element_type_538 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_60, torch.float32), kwargs = {})
# %select_scatter_default_42 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_538, 0, 0), kwargs = {})
# %select_scatter_default_45 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_42, 0, 9), kwargs = {})
# %add_286 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_270, %select_scatter_default_45), kwargs = {})
# %convert_element_type_570 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_68, torch.float32), kwargs = {})
# %select_scatter_default_49 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_570, 0, 0), kwargs = {})
# %select_scatter_default_52 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_49, 0, 8), kwargs = {})
# %add_302 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_286, %select_scatter_default_52), kwargs = {})
# %convert_element_type_613 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_78, torch.float32), kwargs = {})
# %select_scatter_default_58 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_613, 0, 0), kwargs = {})
# %select_scatter_default_61 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_58, 0, 6), kwargs = {})
# %add_323 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_302, %select_scatter_default_61), kwargs = {})
# %convert_element_type_644 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_85, torch.float32), kwargs = {})
# %select_scatter_default_64 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_644, 0, 0), kwargs = {})
# %select_scatter_default_67 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_64, 0, 5), kwargs = {})
# %add_339 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_323, %select_scatter_default_67), kwargs = {})
# %convert_element_type_675 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_92, torch.float32), kwargs = {})
# %select_scatter_default_70 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_675, 0, 0), kwargs = {})
# %select_scatter_default_73 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_70, 0, 4), kwargs = {})
# %add_354 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_339, %select_scatter_default_73), kwargs = {})
# %convert_element_type_706 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_99, torch.float32), kwargs = {})
# %select_scatter_default_76 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_706, 0, 0), kwargs = {})
# %select_scatter_default_79 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_76, 0, 3), kwargs = {})
# %add_370 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_354, %select_scatter_default_79), kwargs = {})
triton_poi_fused__to_copy_add_select_backward_41 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_41', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 32},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 15, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_select_backward_41(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex // 2
x0 = (xindex % 2)
x2 = xindex
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK])
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32)
tmp14 = tl.broadcast_to(tmp13, [XBLOCK])
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32)
tmp22 = tl.broadcast_to(tmp21, [XBLOCK])
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK])
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32)
tmp35 = tl.broadcast_to(tmp34, [XBLOCK])
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32)
tmp39 = tl.broadcast_to(tmp38, [XBLOCK])
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32)
tmp48 = tl.broadcast_to(tmp47, [XBLOCK])
tmp55 = tl.load(in_ptr7 + (0)).to(tl.float32)
tmp56 = tl.broadcast_to(tmp55, [XBLOCK])
tmp63 = tl.load(in_ptr8 + (0)).to(tl.float32)
tmp64 = tl.broadcast_to(tmp63, [XBLOCK])
tmp71 = tl.load(in_ptr9 + (0)).to(tl.float32)
tmp72 = tl.broadcast_to(tmp71, [XBLOCK])
tmp79 = tl.load(in_ptr10 + (0)).to(tl.float32)
tmp80 = tl.broadcast_to(tmp79, [XBLOCK])
tmp87 = tl.load(in_ptr11 + (0)).to(tl.float32)
tmp88 = tl.broadcast_to(tmp87, [XBLOCK])
tmp95 = tl.load(in_ptr12 + (0)).to(tl.float32)
tmp96 = tl.broadcast_to(tmp95, [XBLOCK])
tmp103 = tl.load(in_ptr13 + (0)).to(tl.float32)
tmp104 = tl.broadcast_to(tmp103, [XBLOCK])
tmp111 = tl.load(in_ptr14 + (0)).to(tl.float32)
tmp112 = tl.broadcast_to(tmp111, [XBLOCK])
tmp0 = x1
tmp1 = tl.full([1], 15, tl.int32)
tmp2 = tmp0 == tmp1
tmp3 = x0
tmp4 = tl.full([1], 1, tl.int32)
tmp5 = tmp3 == tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = 0.0
tmp10 = tl.where(tmp5, tmp8, tmp9)
tmp11 = tl.full([1], 0, tl.int32)
tmp12 = tmp3 == tmp11
tmp15 = tmp14.to(tl.float32)
tmp16 = tl.where(tmp12, tmp15, tmp9)
tmp17 = tmp10 + tmp16
tmp18 = tl.where(tmp2, tmp17, tmp9)
tmp19 = tl.full([1], 14, tl.int32)
tmp20 = tmp0 == tmp19
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.where(tmp5, tmp23, tmp9)
tmp27 = tmp26.to(tl.float32)
tmp28 = tl.where(tmp12, tmp27, tmp9)
tmp29 = tmp24 + tmp28
tmp30 = tl.where(tmp20, tmp29, tmp9)
tmp31 = tmp18 + tmp30
tmp32 = tl.full([1], 13, tl.int32)
tmp33 = tmp0 == tmp32
tmp36 = tmp35.to(tl.float32)
tmp37 = tl.where(tmp5, tmp36, tmp9)
tmp40 = tmp39.to(tl.float32)
tmp41 = tl.where(tmp12, tmp40, tmp9)
tmp42 = tmp37 + tmp41
tmp43 = tl.where(tmp33, tmp42, tmp9)
tmp44 = tmp31 + tmp43
tmp45 = tl.full([1], 12, tl.int32)
tmp46 = tmp0 == tmp45
tmp49 = tmp48.to(tl.float32)
tmp50 = tl.where(tmp12, tmp49, tmp9)
tmp51 = tl.where(tmp46, tmp50, tmp9)
tmp52 = tmp44 + tmp51
tmp53 = tl.full([1], 11, tl.int32)
tmp54 = tmp0 == tmp53
tmp57 = tmp56.to(tl.float32)
tmp58 = tl.where(tmp12, tmp57, tmp9)
tmp59 = tl.where(tmp54, tmp58, tmp9)
tmp60 = tmp52 + tmp59
tmp61 = tl.full([1], 10, tl.int32)
tmp62 = tmp0 == tmp61
tmp65 = tmp64.to(tl.float32)
tmp66 = tl.where(tmp12, tmp65, tmp9)
tmp67 = tl.where(tmp62, tmp66, tmp9)
tmp68 = tmp60 + tmp67
tmp69 = tl.full([1], 9, tl.int32)
tmp70 = tmp0 == tmp69
tmp73 = tmp72.to(tl.float32)
tmp74 = tl.where(tmp12, tmp73, tmp9)
tmp75 = tl.where(tmp70, tmp74, tmp9)
tmp76 = tmp68 + tmp75
tmp77 = tl.full([1], 8, tl.int32)
tmp78 = tmp0 == tmp77
tmp81 = tmp80.to(tl.float32)
tmp82 = tl.where(tmp12, tmp81, tmp9)
tmp83 = tl.where(tmp78, tmp82, tmp9)
tmp84 = tmp76 + tmp83
tmp85 = tl.full([1], 6, tl.int32)
tmp86 = tmp0 == tmp85
tmp89 = tmp88.to(tl.float32)
tmp90 = tl.where(tmp12, tmp89, tmp9)
tmp91 = tl.where(tmp86, tmp90, tmp9)
tmp92 = tmp84 + tmp91
tmp93 = tl.full([1], 5, tl.int32)
tmp94 = tmp0 == tmp93
tmp97 = tmp96.to(tl.float32)
tmp98 = tl.where(tmp12, tmp97, tmp9)
tmp99 = tl.where(tmp94, tmp98, tmp9)
tmp100 = tmp92 + tmp99
tmp101 = tl.full([1], 4, tl.int32)
tmp102 = tmp0 == tmp101
tmp105 = tmp104.to(tl.float32)
tmp106 = tl.where(tmp12, tmp105, tmp9)
tmp107 = tl.where(tmp102, tmp106, tmp9)
tmp108 = tmp100 + tmp107
tmp109 = tl.full([1], 3, tl.int32)
tmp110 = tmp0 == tmp109
tmp113 = tmp112.to(tl.float32)
tmp114 = tl.where(tmp12, tmp113, tmp9)
tmp115 = tl.where(tmp110, tmp114, tmp9)
tmp116 = tmp108 + tmp115
tl.store(out_ptr0 + (x2), tmp116, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/o4/co4wxq5g7foyzmi2ihlzwhjvufk3vmnc4dgakkdaqoxojgejv4d2.py
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_7 => convert_element_type_62
# Graph fragment:
# %mul_770 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_392, %select_20), kwargs = {})
# %convert_element_type_62 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_32, torch.float32), kwargs = {})
# %convert_element_type_739 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_770, torch.float32), kwargs = {})
# %mul_772 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %convert_element_type_62), kwargs = {})
# %mul_773 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %rsqrt_11), kwargs = {})
# %sum_108 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_772, [3], True), kwargs = {})
# %div_54 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_71, 128), kwargs = {})
# %pow_200 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_62, 1.0), kwargs = {})
# %mul_776 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_200, 2.0), kwargs = {})
# %mul_777 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_54, %mul_776), kwargs = {})
# %add_376 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_773, %mul_777), kwargs = {})
# %convert_element_type_740 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_376, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_42 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_42', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_42(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (52))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (52))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/bo/cbowtcvxfb4pbtd63playhswlkgy4jhcpl4b3bcqkbz3mfy4xrzq.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {})
# %mul_800 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_18), kwargs = {})
triton_poi_fused_add_mul_43 = async_compile.triton('triton_poi_fused_add_mul_43', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_43(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (20))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/p3/cp356rh6ewfn6qtk7pjhx6degjloji6ndbxsbgpu5ffy373kln4d.py
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {})
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {})
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {})
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44 = async_compile.triton('triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 51463168
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/qe/cqetkj23jbj4oymf7m4upmjbom7qob3kerhbagz6hn5pzn6li6pu.py
# Topologically Sorted Source Nodes: [loss, v_4], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# v_4 => convert_element_type_42
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {})
# %mul_812 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_412, %select_13), kwargs = {})
# %convert_element_type_42 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_23, torch.float32), kwargs = {})
# %convert_element_type_771 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_812, torch.float32), kwargs = {})
# %mul_814 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %convert_element_type_42), kwargs = {})
# %mul_815 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %rsqrt_7), kwargs = {})
# %sum_116 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_814, [3], True), kwargs = {})
# %div_58 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_75, 128), kwargs = {})
# %pow_209 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_42, 1.0), kwargs = {})
# %mul_818 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_209, 2.0), kwargs = {})
# %mul_819 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_58, %mul_818), kwargs = {})
# %add_393 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_815, %mul_819), kwargs = {})
# %convert_element_type_772 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_393, torch.bfloat16), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {})
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {})
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {})
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45 = async_compile.triton('triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*i32', 'in_ptr5': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', 'mutated_arg_names': ['out_ptr2'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (50))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (50))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
tmp34 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last')
tmp44 = tl.load(in_ptr1 + (77))
tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK])
tmp48 = tl.load(in_ptr1 + (51))
tmp49 = tl.broadcast_to(tmp48, [XBLOCK, R0_BLOCK])
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp43 = tl.load(in_ptr5 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tmp35 = tmp34.to(tl.int64)
tmp36 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32)
tmp37 = tmp35 + tmp36
tmp38 = tmp35 < 0
tmp39 = tl.where(tmp38, tmp37, tmp35)
tl.device_assert((0 <= tmp39) & (tmp39 < 50257), "index out of bounds: 0 <= tmp39 < 50257")
tmp41 = tl.full([1, 1], -1, tl.int64)
tmp42 = tmp35 == tmp41
tmp46 = tmp45.to(tl.float32)
tmp47 = tmp43 * tmp46
tmp50 = tmp49.to(tl.float32)
tmp51 = tmp12 * tmp50
tmp52 = tmp47 + tmp51
tmp53 = tmp52.to(tl.float32)
tmp54 = 0.0
tmp55 = tl.where(tmp42, tmp54, tmp53)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
tl.atomic_add(out_ptr2 + (r0_2 + 128*x0 + 1024*tmp39), tmp55, r0_mask, sem='relaxed')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/xd/cxde4n2e7lw45tanddyed35fd2lchomkxmgk6ds234dvjtxmqmgj.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {})
# %mul_756 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_26), kwargs = {})
# %add_367 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_352, %mul_756), kwargs = {})
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {})
# %mul_798 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_19), kwargs = {})
# %mul_799 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %convert_element_type_11), kwargs = {})
# %sum_111 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_799,), kwargs = {})
# %add_385 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_367, %mul_798), kwargs = {})
# %mul_801 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %add_24), kwargs = {})
# %sum_112 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_801,), kwargs = {})
# %add_400 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_390, %view_346), kwargs = {})
# %mul_840 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %select_12), kwargs = {})
# %mul_841 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %convert_element_type_11), kwargs = {})
# %sum_119 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_841,), kwargs = {})
# %add_402 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_385, %mul_840), kwargs = {})
# %mul_842 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %select_11), kwargs = {})
# %mul_843 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_400, %add_12), kwargs = {})
# %sum_120 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_843,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_46 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_46', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_46', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 15, 'num_reduction': 4, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_46(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (23))
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp9 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (21))
tmp13 = tl.broadcast_to(tmp12, [R0_BLOCK])
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp20 = tl.load(in_ptr2 + (19))
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp25 = tl.load(in_ptr2 + (18))
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp29 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp31 = tl.load(in_ptr8 + (x0), None, eviction_policy='evict_last')
tmp38 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp47 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tmp1 + tmp2
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp3 * tmp6
tmp8 = tmp0 + tmp7
tmp11 = tmp9 + tmp10
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp11 * tmp14
tmp16 = tmp8 + tmp15
tmp19 = tmp17 + tmp18
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp19 * tmp22
tmp24 = tmp16 + tmp23
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp19 * tmp27
tmp30 = tmp29.to(tl.float32)
tmp32 = tmp30 * tmp31
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp11 * tmp33
tmp35 = tl.broadcast_to(tmp34, [R0_BLOCK])
tmp37 = triton_helpers.promote_to_tensor(tl.sum(tmp35, 0))
tmp39 = tmp11 * tmp38
tmp40 = tl.broadcast_to(tmp39, [R0_BLOCK])
tmp42 = triton_helpers.promote_to_tensor(tl.sum(tmp40, 0))
tmp43 = tmp19 * tmp33
tmp44 = tl.broadcast_to(tmp43, [R0_BLOCK])
tmp46 = triton_helpers.promote_to_tensor(tl.sum(tmp44, 0))
tmp48 = tmp19 * tmp47
tmp49 = tl.broadcast_to(tmp48, [R0_BLOCK])
tmp51 = triton_helpers.promote_to_tensor(tl.sum(tmp49, 0))
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp24, None)
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp28, None)
tl.store(out_ptr1 + (x0), tmp37, None)
tl.store(out_ptr2 + (x0), tmp42, None)
tl.store(out_ptr3 + (x0), tmp46, None)
tl.store(out_ptr4 + (x0), tmp51, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/d6/cd6bxpx2rotlzolnhqlsypn4c3ofsxep56v7rmvfvdcg45tofh7f.py
# Topologically Sorted Source Nodes: [loss, v_1], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# v_1 => convert_element_type_22
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_408 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_216, %view_355), kwargs = {})
# %mul_854 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_432, %select_6), kwargs = {})
# %convert_element_type_22 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_14, torch.float32), kwargs = {})
# %convert_element_type_803 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_854, torch.float32), kwargs = {})
# %mul_856 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_803, %convert_element_type_22), kwargs = {})
# %mul_857 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_803, %rsqrt_3), kwargs = {})
# %sum_124 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_856, [3], True), kwargs = {})
# %div_62 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_79, 128), kwargs = {})
# %pow_218 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_22, 1.0), kwargs = {})
# %mul_860 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_218, 2.0), kwargs = {})
# %mul_861 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_62, %mul_860), kwargs = {})
# %add_410 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_857, %mul_861), kwargs = {})
# %convert_element_type_804 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_410, torch.bfloat16), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_830 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_408, torch.float32), kwargs = {})
# %where_27 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_830), kwargs = {})
# %index_put_7 : [num_users=1] = call_function[target=torch.ops.aten.index_put_.default](args = (%full_default_169, [%convert_element_type_822], %where_27, True), kwargs = {})
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47 = async_compile.triton('triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*i32', 'in_ptr5': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47', 'mutated_arg_names': ['out_ptr2'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (48))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (48))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
tmp34 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last')
tmp44 = tl.load(in_ptr1 + (75))
tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK])
tmp48 = tl.load(in_ptr1 + (49))
tmp49 = tl.broadcast_to(tmp48, [XBLOCK, R0_BLOCK])
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp43 = tl.load(in_ptr5 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tmp35 = tmp34.to(tl.int64)
tmp36 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32)
tmp37 = tmp35 + tmp36
tmp38 = tmp35 < 0
tmp39 = tl.where(tmp38, tmp37, tmp35)
tl.device_assert((0 <= tmp39) & (tmp39 < 50257), "index out of bounds: 0 <= tmp39 < 50257")
tmp41 = tl.full([1, 1], -1, tl.int64)
tmp42 = tmp35 == tmp41
tmp46 = tmp45.to(tl.float32)
tmp47 = tmp43 * tmp46
tmp50 = tmp49.to(tl.float32)
tmp51 = tmp12 * tmp50
tmp52 = tmp47 + tmp51
tmp53 = tmp52.to(tl.float32)
tmp54 = 0.0
tmp55 = tl.where(tmp42, tmp54, tmp53)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
tl.atomic_add(out_ptr2 + (r0_2 + 128*x0 + 1024*tmp39), tmp55, r0_mask, sem='relaxed')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/ku/cku3y346hkc2rt3ol4ul22b7a2hr3nvz4j76utjfqol6py33tpfs.py
# Topologically Sorted Source Nodes: [loss, x], Original ATen: [aten.nll_loss_forward, aten._to_copy, aten.mul, aten.add, aten.sum, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_417 : [num_users=3] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_407, %view_358), kwargs = {})
# %mul_882 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_417, %select_5), kwargs = {})
# %mul_883 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_417, %convert_element_type_11), kwargs = {})
# %sum_127 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_883,), kwargs = {})
# %add_419 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_402, %mul_882), kwargs = {})
# %mul_884 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_417, %select_4), kwargs = {})
# %add_420 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_419, %mul_884), kwargs = {})
# %convert_element_type_819 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_420, torch.float32), kwargs = {})
# %mul_886 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_819, %convert_element_type_10), kwargs = {})
# %sum_129 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_886, [2], True), kwargs = {})
# %convert_element_type_821 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%squeeze_1, torch.float32), kwargs = {})
# %where_24 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_821), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %index_put_4 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_24, True), kwargs = {})
triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48 = async_compile.triton('triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*i32', 'out_ptr1': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48', 'mutated_arg_names': ['out_ptr3'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 8, 'num_reduction': 2, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr1, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp4 = tl.load(in_ptr3 + (17))
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp9 = tl.load(in_ptr3 + (16))
tmp10 = tl.broadcast_to(tmp9, [R0_BLOCK])
tmp15 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (x0), None, eviction_policy='evict_last')
tmp28 = tl.load(in_ptr6 + (x0), None, eviction_policy='evict_last')
tmp3 = tmp1 + tmp2
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp3 * tmp6
tmp8 = tmp0 + tmp7
tmp11 = tmp10.to(tl.float32)
tmp12 = tmp3 * tmp11
tmp13 = tmp8 + tmp12
tmp14 = tmp13.to(tl.float32)
tmp16 = tmp15.to(tl.float32)
tmp18 = tmp16 * tmp17
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp3 * tmp19
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp24 = tmp14 * tmp16
tmp25 = tl.broadcast_to(tmp24, [R0_BLOCK])
tmp27 = triton_helpers.promote_to_tensor(tl.sum(tmp25, 0))
tmp29 = tmp28.to(tl.int64)
tmp30 = tl.full([R0_BLOCK], 50257, tl.int32)
tmp31 = tmp29 + tmp30
tmp32 = tmp29 < 0
tmp33 = tl.where(tmp32, tmp31, tmp29)
tl.device_assert((0 <= tmp33) & (tmp33 < 50257), "index out of bounds: 0 <= tmp33 < 50257")
tmp35 = tl.full([1], -1, tl.int64)
tmp36 = tmp29 == tmp35
tmp37 = tmp14 * tmp17
tmp38 = -0.5
tmp39 = tmp27 * tmp38
tmp40 = tmp17 * tmp17
tmp41 = tmp40 * tmp17
tmp42 = tmp39 * tmp41
tmp43 = 0.0009765625
tmp44 = tmp42 * tmp43
tmp45 = 2.0
tmp46 = tmp16 * tmp45
tmp47 = tmp44 * tmp46
tmp48 = tmp37 + tmp47
tmp49 = tmp48.to(tl.float32)
tmp50 = tmp49.to(tl.float32)
tmp51 = 0.0
tmp52 = tl.where(tmp36, tmp51, tmp50)
tl.atomic_add(out_ptr3 + (tl.broadcast_to(r0_1 + 1024*tmp33, [R0_BLOCK])), tmp52, None, sem='relaxed')
tl.store(out_ptr1 + (x0), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/lx/clxwsbf6ichdepzkafcu43p6rfa33ojvhrn7mlshq6ooqw2t4b5b.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add, aten.slice_backward]
# Source node to ATen node mapping:
# Graph fragment:
# %convert_element_type_491 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_50, torch.float32), kwargs = {})
# %full_default_65 : [num_users=3] = call_function[target=torch.ops.aten.full.default](args = ([16], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_33 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_65, %convert_element_type_491, 0, 2), kwargs = {})
# %convert_element_type_523 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_58, torch.float32), kwargs = {})
# %select_scatter_default_40 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_65, %convert_element_type_523, 0, 4), kwargs = {})
# %add_272 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_33, %select_scatter_default_40), kwargs = {})
# %convert_element_type_555 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_66, torch.float32), kwargs = {})
# %select_scatter_default_47 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_65, %convert_element_type_555, 0, 6), kwargs = {})
# %add_288 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_272, %select_scatter_default_47), kwargs = {})
# %full_default_165 : [num_users=3] = call_function[target=torch.ops.aten.full.default](args = ([80], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %slice_scatter_default_15 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_165, %view_360, 0, 48, 80), kwargs = {})
# %slice_scatter_default_16 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_165, %view_361, 0, 16, 48), kwargs = {})
# %add_424 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%slice_scatter_default_15, %slice_scatter_default_16), kwargs = {})
# %slice_scatter_default_17 : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_165, %add_288, 0, 0, 16), kwargs = {})
# %add_425 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_424, %slice_scatter_default_17), kwargs = {})
triton_poi_fused__to_copy_add_select_backward_slice_backward_49 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_slice_backward_49', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 128},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'in_ptr15': '*bf16', 'in_ptr16': '*bf16', 'in_ptr17': '*bf16', 'in_ptr18': '*bf16', 'in_ptr19': '*bf16', 'in_ptr20': '*bf16', 'in_ptr21': '*bf16', 'in_ptr22': '*bf16', 'in_ptr23': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_slice_backward_49', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 24, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_select_backward_slice_backward_49(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, in_ptr18, in_ptr19, in_ptr20, in_ptr21, in_ptr22, in_ptr23, xnumel, XBLOCK : tl.constexpr):
xnumel = 80
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = x0
tmp1 = tl.full([1], 48, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.load(in_ptr0 + (2*(((((-48) + x0) // 2) % 16)) + ((x0 % 2))), xmask & tmp2, other=0.0)
tmp4 = ((((-48) + x0) // 2) % 16)
tmp5 = tl.full([1], 2, tl.int32)
tmp6 = tmp4 == tmp5
tmp7 = (x0 % 2)
tmp8 = tl.full([1], 1, tl.int32)
tmp9 = tmp7 == tmp8
tmp10 = tl.load(in_ptr1 + (0)).to(tl.float32)
tmp11 = tl.broadcast_to(tmp10, [XBLOCK])
tmp12 = tl.where(tmp2, tmp11, 0.0)
tmp13 = tmp12.to(tl.float32)
tmp14 = 0.0
tmp15 = tl.where(tmp9, tmp13, tmp14)
tmp16 = tl.full([1], 0, tl.int32)
tmp17 = tmp7 == tmp16
tmp18 = tl.load(in_ptr2 + (0)).to(tl.float32)
tmp19 = tl.broadcast_to(tmp18, [XBLOCK])
tmp20 = tl.where(tmp2, tmp19, 0.0)
tmp21 = tmp20.to(tl.float32)
tmp22 = tl.where(tmp17, tmp21, tmp14)
tmp23 = tmp15 + tmp22
tmp24 = tl.where(tmp6, tmp23, tmp14)
tmp25 = tmp3 + tmp24
tmp26 = tmp4 == tmp8
tmp27 = tl.load(in_ptr3 + (0)).to(tl.float32)
tmp28 = tl.broadcast_to(tmp27, [XBLOCK])
tmp29 = tl.where(tmp2, tmp28, 0.0)
tmp30 = tmp29.to(tl.float32)
tmp31 = tl.where(tmp9, tmp30, tmp14)
tmp32 = tl.load(in_ptr4 + (0)).to(tl.float32)
tmp33 = tl.broadcast_to(tmp32, [XBLOCK])
tmp34 = tl.where(tmp2, tmp33, 0.0)
tmp35 = tmp34.to(tl.float32)
tmp36 = tl.where(tmp17, tmp35, tmp14)
tmp37 = tmp31 + tmp36
tmp38 = tl.where(tmp26, tmp37, tmp14)
tmp39 = tmp25 + tmp38
tmp40 = tmp4 == tmp16
tmp41 = tl.load(in_ptr5 + (0)).to(tl.float32)
tmp42 = tl.broadcast_to(tmp41, [XBLOCK])
tmp43 = tl.where(tmp2, tmp42, 0.0)
tmp44 = tmp43.to(tl.float32)
tmp45 = tl.where(tmp9, tmp44, tmp14)
tmp46 = tl.load(in_ptr6 + (0)).to(tl.float32)
tmp47 = tl.broadcast_to(tmp46, [XBLOCK])
tmp48 = tl.where(tmp2, tmp47, 0.0)
tmp49 = tmp48.to(tl.float32)
tmp50 = tl.where(tmp17, tmp49, tmp14)
tmp51 = tmp45 + tmp50
tmp52 = tl.where(tmp40, tmp51, tmp14)
tmp53 = tmp39 + tmp52
tmp54 = tl.full(tmp53.shape, 0.0, tmp53.dtype)
tmp55 = tl.where(tmp2, tmp53, tmp54)
tmp56 = 0.0
tmp57 = tl.where(tmp2, tmp55, tmp56)
tmp58 = tl.full([1], 16, tl.int64)
tmp59 = tmp0 >= tmp58
tmp60 = tmp0 < tmp1
tmp61 = tmp59 & tmp60
tmp62 = tl.load(in_ptr7 + (2*(((((-16) + x0) // 2) % 16)) + ((x0 % 2))), xmask & tmp61, other=0.0)
tmp63 = ((((-16) + x0) // 2) % 16)
tmp64 = tl.full([1], 6, tl.int32)
tmp65 = tmp63 == tmp64
tmp66 = (x0 % 2)
tmp67 = tl.full([1], 1, tl.int32)
tmp68 = tmp66 == tmp67
tmp69 = tl.load(in_ptr8 + (0)).to(tl.float32)
tmp70 = tl.broadcast_to(tmp69, [XBLOCK])
tmp71 = tl.where(tmp61, tmp70, 0.0)
tmp72 = tmp71.to(tl.float32)
tmp73 = 0.0
tmp74 = tl.where(tmp68, tmp72, tmp73)
tmp75 = tl.full([1], 0, tl.int32)
tmp76 = tmp66 == tmp75
tmp77 = tl.load(in_ptr9 + (0)).to(tl.float32)
tmp78 = tl.broadcast_to(tmp77, [XBLOCK])
tmp79 = tl.where(tmp61, tmp78, 0.0)
tmp80 = tmp79.to(tl.float32)
tmp81 = tl.where(tmp76, tmp80, tmp73)
tmp82 = tmp74 + tmp81
tmp83 = tl.where(tmp65, tmp82, tmp73)
tmp84 = tmp62 + tmp83
tmp85 = tl.full([1], 5, tl.int32)
tmp86 = tmp63 == tmp85
tmp87 = tl.load(in_ptr10 + (0)).to(tl.float32)
tmp88 = tl.broadcast_to(tmp87, [XBLOCK])
tmp89 = tl.where(tmp61, tmp88, 0.0)
tmp90 = tmp89.to(tl.float32)
tmp91 = tl.where(tmp68, tmp90, tmp73)
tmp92 = tl.load(in_ptr11 + (0)).to(tl.float32)
tmp93 = tl.broadcast_to(tmp92, [XBLOCK])
tmp94 = tl.where(tmp61, tmp93, 0.0)
tmp95 = tmp94.to(tl.float32)
tmp96 = tl.where(tmp76, tmp95, tmp73)
tmp97 = tmp91 + tmp96
tmp98 = tl.where(tmp86, tmp97, tmp73)
tmp99 = tmp84 + tmp98
tmp100 = tl.full([1], 4, tl.int32)
tmp101 = tmp63 == tmp100
tmp102 = tl.load(in_ptr12 + (0)).to(tl.float32)
tmp103 = tl.broadcast_to(tmp102, [XBLOCK])
tmp104 = tl.where(tmp61, tmp103, 0.0)
tmp105 = tmp104.to(tl.float32)
tmp106 = tl.where(tmp68, tmp105, tmp73)
tmp107 = tl.load(in_ptr13 + (0)).to(tl.float32)
tmp108 = tl.broadcast_to(tmp107, [XBLOCK])
tmp109 = tl.where(tmp61, tmp108, 0.0)
tmp110 = tmp109.to(tl.float32)
tmp111 = tl.where(tmp76, tmp110, tmp73)
tmp112 = tmp106 + tmp111
tmp113 = tl.where(tmp101, tmp112, tmp73)
tmp114 = tmp99 + tmp113
tmp115 = tl.full([1], 3, tl.int32)
tmp116 = tmp63 == tmp115
tmp117 = tl.load(in_ptr14 + (0)).to(tl.float32)
tmp118 = tl.broadcast_to(tmp117, [XBLOCK])
tmp119 = tl.where(tmp61, tmp118, 0.0)
tmp120 = tmp119.to(tl.float32)
tmp121 = tl.where(tmp68, tmp120, tmp73)
tmp122 = tl.load(in_ptr15 + (0)).to(tl.float32)
tmp123 = tl.broadcast_to(tmp122, [XBLOCK])
tmp124 = tl.where(tmp61, tmp123, 0.0)
tmp125 = tmp124.to(tl.float32)
tmp126 = tl.where(tmp76, tmp125, tmp73)
tmp127 = tmp121 + tmp126
tmp128 = tl.where(tmp116, tmp127, tmp73)
tmp129 = tmp114 + tmp128
tmp130 = tl.full([1], 2, tl.int32)
tmp131 = tmp63 == tmp130
tmp132 = tl.load(in_ptr16 + (0)).to(tl.float32)
tmp133 = tl.broadcast_to(tmp132, [XBLOCK])
tmp134 = tl.where(tmp61, tmp133, 0.0)
tmp135 = tmp134.to(tl.float32)
tmp136 = tl.where(tmp68, tmp135, tmp73)
tmp137 = tl.load(in_ptr17 + (0)).to(tl.float32)
tmp138 = tl.broadcast_to(tmp137, [XBLOCK])
tmp139 = tl.where(tmp61, tmp138, 0.0)
tmp140 = tmp139.to(tl.float32)
tmp141 = tl.where(tmp76, tmp140, tmp73)
tmp142 = tmp136 + tmp141
tmp143 = tl.where(tmp131, tmp142, tmp73)
tmp144 = tmp129 + tmp143
tmp145 = tmp63 == tmp67
tmp146 = tl.load(in_ptr18 + (0)).to(tl.float32)
tmp147 = tl.broadcast_to(tmp146, [XBLOCK])
tmp148 = tl.where(tmp61, tmp147, 0.0)
tmp149 = tmp148.to(tl.float32)
tmp150 = tl.where(tmp68, tmp149, tmp73)
tmp151 = tl.load(in_ptr19 + (0)).to(tl.float32)
tmp152 = tl.broadcast_to(tmp151, [XBLOCK])
tmp153 = tl.where(tmp61, tmp152, 0.0)
tmp154 = tmp153.to(tl.float32)
tmp155 = tl.where(tmp76, tmp154, tmp73)
tmp156 = tmp150 + tmp155
tmp157 = tl.where(tmp145, tmp156, tmp73)
tmp158 = tmp144 + tmp157
tmp159 = tmp63 == tmp75
tmp160 = tl.load(in_ptr20 + (0)).to(tl.float32)
tmp161 = tl.broadcast_to(tmp160, [XBLOCK])
tmp162 = tl.where(tmp61, tmp161, 0.0)
tmp163 = tmp162.to(tl.float32)
tmp164 = tl.where(tmp68, tmp163, tmp73)
tmp165 = tl.where(tmp76, tmp163, tmp73)
tmp166 = tmp164 + tmp165
tmp167 = tl.where(tmp159, tmp166, tmp73)
tmp168 = tmp158 + tmp167
tmp169 = tl.full(tmp168.shape, 0.0, tmp168.dtype)
tmp170 = tl.where(tmp61, tmp168, tmp169)
tmp171 = tl.where(tmp61, tmp170, tmp56)
tmp172 = tmp57 + tmp171
tmp173 = tmp0 < tmp58
tmp174 = x0
tmp175 = tl.full([1], 2, tl.int32)
tmp176 = tmp174 == tmp175
tmp177 = tl.load(in_ptr21 + (0)).to(tl.float32)
tmp178 = tl.broadcast_to(tmp177, [XBLOCK])
tmp179 = tl.where(tmp173, tmp178, 0.0)
tmp180 = tmp179.to(tl.float32)
tmp181 = 0.0
tmp182 = tl.where(tmp176, tmp180, tmp181)
tmp183 = tl.full([1], 4, tl.int32)
tmp184 = tmp174 == tmp183
tmp185 = tl.load(in_ptr22 + (0)).to(tl.float32)
tmp186 = tl.broadcast_to(tmp185, [XBLOCK])
tmp187 = tl.where(tmp173, tmp186, 0.0)
tmp188 = tmp187.to(tl.float32)
tmp189 = tl.where(tmp184, tmp188, tmp181)
tmp190 = tmp182 + tmp189
tmp191 = tl.full([1], 6, tl.int32)
tmp192 = tmp174 == tmp191
tmp193 = tl.load(in_ptr23 + (0)).to(tl.float32)
tmp194 = tl.broadcast_to(tmp193, [XBLOCK])
tmp195 = tl.where(tmp173, tmp194, 0.0)
tmp196 = tmp195.to(tl.float32)
tmp197 = tl.where(tmp192, tmp196, tmp181)
tmp198 = tmp190 + tmp197
tmp199 = tl.full(tmp198.shape, 0.0, tmp198.dtype)
tmp200 = tl.where(tmp173, tmp198, tmp199)
tmp201 = tl.where(tmp173, tmp200, tmp56)
tmp202 = tmp172 + tmp201
tl.store(in_out_ptr0 + (x0), tmp202, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/xa/cxacj2l5hljhfhgtubky2uytsytumvxkk5obolbxfdymabkc3jde.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
# Source node to ATen node mapping:
# Graph fragment:
# %convert_element_type_823 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%index_put_4, torch.bfloat16), kwargs = {})
triton_poi_fused_embedding_dense_backward_50 = async_compile.triton('triton_poi_fused_embedding_dense_backward_50', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_embedding_dense_backward_50', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_embedding_dense_backward_50(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 51463168
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmp_fz92xz8/kg/ckgmdoifhue6qhtyygsdiopbhe342puvu34bhdcnkwnjowz6irqm.py
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_374 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_192, %view_331), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_824 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_374, torch.float32), kwargs = {})
# %where_25 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_824), kwargs = {})
# %index_put_5 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_25, True), kwargs = {})
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51 = async_compile.triton('triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i32', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51', 'mutated_arg_names': ['out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 0, 'backend_hash': 'CE326918D8FE98067798167ABD26A2E4EDFD110D9ECD4380441C64512F6D164E', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x1 = xindex // 1024
x2 = xindex
x0 = (xindex % 1024)
tmp0 = tl.load(in_ptr0 + (x1), None, eviction_policy='evict_last')
tmp9 = tl.load(in_ptr1 + (x2), None).to(tl.float32)
tmp10 = tl.load(in_ptr2 + (79))
tmp11 = tl.broadcast_to(tmp10, [XBLOCK])
tmp14 = tl.load(in_ptr3 + (x2), None).to(tl.float32)
tmp15 = tl.load(in_ptr2 + (53))
tmp16 = tl.broadcast_to(tmp15, [XBLOCK])
tmp1 = tmp0.to(tl.int64)
tmp2 = tl.full([XBLOCK], 50257, tl.int32)
tmp3 = tmp1 + tmp2
tmp4 = tmp1 < 0
tmp5 = tl.where(tmp4, tmp3, tmp1)
tl.device_assert((0 <= tmp5) & (tmp5 < 50257), "index out of bounds: 0 <= tmp5 < 50257")
tmp7 = tl.full([1], -1, tl.int64)
tmp8 = tmp1 == tmp7
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp9 * tmp12
tmp17 = tmp16.to(tl.float32)
tmp18 = tmp14 * tmp17
tmp19 = tmp13 + tmp18
tmp20 = tmp19.to(tl.float32)
tmp21 = 0.0
tmp22 = tl.where(tmp8, tmp21, tmp20)
tl.atomic_add(out_ptr0 + (x0 + 1024*tmp5), tmp22, None, sem='relaxed')
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_7, primals_86, embedding, embedding_1, embedding_2, cumsum, unsqueeze_9, unsqueeze_13, clamp_max, clamp_max_1, convert_element_type_2, clone_4, convert_element_type_4, clone_7, clamp_max_2, clamp_max_3, convert_element_type_6, clone_10, convert_element_type_8, clone_13, embedding_3, rsqrt, view_9, getitem_12, getitem_13, getitem_14, rsqrt_1, rsqrt_2, unsqueeze_44, unsqueeze_46, rsqrt_3, permute_5, permute_6, permute_7, getitem_19, getitem_20, add_10, rsqrt_4, view_16, mm_2, view_18, add_12, view_21, getitem_21, getitem_22, getitem_23, rsqrt_5, rsqrt_6, unsqueeze_52, unsqueeze_54, rsqrt_7, permute_13, permute_14, permute_15, getitem_28, getitem_29, add_22, rsqrt_8, view_28, mm_6, view_30, add_24, view_33, getitem_30, getitem_31, getitem_32, rsqrt_9, rsqrt_10, unsqueeze_60, unsqueeze_62, rsqrt_11, permute_21, permute_22, permute_23, getitem_37, getitem_38, add_34, rsqrt_12, view_40, mm_10, view_42, add_36, view_45, getitem_39, getitem_40, getitem_41, rsqrt_13, rsqrt_14, unsqueeze_68, unsqueeze_70, rsqrt_15, permute_29, permute_30, permute_31, getitem_46, getitem_47, add_45, rsqrt_16, view_51, mm_14, view_53, add_47, view_56, getitem_48, getitem_49, getitem_50, rsqrt_17, rsqrt_18, unsqueeze_76, unsqueeze_78, rsqrt_19, permute_37, permute_38, permute_39, getitem_55, getitem_56, add_56, rsqrt_20, view_62, mm_18, view_64, add_58, view_67, getitem_57, getitem_58, getitem_59, rsqrt_21, rsqrt_22, unsqueeze_84, unsqueeze_86, rsqrt_23, permute_45, permute_46, permute_47, getitem_64, getitem_65, add_67, rsqrt_24, view_73, mm_22, view_75, add_69, view_78, getitem_66, getitem_67, getitem_68, rsqrt_25, rsqrt_26, unsqueeze_92, unsqueeze_94, rsqrt_27, permute_53, permute_54, permute_55, getitem_73, getitem_74, add_78, rsqrt_28, view_84, mm_26, view_86, add_80, add_81, rsqrt_29, view_88, mm_28, view_90, add_83, view_93, getitem_75, getitem_76, getitem_77, rsqrt_30, rsqrt_31, unsqueeze_100, unsqueeze_102, rsqrt_32, permute_63, permute_64, permute_65, getitem_82, getitem_83, add_92, rsqrt_33, view_99, mm_32, view_101, add_95, view_104, getitem_84, getitem_85, getitem_86, rsqrt_34, rsqrt_35, unsqueeze_108, unsqueeze_110, rsqrt_36, permute_71, permute_72, permute_73, getitem_91, getitem_92, add_104, rsqrt_37, view_110, mm_36, view_112, add_107, view_115, getitem_93, getitem_94, getitem_95, rsqrt_38, rsqrt_39, unsqueeze_116, unsqueeze_118, rsqrt_40, permute_79, permute_80, permute_81, getitem_100, getitem_101, add_116, rsqrt_41, view_121, mm_40, view_123, add_119, view_126, getitem_102, getitem_103, getitem_104, rsqrt_42, rsqrt_43, unsqueeze_124, unsqueeze_126, rsqrt_44, permute_87, permute_88, permute_89, getitem_109, getitem_110, add_128, rsqrt_45, view_132, mm_44, view_134, add_130, view_137, getitem_111, getitem_112, getitem_113, rsqrt_46, rsqrt_47, unsqueeze_132, unsqueeze_134, rsqrt_48, permute_95, permute_96, permute_97, getitem_118, getitem_119, add_139, rsqrt_49, view_143, mm_48, view_145, add_141, view_148, getitem_120, getitem_121, getitem_122, rsqrt_50, rsqrt_51, unsqueeze_140, unsqueeze_142, rsqrt_52, permute_103, permute_104, permute_105, getitem_127, getitem_128, add_151, rsqrt_53, view_155, mm_52, view_157, add_153, view_160, getitem_129, getitem_130, getitem_131, rsqrt_54, rsqrt_55, unsqueeze_148, unsqueeze_150, rsqrt_56, permute_111, permute_112, permute_113, getitem_136, getitem_137, add_163, rsqrt_57, view_167, mm_56, view_169, add_165, view_172, getitem_138, getitem_139, getitem_140, rsqrt_58, rsqrt_59, unsqueeze_156, unsqueeze_158, rsqrt_60, permute_119, permute_120, permute_121, getitem_145, getitem_146, add_175, rsqrt_61, view_179, mm_60, view_181, add_177, rsqrt_62, view_183, mm_62, amax, log, convert_element_type_324, permute_129, permute_133, permute_137, permute_141, permute_149, permute_153, permute_157, permute_161, permute_169, permute_173, permute_177, permute_181, permute_189, permute_193, permute_197, permute_201, permute_209, permute_213, permute_217, permute_221, permute_229, permute_233, permute_237, permute_241, permute_249, permute_253, permute_257, permute_261, permute_269, permute_273, permute_277, permute_281, permute_289, permute_293, permute_297, permute_301, permute_305, permute_309, permute_317, permute_321, permute_325, permute_329, permute_337, permute_341, permute_345, permute_349, permute_357, permute_361, permute_365, permute_369, permute_377, permute_381, permute_385, permute_389, permute_397, permute_401, permute_405, permute_409, permute_417, permute_421, permute_425, permute_429, permute_437, tangents_1 = args
args.clear()
assert_size_stride(primals_1, (65536, ), (1, ))
assert_size_stride(primals_7, (80, ), (1, ))
assert_size_stride(primals_86, (65536, ), (1, ))
assert_size_stride(embedding, (65536, 1024), (1024, 1))
assert_size_stride(embedding_1, (65536, 1024), (1024, 1))
assert_size_stride(embedding_2, (65536, 1024), (1024, 1))
assert_size_stride(cumsum, (65536, ), (1, ))
assert_size_stride(unsqueeze_9, (1, 1, 512, 512), (262144, 262144, 512, 1))
assert_size_stride(unsqueeze_13, (1, 1, 512, 512), (262144, 262144, 512, 1))
assert_size_stride(clamp_max, (1, 1, 512), (512, 512, 1))
assert_size_stride(clamp_max_1, (1, 1, 512), (512, 512, 1))
assert_size_stride(convert_element_type_2, (1, 1, 512), (512, 512, 1))
assert_size_stride(clone_4, (1, 1, 512, 512), (262144, 262144, 512, 1))
assert_size_stride(convert_element_type_4, (1, 1, 512), (512, 512, 1))
assert_size_stride(clone_7, (1, 1, 512, 512), (262144, 262144, 512, 1))
assert_size_stride(clamp_max_2, (1, 1, 512), (512, 512, 1))
assert_size_stride(clamp_max_3, (1, 1, 512), (512, 512, 1))
assert_size_stride(convert_element_type_6, (1, 1, 512), (512, 512, 1))
assert_size_stride(clone_10, (1, 1, 512, 512), (262144, 262144, 512, 1))
assert_size_stride(convert_element_type_8, (1, 1, 512), (512, 512, 1))
assert_size_stride(clone_13, (1, 1, 512, 512), (262144, 262144, 512, 1))
assert_size_stride(embedding_3, (65536, 1024), (1024, 1))
assert_size_stride(rsqrt, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_9, (65536, 1024), (1024, 1))
assert_size_stride(getitem_12, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_13, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_14, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_1, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_2, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_44, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_46, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_3, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_5, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_6, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_7, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_19, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_20, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_10, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_4, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_16, (65536, 1024), (1024, 1))
assert_size_stride(mm_2, (65536, 4096), (4096, 1))
assert_size_stride(view_18, (65536, 4096), (4096, 1))
assert_size_stride(add_12, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_21, (65536, 1024), (1024, 1))
assert_size_stride(getitem_21, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_22, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_23, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_5, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_6, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_52, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_54, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_7, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_13, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_14, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_15, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_28, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_29, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_22, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_8, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_28, (65536, 1024), (1024, 1))
assert_size_stride(mm_6, (65536, 4096), (4096, 1))
assert_size_stride(view_30, (65536, 4096), (4096, 1))
assert_size_stride(add_24, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_33, (65536, 1024), (1024, 1))
assert_size_stride(getitem_30, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_31, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_32, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_9, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_10, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_60, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_62, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_11, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_21, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_22, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_23, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_37, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_38, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_34, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_12, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_40, (65536, 1024), (1024, 1))
assert_size_stride(mm_10, (65536, 4096), (4096, 1))
assert_size_stride(view_42, (65536, 4096), (4096, 1))
assert_size_stride(add_36, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_45, (65536, 1024), (1024, 1))
assert_size_stride(getitem_39, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_40, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_41, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_13, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_14, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_68, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_70, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_15, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_29, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_30, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_31, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_46, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_47, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_45, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_16, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_51, (65536, 1024), (1024, 1))
assert_size_stride(mm_14, (65536, 4096), (4096, 1))
assert_size_stride(view_53, (65536, 4096), (4096, 1))
assert_size_stride(add_47, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_56, (65536, 1024), (1024, 1))
assert_size_stride(getitem_48, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_49, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_50, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_17, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_18, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_76, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_78, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_19, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_37, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_38, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_39, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_55, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_56, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_56, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_20, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_62, (65536, 1024), (1024, 1))
assert_size_stride(mm_18, (65536, 4096), (4096, 1))
assert_size_stride(view_64, (65536, 4096), (4096, 1))
assert_size_stride(add_58, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_67, (65536, 1024), (1024, 1))
assert_size_stride(getitem_57, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_58, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_59, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_21, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_22, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_84, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_86, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_23, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_45, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_46, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_47, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_64, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_65, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_67, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_24, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_73, (65536, 1024), (1024, 1))
assert_size_stride(mm_22, (65536, 4096), (4096, 1))
assert_size_stride(view_75, (65536, 4096), (4096, 1))
assert_size_stride(add_69, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_78, (65536, 1024), (1024, 1))
assert_size_stride(getitem_66, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_67, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_68, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_25, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_26, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_92, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_94, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_27, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_53, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_54, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_55, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_73, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_74, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_78, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_28, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_84, (65536, 1024), (1024, 1))
assert_size_stride(mm_26, (65536, 4096), (4096, 1))
assert_size_stride(view_86, (65536, 4096), (4096, 1))
assert_size_stride(add_80, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(add_81, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_29, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_88, (65536, 1024), (1024, 1))
assert_size_stride(mm_28, (65536, 4096), (4096, 1))
assert_size_stride(view_90, (65536, 4096), (4096, 1))
assert_size_stride(add_83, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_93, (65536, 1024), (1024, 1))
assert_size_stride(getitem_75, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_76, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_77, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_30, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_31, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_100, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_102, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_32, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_63, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_64, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_65, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_82, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_83, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_92, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_33, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_99, (65536, 1024), (1024, 1))
assert_size_stride(mm_32, (65536, 4096), (4096, 1))
assert_size_stride(view_101, (65536, 4096), (4096, 1))
assert_size_stride(add_95, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_104, (65536, 1024), (1024, 1))
assert_size_stride(getitem_84, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_85, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_86, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_34, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_35, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_108, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_110, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_36, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_71, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_72, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_73, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_91, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_92, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_104, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_37, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_110, (65536, 1024), (1024, 1))
assert_size_stride(mm_36, (65536, 4096), (4096, 1))
assert_size_stride(view_112, (65536, 4096), (4096, 1))
assert_size_stride(add_107, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_115, (65536, 1024), (1024, 1))
assert_size_stride(getitem_93, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_94, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_95, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_38, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_39, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_116, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_118, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_40, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_79, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_80, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_81, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_100, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_101, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_116, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_41, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_121, (65536, 1024), (1024, 1))
assert_size_stride(mm_40, (65536, 4096), (4096, 1))
assert_size_stride(view_123, (65536, 4096), (4096, 1))
assert_size_stride(add_119, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_126, (65536, 1024), (1024, 1))
assert_size_stride(getitem_102, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_103, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_104, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_42, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_43, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_124, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_126, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_44, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_87, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_88, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_89, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_109, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_110, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_128, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_45, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_132, (65536, 1024), (1024, 1))
assert_size_stride(mm_44, (65536, 4096), (4096, 1))
assert_size_stride(view_134, (65536, 4096), (4096, 1))
assert_size_stride(add_130, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_137, (65536, 1024), (1024, 1))
assert_size_stride(getitem_111, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_112, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_113, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_46, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_47, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_132, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_134, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_48, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_95, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_96, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_97, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_118, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_119, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_139, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_49, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_143, (65536, 1024), (1024, 1))
assert_size_stride(mm_48, (65536, 4096), (4096, 1))
assert_size_stride(view_145, (65536, 4096), (4096, 1))
assert_size_stride(add_141, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_148, (65536, 1024), (1024, 1))
assert_size_stride(getitem_120, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_121, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_122, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_50, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_51, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_140, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_142, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_52, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_103, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_104, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_105, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_127, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_128, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_151, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_53, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_155, (65536, 1024), (1024, 1))
assert_size_stride(mm_52, (65536, 4096), (4096, 1))
assert_size_stride(view_157, (65536, 4096), (4096, 1))
assert_size_stride(add_153, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_160, (65536, 1024), (1024, 1))
assert_size_stride(getitem_129, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_130, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_131, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_54, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_55, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_148, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_150, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_56, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_111, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_112, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_113, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_136, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_137, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_163, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_57, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_167, (65536, 1024), (1024, 1))
assert_size_stride(mm_56, (65536, 4096), (4096, 1))
assert_size_stride(view_169, (65536, 4096), (4096, 1))
assert_size_stride(add_165, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(view_172, (65536, 1024), (1024, 1))
assert_size_stride(getitem_138, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_139, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(getitem_140, (1, 65536, 8, 128), (201326592, 3072, 128, 1))
assert_size_stride(rsqrt_58, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(rsqrt_59, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(unsqueeze_156, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(unsqueeze_158, (1, 65536, 1, 64), (16777216, 64, 64, 1))
assert_size_stride(rsqrt_60, (1, 65536, 8, 1), (524288, 8, 1, 1))
assert_size_stride(permute_119, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_120, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(permute_121, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_145, (1, 8, 65536, 128), (67108864, 128, 1024, 1))
assert_size_stride(getitem_146, (1, 8, 65536), (524288, 65536, 1))
assert_size_stride(add_175, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_61, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_179, (65536, 1024), (1024, 1))
assert_size_stride(mm_60, (65536, 4096), (4096, 1))
assert_size_stride(view_181, (65536, 4096), (4096, 1))
assert_size_stride(add_177, (1, 65536, 1024), (67108864, 1024, 1))
assert_size_stride(rsqrt_62, (1, 65536, 1), (65536, 1, 1))
assert_size_stride(view_183, (65536, 1024), (1024, 1))
assert_size_stride(mm_62, (65536, 50304), (50304, 1))
assert_size_stride(amax, (65536, 1), (1, 1))
assert_size_stride(log, (65536, 1), (1, 1))
assert_size_stride(convert_element_type_324, (), ())
assert_size_stride(permute_129, (50304, 1024), (1024, 1))
assert_size_stride(permute_133, (1024, 4096), (4096, 1))
assert_size_stride(permute_137, (4096, 1024), (1024, 1))
assert_size_stride(permute_141, (1024, 1024), (1024, 1))
assert_size_stride(permute_149, (3072, 1024), (1024, 1))
assert_size_stride(permute_153, (1024, 4096), (4096, 1))
assert_size_stride(permute_157, (4096, 1024), (1024, 1))
assert_size_stride(permute_161, (1024, 1024), (1024, 1))
assert_size_stride(permute_169, (3072, 1024), (1024, 1))
assert_size_stride(permute_173, (1024, 4096), (4096, 1))
assert_size_stride(permute_177, (4096, 1024), (1024, 1))
assert_size_stride(permute_181, (1024, 1024), (1024, 1))
assert_size_stride(permute_189, (3072, 1024), (1024, 1))
assert_size_stride(permute_193, (1024, 4096), (4096, 1))
assert_size_stride(permute_197, (4096, 1024), (1024, 1))
assert_size_stride(permute_201, (1024, 1024), (1024, 1))
assert_size_stride(permute_209, (3072, 1024), (1024, 1))
assert_size_stride(permute_213, (1024, 4096), (4096, 1))
assert_size_stride(permute_217, (4096, 1024), (1024, 1))
assert_size_stride(permute_221, (1024, 1024), (1024, 1))
assert_size_stride(permute_229, (3072, 1024), (1024, 1))
assert_size_stride(permute_233, (1024, 4096), (4096, 1))
assert_size_stride(permute_237, (4096, 1024), (1024, 1))
assert_size_stride(permute_241, (1024, 1024), (1024, 1))
assert_size_stride(permute_249, (3072, 1024), (1024, 1))
assert_size_stride(permute_253, (1024, 4096), (4096, 1))
assert_size_stride(permute_257, (4096, 1024), (1024, 1))
assert_size_stride(permute_261, (1024, 1024), (1024, 1))
assert_size_stride(permute_269, (3072, 1024), (1024, 1))
assert_size_stride(permute_273, (1024, 4096), (4096, 1))
assert_size_stride(permute_277, (4096, 1024), (1024, 1))
assert_size_stride(permute_281, (1024, 1024), (1024, 1))
assert_size_stride(permute_289, (3072, 1024), (1024, 1))
assert_size_stride(permute_293, (1024, 4096), (4096, 1))
assert_size_stride(permute_297, (4096, 1024), (1024, 1))
assert_size_stride(permute_301, (1024, 4096), (4096, 1))
assert_size_stride(permute_305, (4096, 1024), (1024, 1))
assert_size_stride(permute_309, (1024, 1024), (1024, 1))
assert_size_stride(permute_317, (3072, 1024), (1024, 1))
assert_size_stride(permute_321, (1024, 4096), (4096, 1))
assert_size_stride(permute_325, (4096, 1024), (1024, 1))
assert_size_stride(permute_329, (1024, 1024), (1024, 1))
assert_size_stride(permute_337, (3072, 1024), (1024, 1))
assert_size_stride(permute_341, (1024, 4096), (4096, 1))
assert_size_stride(permute_345, (4096, 1024), (1024, 1))
assert_size_stride(permute_349, (1024, 1024), (1024, 1))
assert_size_stride(permute_357, (3072, 1024), (1024, 1))
assert_size_stride(permute_361, (1024, 4096), (4096, 1))
assert_size_stride(permute_365, (4096, 1024), (1024, 1))
assert_size_stride(permute_369, (1024, 1024), (1024, 1))
assert_size_stride(permute_377, (3072, 1024), (1024, 1))
assert_size_stride(permute_381, (1024, 4096), (4096, 1))
assert_size_stride(permute_385, (4096, 1024), (1024, 1))
assert_size_stride(permute_389, (1024, 1024), (1024, 1))
assert_size_stride(permute_397, (3072, 1024), (1024, 1))
assert_size_stride(permute_401, (1024, 4096), (4096, 1))
assert_size_stride(permute_405, (4096, 1024), (1024, 1))
assert_size_stride(permute_409, (1024, 1024), (1024, 1))
assert_size_stride(permute_417, (3072, 1024), (1024, 1))
assert_size_stride(permute_421, (1024, 4096), (4096, 1))
assert_size_stride(permute_425, (4096, 1024), (1024, 1))
assert_size_stride(permute_429, (1024, 1024), (1024, 1))
assert_size_stride(permute_437, (3072, 1024), (1024, 1))
assert_size_stride(tangents_1, (), ())
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf2 = mm_62; del mm_62 # reuse
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data]
stream0 = get_raw_stream(0)
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0.run(buf2, primals_86, tangents_1, convert_element_type_324, amax, log, 65536, 50304, grid=grid(65536), stream=stream0)
del amax
del convert_element_type_324
del log
del primals_86
del tangents_1
buf3 = empty_strided_cuda((50304, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf2, (50304, 65536), (1, 50304), 0), view_183, out=buf3)
del view_183
buf4 = empty_strided_cuda((65536, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(buf2, permute_129, out=buf4)
del buf2
del permute_129
buf5 = empty_strided_cuda((50304, 1024), (1024, 1), torch.float32)
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_1.run(buf3, buf5, 51511296, grid=grid(51511296), stream=stream0)
del buf3
buf7 = reinterpret_tensor(buf4, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf4 # reuse
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_2.run(buf7, add_177, rsqrt_62, 65536, 1024, grid=grid(65536), stream=stream0)
del add_177
del rsqrt_62
buf8 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf7, (1024, 65536), (1, 1024), 0), view_181, out=buf8)
del view_181
buf9 = empty_strided_cuda((65536, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf7, (65536, 1024), (1024, 1), 0), permute_133, out=buf9)
del permute_133
buf10 = reinterpret_tensor(mm_60, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_60 # reuse
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf10, buf9, 268435456, grid=grid(268435456), stream=stream0)
del buf9
buf11 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf10, (4096, 65536), (1, 4096), 0), view_179, out=buf11)
del view_179
buf12 = empty_strided_cuda((65536, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf10, (65536, 4096), (4096, 1), 0), permute_137, out=buf12)
del permute_137
buf14 = buf7; del buf7 # reuse
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf14, buf12, add_175, rsqrt_61, 65536, 1024, grid=grid(65536), stream=stream0)
del add_175
del rsqrt_61
buf15 = empty_strided_cuda((1024, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf14, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_145, (65536, 1024), (1024, 1), 0), out=buf15)
buf16 = buf12; del buf12 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf14, (65536, 1024), (1024, 1), 0), permute_141, out=buf16)
del permute_141
buf17 = empty_strided_cuda((1, 8, 65536), (524288, 1, 8), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_145, buf16, buf17, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_145
buf18 = empty_strided_cuda((1, 8, 65536), (524288, 65536, 1), torch.float32)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf17, buf18, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf20 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf21 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf22 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_119, permute_120, permute_121, getitem_146, buf18, buf16, buf20, buf21, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf22, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_146
del permute_119
del permute_120
del permute_121
buf25 = empty_strided_cuda((512, ), (1, ), torch.float32)
buf27 = empty_strided_cuda((512, ), (1, ), torch.float32)
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten.sum, aten._to_copy]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_8.run(buf21, embedding_2, getitem_140, rsqrt_60, buf25, buf27, 512, 131072, grid=grid(512), stream=stream0)
buf26 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf25, buf26, 1, 512, grid=grid(1), stream=stream0)
buf28 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf27, buf28, 1, 512, grid=grid(1), stream=stream0)
buf37 = empty_strided_cuda((1, 65536, 24, 128), (201326592, 3072, 128, 1), torch.bfloat16)
buf36 = reinterpret_tensor(buf37, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_10.run(buf21, primals_7, getitem_140, rsqrt_60, buf36, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_140
del rsqrt_60
buf30 = empty_strided_cuda((1, 65536, 8, 128), (67108864, 1024, 128, 1), torch.float32)
buf31 = empty_strided_cuda((1, 65536, 8, 128), (67108864, 1024, 128, 1), torch.float32)
buf35 = reinterpret_tensor(buf37, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf34 = reinterpret_tensor(buf37, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_43, q_43], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf22, unsqueeze_158, unsqueeze_156, buf20, getitem_139, rsqrt_59, getitem_138, rsqrt_58, buf30, buf31, buf35, buf34, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_138
del getitem_139
del rsqrt_58
del rsqrt_59
del unsqueeze_156
del unsqueeze_158
del buf34
del buf35
del buf36
buf38 = empty_strided_cuda((3072, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf37, (3072, 65536), (1, 3072), 0), view_172, out=buf38)
del view_172
buf39 = reinterpret_tensor(buf22, (65536, 1024), (1024, 1), 0); del buf22 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf37, (65536, 3072), (3072, 1), 0), permute_149, out=buf39)
del permute_149
buf40 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf15, buf38, buf40, 4194304, grid=grid(4194304), stream=stream0)
buf45 = reinterpret_tensor(buf20, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf20 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_13.run(buf14, buf39, primals_7, buf45, 67108864, grid=grid(67108864), stream=stream0)
buf47 = reinterpret_tensor(buf10, (65536, 4096), (4096, 1), 0); del buf10 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf45, (65536, 1024), (1024, 1), 0), permute_153, out=buf47)
del permute_153
buf48 = reinterpret_tensor(mm_56, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_56 # reuse
# Topologically Sorted Source Nodes: [relu_14], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf48, buf47, 268435456, grid=grid(268435456), stream=stream0)
del buf47
buf50 = buf16; del buf16 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf48, (65536, 4096), (4096, 1), 0), permute_157, out=buf50)
del permute_157
buf52 = reinterpret_tensor(buf50, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf50 # reuse
# Topologically Sorted Source Nodes: [rms_norm_57], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf52, add_163, buf45, rsqrt_57, 65536, 1024, grid=grid(65536), stream=stream0)
del add_163
del rsqrt_57
buf54 = empty_strided_cuda((65536, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf52, (65536, 1024), (1024, 1), 0), permute_161, out=buf54)
del permute_161
buf55 = buf17; del buf17 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_136, buf54, buf55, 524288, 128, grid=grid(524288), stream=stream0)
buf56 = buf18; del buf18 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf55, buf56, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf58 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf59 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf60 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_111, permute_112, permute_113, getitem_137, buf56, buf54, buf58, buf59, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf60, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_137
del permute_111
del permute_112
del permute_113
buf68 = buf31; del buf31 # reuse
buf69 = buf30; del buf30 # reuse
buf75 = buf37; del buf37 # reuse
buf73 = reinterpret_tensor(buf75, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf72 = reinterpret_tensor(buf75, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_40, q_40], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf60, unsqueeze_150, unsqueeze_148, buf58, getitem_130, rsqrt_55, getitem_129, rsqrt_54, buf68, buf69, buf73, buf72, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_129
del getitem_130
del rsqrt_54
del rsqrt_55
del unsqueeze_148
del unsqueeze_150
buf74 = reinterpret_tensor(buf75, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_15.run(buf59, primals_7, getitem_131, rsqrt_56, buf74, 524288, 128, grid=grid(524288), stream=stream0)
del buf72
del buf73
del buf74
buf77 = reinterpret_tensor(buf60, (65536, 1024), (1024, 1), 0); del buf60 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf75, (65536, 3072), (3072, 1), 0), permute_169, out=buf77)
del permute_169
buf41 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
buf43 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
buf79 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
buf81 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
buf83 = reinterpret_tensor(buf58, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf58 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_16.run(buf14, buf39, embedding_3, rsqrt, add_165, buf52, buf77, add_153, primals_7, buf41, buf43, buf79, buf81, buf83, 65536, 1024, grid=grid(65536), stream=stream0)
del add_153
del add_165
buf42 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf41, buf42, 1, 65536, grid=grid(1), stream=stream0)
buf44 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf43, buf44, 1, 65536, grid=grid(1), stream=stream0)
buf46 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf45, (1024, 65536), (1, 1024), 0), view_169, out=buf46)
del view_169
buf49 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf48, (4096, 65536), (1, 4096), 0), view_167, out=buf49)
del view_167
buf53 = buf15; del buf15 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf52, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_136, (65536, 1024), (1024, 1), 0), out=buf53)
del getitem_136
buf63 = buf27; del buf27 # reuse
buf65 = buf25; del buf25 # reuse
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten.sum, aten._to_copy]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_8.run(buf59, embedding_1, getitem_131, rsqrt_56, buf63, buf65, 512, 131072, grid=grid(512), stream=stream0)
del getitem_131
del rsqrt_56
buf64 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf63, buf64, 1, 512, grid=grid(1), stream=stream0)
buf66 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf65, buf66, 1, 512, grid=grid(1), stream=stream0)
buf76 = buf38; del buf38 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf75, (3072, 65536), (1, 3072), 0), view_160, out=buf76)
del view_160
buf78 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf53, buf76, buf78, 4194304, grid=grid(4194304), stream=stream0)
buf80 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf79, buf80, 1, 65536, grid=grid(1), stream=stream0)
buf82 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf81, buf82, 1, 65536, grid=grid(1), stream=stream0)
buf84 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf83, (1024, 65536), (1, 1024), 0), view_157, out=buf84)
del view_157
buf85 = reinterpret_tensor(buf48, (65536, 4096), (4096, 1), 0); del buf48 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf83, (65536, 1024), (1024, 1), 0), permute_173, out=buf85)
del permute_173
buf86 = reinterpret_tensor(mm_52, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_52 # reuse
# Topologically Sorted Source Nodes: [relu_13], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf86, buf85, 268435456, grid=grid(268435456), stream=stream0)
del buf85
buf87 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf86, (4096, 65536), (1, 4096), 0), view_155, out=buf87)
del view_155
buf88 = reinterpret_tensor(buf45, (65536, 1024), (1024, 1), 0); del buf45 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf86, (65536, 4096), (4096, 1), 0), permute_177, out=buf88)
del permute_177
buf90 = buf83; del buf83 # reuse
# Topologically Sorted Source Nodes: [rms_norm_53], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf90, buf88, add_151, rsqrt_53, 65536, 1024, grid=grid(65536), stream=stream0)
del add_151
del rsqrt_53
buf91 = buf53; del buf53 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf90, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_127, (65536, 1024), (1024, 1), 0), out=buf91)
buf92 = buf88; del buf88 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf90, (65536, 1024), (1024, 1), 0), permute_181, out=buf92)
del permute_181
buf93 = buf55; del buf55 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_127, buf92, buf93, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_127
buf94 = buf56; del buf56 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf93, buf94, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf96 = reinterpret_tensor(buf54, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf54 # reuse
buf97 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf98 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_103, permute_104, permute_105, getitem_128, buf94, buf92, buf96, buf97, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf98, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_128
del permute_103
del permute_104
del permute_105
buf101 = buf65; del buf65 # reuse
buf103 = buf63; del buf63 # reuse
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten.sum, aten._to_copy]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_8.run(buf97, embedding, getitem_122, rsqrt_52, buf101, buf103, 512, 131072, grid=grid(512), stream=stream0)
buf102 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf101, buf102, 1, 512, grid=grid(1), stream=stream0)
buf104 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf103, buf104, 1, 512, grid=grid(1), stream=stream0)
buf113 = buf75; del buf75 # reuse
buf112 = reinterpret_tensor(buf113, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_18.run(buf97, primals_7, getitem_122, rsqrt_52, buf112, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_122
del rsqrt_52
buf106 = buf69; del buf69 # reuse
buf107 = buf68; del buf68 # reuse
buf111 = reinterpret_tensor(buf113, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf110 = reinterpret_tensor(buf113, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_37, q_37], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf98, unsqueeze_142, unsqueeze_140, buf96, getitem_121, rsqrt_51, getitem_120, rsqrt_50, buf106, buf107, buf111, buf110, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_120
del getitem_121
del rsqrt_50
del rsqrt_51
del unsqueeze_140
del unsqueeze_142
del buf110
del buf111
del buf112
buf114 = buf76; del buf76 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf113, (3072, 65536), (1, 3072), 0), view_148, out=buf114)
del view_148
buf115 = reinterpret_tensor(buf98, (65536, 1024), (1024, 1), 0); del buf98 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf113, (65536, 3072), (3072, 1), 0), permute_189, out=buf115)
del permute_189
buf116 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf91, buf114, buf116, 4194304, grid=grid(4194304), stream=stream0)
buf119 = buf14; del buf14 # reuse
buf122 = reinterpret_tensor(buf96, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf96 # reuse
buf117 = buf81; del buf81 # reuse
buf120 = buf79; del buf79 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten.mul, aten._to_copy, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_19.run(buf119, buf39, primals_7, buf52, buf77, buf90, buf115, embedding_3, rsqrt, add_141, buf122, buf117, buf120, 65536, 1024, grid=grid(65536), stream=stream0)
del add_141
buf118 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf117, buf118, 1, 65536, grid=grid(1), stream=stream0)
buf121 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf120, buf121, 1, 65536, grid=grid(1), stream=stream0)
buf123 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf122, (1024, 65536), (1, 1024), 0), view_145, out=buf123)
del view_145
buf124 = reinterpret_tensor(buf86, (65536, 4096), (4096, 1), 0); del buf86 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf122, (65536, 1024), (1024, 1), 0), permute_193, out=buf124)
del permute_193
buf125 = reinterpret_tensor(mm_48, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_48 # reuse
# Topologically Sorted Source Nodes: [relu_12], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf125, buf124, 268435456, grid=grid(268435456), stream=stream0)
del buf124
buf126 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf125, (4096, 65536), (1, 4096), 0), view_143, out=buf126)
del view_143
buf127 = reinterpret_tensor(buf90, (65536, 1024), (1024, 1), 0); del buf90 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf125, (65536, 4096), (4096, 1), 0), permute_197, out=buf127)
del permute_197
buf129 = buf122; del buf122 # reuse
# Topologically Sorted Source Nodes: [rms_norm_49], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf129, buf127, add_139, rsqrt_49, 65536, 1024, grid=grid(65536), stream=stream0)
del add_139
del rsqrt_49
buf130 = buf91; del buf91 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf129, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_118, (65536, 1024), (1024, 1), 0), out=buf130)
buf131 = buf127; del buf127 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf129, (65536, 1024), (1024, 1), 0), permute_201, out=buf131)
del permute_201
buf132 = buf93; del buf93 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_118, buf131, buf132, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_118
buf133 = buf94; del buf94 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf132, buf133, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf135 = reinterpret_tensor(buf77, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf77 # reuse
buf136 = reinterpret_tensor(buf52, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf52 # reuse
buf137 = reinterpret_tensor(buf39, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf39 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_95, permute_96, permute_97, getitem_119, buf133, buf131, buf135, buf136, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf137, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_119
del permute_95
del permute_96
del permute_97
buf140 = buf103; del buf103 # reuse
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf136, getitem_113, rsqrt_48, buf140, 512, 131072, grid=grid(512), stream=stream0)
buf141 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf140, buf141, 1, 512, grid=grid(1), stream=stream0)
buf150 = buf113; del buf113 # reuse
buf149 = reinterpret_tensor(buf150, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_21.run(buf136, primals_7, getitem_113, rsqrt_48, buf149, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_113
del rsqrt_48
buf143 = buf107; del buf107 # reuse
buf144 = buf106; del buf106 # reuse
buf148 = reinterpret_tensor(buf150, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf147 = reinterpret_tensor(buf150, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_34, q_34], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf137, unsqueeze_134, unsqueeze_132, buf135, getitem_112, rsqrt_47, getitem_111, rsqrt_46, buf143, buf144, buf148, buf147, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_111
del getitem_112
del rsqrt_46
del rsqrt_47
del unsqueeze_132
del unsqueeze_134
del buf147
del buf148
del buf149
buf151 = buf114; del buf114 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf150, (3072, 65536), (1, 3072), 0), view_137, out=buf151)
del view_137
buf152 = reinterpret_tensor(buf137, (65536, 1024), (1024, 1), 0); del buf137 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf150, (65536, 3072), (3072, 1), 0), permute_209, out=buf152)
del permute_209
buf153 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf130, buf151, buf153, 4194304, grid=grid(4194304), stream=stream0)
buf158 = reinterpret_tensor(buf135, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf135 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_22.run(buf129, buf152, primals_7, buf158, 67108864, grid=grid(67108864), stream=stream0)
buf160 = reinterpret_tensor(buf125, (65536, 4096), (4096, 1), 0); del buf125 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf158, (65536, 1024), (1024, 1), 0), permute_213, out=buf160)
del permute_213
buf161 = reinterpret_tensor(mm_44, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_44 # reuse
# Topologically Sorted Source Nodes: [relu_11], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf161, buf160, 268435456, grid=grid(268435456), stream=stream0)
del buf160
buf163 = reinterpret_tensor(buf136, (65536, 1024), (1024, 1), 0); del buf136 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf161, (65536, 4096), (4096, 1), 0), permute_217, out=buf163)
del permute_217
buf165 = reinterpret_tensor(buf163, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf163 # reuse
# Topologically Sorted Source Nodes: [rms_norm_45], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf165, add_128, buf158, rsqrt_45, 65536, 1024, grid=grid(65536), stream=stream0)
del add_128
del rsqrt_45
buf167 = buf131; del buf131 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf165, (65536, 1024), (1024, 1), 0), permute_221, out=buf167)
del permute_221
buf168 = buf132; del buf132 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_109, buf167, buf168, 524288, 128, grid=grid(524288), stream=stream0)
buf169 = buf133; del buf133 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf168, buf169, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf171 = reinterpret_tensor(buf115, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf115 # reuse
buf172 = reinterpret_tensor(buf92, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf92 # reuse
buf173 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_87, permute_88, permute_89, getitem_110, buf169, buf167, buf171, buf172, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf173, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_110
del permute_87
del permute_88
del permute_89
buf179 = buf144; del buf144 # reuse
buf180 = buf143; del buf143 # reuse
buf186 = buf150; del buf150 # reuse
buf184 = reinterpret_tensor(buf186, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf183 = reinterpret_tensor(buf186, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_31, q_31], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf173, unsqueeze_126, unsqueeze_124, buf171, getitem_103, rsqrt_43, getitem_102, rsqrt_42, buf179, buf180, buf184, buf183, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_102
del getitem_103
del rsqrt_42
del rsqrt_43
del unsqueeze_124
del unsqueeze_126
buf185 = reinterpret_tensor(buf186, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_23.run(buf172, primals_7, getitem_104, rsqrt_44, buf185, 524288, 128, grid=grid(524288), stream=stream0)
del buf183
del buf184
del buf185
buf188 = reinterpret_tensor(buf173, (65536, 1024), (1024, 1), 0); del buf173 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf186, (65536, 3072), (3072, 1), 0), permute_229, out=buf188)
del permute_229
buf154 = buf120; del buf120 # reuse
buf156 = buf117; del buf117 # reuse
buf190 = buf43; del buf43 # reuse
buf192 = buf41; del buf41 # reuse
buf194 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
buf196 = reinterpret_tensor(buf171, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf171 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_24.run(buf129, buf152, embedding_3, rsqrt, add_130, buf165, buf188, add_119, primals_7, add_36, buf154, buf156, buf190, buf192, buf194, buf196, 65536, 1024, grid=grid(65536), stream=stream0)
del add_119
del add_130
buf155 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf154, buf155, 1, 65536, grid=grid(1), stream=stream0)
buf157 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf156, buf157, 1, 65536, grid=grid(1), stream=stream0)
buf159 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf158, (1024, 65536), (1, 1024), 0), view_134, out=buf159)
del view_134
buf162 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf161, (4096, 65536), (1, 4096), 0), view_132, out=buf162)
del view_132
buf166 = buf130; del buf130 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf165, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_109, (65536, 1024), (1024, 1), 0), out=buf166)
del getitem_109
buf176 = buf140; del buf140 # reuse
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf172, getitem_104, rsqrt_44, buf176, 512, 131072, grid=grid(512), stream=stream0)
del getitem_104
del rsqrt_44
buf177 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf176, buf177, 1, 512, grid=grid(1), stream=stream0)
buf187 = buf151; del buf151 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf186, (3072, 65536), (1, 3072), 0), view_126, out=buf187)
del view_126
buf189 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf166, buf187, buf189, 4194304, grid=grid(4194304), stream=stream0)
buf191 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf190, buf191, 1, 65536, grid=grid(1), stream=stream0)
buf193 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf192, buf193, 1, 65536, grid=grid(1), stream=stream0)
buf195 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf194, buf195, 1, 65536, grid=grid(1), stream=stream0)
buf197 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf196, (1024, 65536), (1, 1024), 0), view_123, out=buf197)
del view_123
buf198 = reinterpret_tensor(buf161, (65536, 4096), (4096, 1), 0); del buf161 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf196, (65536, 1024), (1024, 1), 0), permute_233, out=buf198)
del permute_233
buf199 = reinterpret_tensor(mm_40, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_40 # reuse
# Topologically Sorted Source Nodes: [relu_10], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf199, buf198, 268435456, grid=grid(268435456), stream=stream0)
del buf198
buf200 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf199, (4096, 65536), (1, 4096), 0), view_121, out=buf200)
del view_121
buf201 = reinterpret_tensor(buf172, (65536, 1024), (1024, 1), 0); del buf172 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf199, (65536, 4096), (4096, 1), 0), permute_237, out=buf201)
del permute_237
buf203 = buf196; del buf196 # reuse
# Topologically Sorted Source Nodes: [rms_norm_41], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf203, buf201, add_116, rsqrt_41, 65536, 1024, grid=grid(65536), stream=stream0)
del add_116
del rsqrt_41
buf204 = buf166; del buf166 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf203, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_100, (65536, 1024), (1024, 1), 0), out=buf204)
buf205 = buf201; del buf201 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf203, (65536, 1024), (1024, 1), 0), permute_241, out=buf205)
del permute_241
buf206 = buf168; del buf168 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_100, buf205, buf206, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_100
buf207 = buf169; del buf169 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf206, buf207, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf209 = reinterpret_tensor(buf158, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf158 # reuse
buf210 = reinterpret_tensor(buf167, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf167 # reuse
buf211 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_79, permute_80, permute_81, getitem_101, buf207, buf205, buf209, buf210, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf211, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_101
del permute_79
del permute_80
del permute_81
buf214 = buf176; del buf176 # reuse
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf210, getitem_95, rsqrt_40, buf214, 512, 131072, grid=grid(512), stream=stream0)
buf215 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf214, buf215, 1, 512, grid=grid(1), stream=stream0)
buf224 = buf186; del buf186 # reuse
buf223 = reinterpret_tensor(buf224, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_25.run(buf210, primals_7, getitem_95, rsqrt_40, buf223, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_95
del rsqrt_40
buf217 = buf180; del buf180 # reuse
buf218 = buf179; del buf179 # reuse
buf222 = reinterpret_tensor(buf224, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf221 = reinterpret_tensor(buf224, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_28, q_28], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf211, unsqueeze_118, unsqueeze_116, buf209, getitem_94, rsqrt_39, getitem_93, rsqrt_38, buf217, buf218, buf222, buf221, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_93
del getitem_94
del rsqrt_38
del rsqrt_39
del unsqueeze_116
del unsqueeze_118
del buf221
del buf222
del buf223
buf225 = buf187; del buf187 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf224, (3072, 65536), (1, 3072), 0), view_115, out=buf225)
del view_115
buf226 = reinterpret_tensor(buf211, (65536, 1024), (1024, 1), 0); del buf211 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf224, (65536, 3072), (3072, 1), 0), permute_249, out=buf226)
del permute_249
buf227 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf204, buf225, buf227, 4194304, grid=grid(4194304), stream=stream0)
buf228 = buf194; del buf194 # reuse
buf231 = buf192; del buf192 # reuse
buf233 = buf190; del buf190 # reuse
buf230 = buf119; del buf119 # reuse
buf235 = reinterpret_tensor(buf209, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf209 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_26.run(buf230, buf203, buf226, embedding_3, rsqrt, add_107, primals_7, add_58, buf129, buf152, buf165, buf188, buf228, buf231, buf233, buf235, 65536, 1024, grid=grid(65536), stream=stream0)
del add_107
buf229 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf228, buf229, 1, 65536, grid=grid(1), stream=stream0)
buf232 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf231, buf232, 1, 65536, grid=grid(1), stream=stream0)
buf234 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf233, buf234, 1, 65536, grid=grid(1), stream=stream0)
buf236 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf235, (1024, 65536), (1, 1024), 0), view_112, out=buf236)
del view_112
buf237 = reinterpret_tensor(buf199, (65536, 4096), (4096, 1), 0); del buf199 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf235, (65536, 1024), (1024, 1), 0), permute_253, out=buf237)
del permute_253
buf238 = reinterpret_tensor(mm_36, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_36 # reuse
# Topologically Sorted Source Nodes: [relu_9], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf238, buf237, 268435456, grid=grid(268435456), stream=stream0)
del buf237
buf239 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf238, (4096, 65536), (1, 4096), 0), view_110, out=buf239)
del view_110
buf240 = buf152; del buf152 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf238, (65536, 4096), (4096, 1), 0), permute_257, out=buf240)
del permute_257
buf242 = buf235; del buf235 # reuse
# Topologically Sorted Source Nodes: [rms_norm_37], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf242, buf240, add_104, rsqrt_37, 65536, 1024, grid=grid(65536), stream=stream0)
del add_104
del rsqrt_37
buf243 = buf204; del buf204 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf242, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_91, (65536, 1024), (1024, 1), 0), out=buf243)
buf244 = buf240; del buf240 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf242, (65536, 1024), (1024, 1), 0), permute_261, out=buf244)
del permute_261
buf245 = buf206; del buf206 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_91, buf244, buf245, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_91
buf246 = buf207; del buf207 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf245, buf246, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf248 = reinterpret_tensor(buf129, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf129 # reuse
buf249 = buf210; del buf210 # reuse
buf250 = reinterpret_tensor(buf205, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf205 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_71, permute_72, permute_73, getitem_92, buf246, buf244, buf248, buf249, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf250, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_92
del permute_71
del permute_72
del permute_73
buf253 = buf214; del buf214 # reuse
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf249, getitem_86, rsqrt_36, buf253, 512, 131072, grid=grid(512), stream=stream0)
buf254 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf253, buf254, 1, 512, grid=grid(1), stream=stream0)
buf263 = buf224; del buf224 # reuse
buf262 = reinterpret_tensor(buf263, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_27.run(buf249, primals_7, getitem_86, rsqrt_36, buf262, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_86
del rsqrt_36
buf256 = buf218; del buf218 # reuse
buf257 = buf217; del buf217 # reuse
buf261 = reinterpret_tensor(buf263, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf260 = reinterpret_tensor(buf263, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_25, q_25], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf250, unsqueeze_110, unsqueeze_108, buf248, getitem_85, rsqrt_35, getitem_84, rsqrt_34, buf256, buf257, buf261, buf260, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_84
del getitem_85
del rsqrt_34
del rsqrt_35
del unsqueeze_108
del unsqueeze_110
del buf260
del buf261
del buf262
buf264 = buf225; del buf225 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf263, (3072, 65536), (1, 3072), 0), view_104, out=buf264)
del view_104
buf265 = reinterpret_tensor(buf250, (65536, 1024), (1024, 1), 0); del buf250 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf263, (65536, 3072), (3072, 1), 0), permute_269, out=buf265)
del permute_269
buf266 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf243, buf264, buf266, 4194304, grid=grid(4194304), stream=stream0)
buf273 = reinterpret_tensor(buf248, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf248 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_28.run(buf242, buf265, primals_7, buf273, 67108864, grid=grid(67108864), stream=stream0)
buf275 = reinterpret_tensor(buf238, (65536, 4096), (4096, 1), 0); del buf238 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf273, (65536, 1024), (1024, 1), 0), permute_273, out=buf275)
del permute_273
buf276 = reinterpret_tensor(mm_32, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_32 # reuse
# Topologically Sorted Source Nodes: [relu_8], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf276, buf275, 268435456, grid=grid(268435456), stream=stream0)
buf278 = reinterpret_tensor(buf249, (65536, 1024), (1024, 1), 0); del buf249 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf276, (65536, 4096), (4096, 1), 0), permute_277, out=buf278)
del permute_277
buf280 = reinterpret_tensor(buf278, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf278 # reuse
# Topologically Sorted Source Nodes: [rms_norm_33], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf280, add_92, buf273, rsqrt_33, 65536, 1024, grid=grid(65536), stream=stream0)
del add_92
del rsqrt_33
buf282 = buf244; del buf244 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf280, (65536, 1024), (1024, 1), 0), permute_281, out=buf282)
del permute_281
buf283 = buf245; del buf245 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_82, buf282, buf283, 524288, 128, grid=grid(524288), stream=stream0)
buf284 = buf246; del buf246 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf283, buf284, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf286 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf287 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
buf288 = empty_strided_cuda((1, 8, 65536, 128), (67108864, 128, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_63, permute_64, permute_65, getitem_83, buf284, buf282, buf286, buf287, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf288, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_83
del permute_63
del permute_64
del permute_65
buf294 = buf257; del buf257 # reuse
buf295 = buf256; del buf256 # reuse
buf301 = buf263; del buf263 # reuse
buf299 = reinterpret_tensor(buf301, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf298 = reinterpret_tensor(buf301, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_22, q_22], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf288, unsqueeze_102, unsqueeze_100, buf286, getitem_76, rsqrt_31, getitem_75, rsqrt_30, buf294, buf295, buf299, buf298, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_75
del getitem_76
del rsqrt_30
del rsqrt_31
del unsqueeze_100
del unsqueeze_102
buf300 = reinterpret_tensor(buf301, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_29.run(buf287, primals_7, getitem_77, rsqrt_32, buf300, 524288, 128, grid=grid(524288), stream=stream0)
del buf298
del buf299
del buf300
buf303 = reinterpret_tensor(buf288, (65536, 1024), (1024, 1), 0); del buf288 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf301, (65536, 3072), (3072, 1), 0), permute_289, out=buf303)
del permute_289
buf309 = reinterpret_tensor(buf286, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf286 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_30.run(buf280, buf303, primals_7, buf309, 67108864, grid=grid(67108864), stream=stream0)
buf311 = buf275; del buf275 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf309, (65536, 1024), (1024, 1), 0), permute_293, out=buf311)
del permute_293
buf312 = reinterpret_tensor(mm_28, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_28 # reuse
# Topologically Sorted Source Nodes: [relu_7], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf312, buf311, 268435456, grid=grid(268435456), stream=stream0)
del buf311
buf314 = buf282; del buf282 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf312, (65536, 4096), (4096, 1), 0), permute_297, out=buf314)
del permute_297
buf316 = reinterpret_tensor(buf314, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf314 # reuse
buf319 = buf230; del buf230 # reuse
buf322 = empty_strided_cuda((1, 65536, 1024), (67108864, 1024, 1), torch.bfloat16)
buf267 = buf233; del buf233 # reuse
buf269 = buf231; del buf231 # reuse
buf271 = buf228; del buf228 # reuse
buf305 = buf156; del buf156 # reuse
buf307 = buf154; del buf154 # reuse
buf317 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
buf320 = empty_strided_cuda((1, 65536, 1), (65536, 1, 65536), torch.float32)
# Topologically Sorted Source Nodes: [x, rms_norm_29], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum, aten.div, aten.pow]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_31.run(buf316, buf319, add_81, buf309, rsqrt_29, buf242, buf265, primals_7, buf280, buf303, embedding_3, rsqrt, add_95, add_80, add_83, buf322, buf267, buf269, buf271, buf305, buf307, buf317, buf320, 65536, 1024, grid=grid(65536), stream=stream0)
del add_80
del add_81
del add_83
del add_95
del buf242
del rsqrt_29
buf268 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf267, buf268, 1, 65536, grid=grid(1), stream=stream0)
del buf267
buf270 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf269, buf270, 1, 65536, grid=grid(1), stream=stream0)
del buf269
buf272 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf271, buf272, 1, 65536, grid=grid(1), stream=stream0)
del buf271
buf274 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf273, (1024, 65536), (1, 1024), 0), view_101, out=buf274)
del view_101
buf277 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf276, (4096, 65536), (1, 4096), 0), view_99, out=buf277)
del buf276
del view_99
buf281 = buf243; del buf243 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf280, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_82, (65536, 1024), (1024, 1), 0), out=buf281)
del getitem_82
buf291 = buf253; del buf253 # reuse
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf287, getitem_77, rsqrt_32, buf291, 512, 131072, grid=grid(512), stream=stream0)
del getitem_77
del rsqrt_32
buf292 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf291, buf292, 1, 512, grid=grid(1), stream=stream0)
buf302 = buf264; del buf264 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf301, (3072, 65536), (1, 3072), 0), view_93, out=buf302)
del view_93
buf304 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf281, buf302, buf304, 4194304, grid=grid(4194304), stream=stream0)
buf306 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf305, buf306, 1, 65536, grid=grid(1), stream=stream0)
buf308 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf307, buf308, 1, 65536, grid=grid(1), stream=stream0)
buf310 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf309, (1024, 65536), (1, 1024), 0), view_90, out=buf310)
del view_90
buf313 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf312, (4096, 65536), (1, 4096), 0), view_88, out=buf313)
del view_88
buf318 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf317, buf318, 1, 65536, grid=grid(1), stream=stream0)
buf321 = empty_strided_cuda((), (), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf320, buf321, 1, 65536, grid=grid(1), stream=stream0)
buf323 = empty_strided_cuda((16, 2), (2, 1), torch.float32)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten._to_copy, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_add_select_backward_32.run(buf42, buf44, buf80, buf82, buf118, buf121, buf155, buf157, buf191, buf193, buf229, buf232, buf268, buf270, buf306, buf308, buf318, buf321, buf323, 32, grid=grid(32), stream=stream0)
del buf118
del buf121
del buf155
del buf157
del buf191
del buf193
buf324 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf322, (1024, 65536), (1, 1024), 0), view_86, out=buf324)
del view_86
buf325 = reinterpret_tensor(buf312, (65536, 4096), (4096, 1), 0); del buf312 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf322, (65536, 1024), (1024, 1), 0), permute_301, out=buf325)
del permute_301
buf326 = reinterpret_tensor(mm_26, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_26 # reuse
# Topologically Sorted Source Nodes: [relu_6], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf326, buf325, 268435456, grid=grid(268435456), stream=stream0)
del buf325
buf327 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf326, (4096, 65536), (1, 4096), 0), view_84, out=buf327)
del view_84
buf328 = reinterpret_tensor(buf309, (65536, 1024), (1024, 1), 0); del buf309 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf326, (65536, 4096), (4096, 1), 0), permute_305, out=buf328)
del permute_305
buf330 = buf322; del buf322 # reuse
# Topologically Sorted Source Nodes: [rms_norm_28], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf330, buf328, add_78, rsqrt_28, 65536, 1024, grid=grid(65536), stream=stream0)
del add_78
del rsqrt_28
buf331 = buf281; del buf281 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf330, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_73, (65536, 1024), (1024, 1), 0), out=buf331)
buf332 = buf328; del buf328 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf330, (65536, 1024), (1024, 1), 0), permute_309, out=buf332)
del permute_309
buf333 = buf283; del buf283 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_73, buf332, buf333, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_73
buf334 = buf284; del buf284 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf333, buf334, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf336 = buf287; del buf287 # reuse
buf337 = reinterpret_tensor(buf280, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf280 # reuse
buf338 = reinterpret_tensor(buf273, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf273 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_53, permute_54, permute_55, getitem_74, buf334, buf332, buf336, buf337, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf338, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_74
del permute_53
del permute_54
del permute_55
buf341 = buf291; del buf291 # reuse
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf337, getitem_68, rsqrt_27, buf341, 512, 131072, grid=grid(512), stream=stream0)
buf342 = buf82; del buf82 # reuse
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf341, buf342, 1, 512, grid=grid(1), stream=stream0)
buf351 = buf301; del buf301 # reuse
buf350 = reinterpret_tensor(buf351, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_33.run(buf337, primals_7, getitem_68, rsqrt_27, buf350, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_68
del rsqrt_27
buf344 = buf295; del buf295 # reuse
buf345 = buf294; del buf294 # reuse
buf349 = reinterpret_tensor(buf351, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf348 = reinterpret_tensor(buf351, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_19, q_19], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf338, unsqueeze_94, unsqueeze_92, buf336, getitem_67, rsqrt_26, getitem_66, rsqrt_25, buf344, buf345, buf349, buf348, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_66
del getitem_67
del rsqrt_25
del rsqrt_26
del unsqueeze_92
del unsqueeze_94
del buf348
del buf349
del buf350
buf352 = buf302; del buf302 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf351, (3072, 65536), (1, 3072), 0), view_78, out=buf352)
del view_78
buf353 = reinterpret_tensor(buf338, (65536, 1024), (1024, 1), 0); del buf338 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf351, (65536, 3072), (3072, 1), 0), permute_317, out=buf353)
del permute_317
buf354 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf331, buf352, buf354, 4194304, grid=grid(4194304), stream=stream0)
buf359 = reinterpret_tensor(buf336, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf336 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_34.run(buf330, buf353, primals_7, buf359, 67108864, grid=grid(67108864), stream=stream0)
buf361 = reinterpret_tensor(buf326, (65536, 4096), (4096, 1), 0); del buf326 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf359, (65536, 1024), (1024, 1), 0), permute_321, out=buf361)
del permute_321
buf362 = reinterpret_tensor(mm_22, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_22 # reuse
# Topologically Sorted Source Nodes: [relu_5], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf362, buf361, 268435456, grid=grid(268435456), stream=stream0)
del buf361
buf364 = reinterpret_tensor(buf337, (65536, 1024), (1024, 1), 0); del buf337 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf362, (65536, 4096), (4096, 1), 0), permute_325, out=buf364)
del permute_325
buf366 = reinterpret_tensor(buf364, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf364 # reuse
# Topologically Sorted Source Nodes: [rms_norm_24], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf366, add_67, buf359, rsqrt_24, 65536, 1024, grid=grid(65536), stream=stream0)
del add_67
del rsqrt_24
buf368 = buf332; del buf332 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf366, (65536, 1024), (1024, 1), 0), permute_329, out=buf368)
del permute_329
buf369 = buf333; del buf333 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_64, buf368, buf369, 524288, 128, grid=grid(524288), stream=stream0)
buf370 = buf334; del buf334 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf369, buf370, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf372 = reinterpret_tensor(buf316, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf316 # reuse
buf373 = reinterpret_tensor(buf303, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf303 # reuse
buf374 = reinterpret_tensor(buf265, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf265 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_45, permute_46, permute_47, getitem_65, buf370, buf368, buf372, buf373, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf374, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del buf368
del getitem_65
del permute_45
del permute_46
del permute_47
buf380 = buf345; del buf345 # reuse
buf381 = buf344; del buf344 # reuse
buf387 = buf351; del buf351 # reuse
buf385 = reinterpret_tensor(buf387, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf384 = reinterpret_tensor(buf387, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_16, q_16], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf374, unsqueeze_86, unsqueeze_84, buf372, getitem_58, rsqrt_22, getitem_57, rsqrt_21, buf380, buf381, buf385, buf384, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_57
del getitem_58
del rsqrt_21
del rsqrt_22
del unsqueeze_84
del unsqueeze_86
buf386 = reinterpret_tensor(buf387, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_35.run(buf373, primals_7, getitem_59, rsqrt_23, buf386, 524288, 128, grid=grid(524288), stream=stream0)
del buf384
del buf385
del buf386
buf389 = reinterpret_tensor(buf374, (65536, 1024), (1024, 1), 0); del buf374 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf387, (65536, 3072), (3072, 1), 0), permute_337, out=buf389)
del permute_337
buf355 = buf320; del buf320 # reuse
buf357 = buf317; del buf317 # reuse
buf391 = buf307; del buf307 # reuse
buf393 = buf305; del buf305 # reuse
buf395 = buf203; del buf203 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_36.run(buf395, buf330, buf353, embedding_3, rsqrt, add_69, buf366, buf389, add_58, buf226, primals_7, buf355, buf357, buf391, buf393, 65536, 1024, grid=grid(65536), stream=stream0)
del add_58
del add_69
buf356 = buf80; del buf80 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf355, buf356, 1, 65536, grid=grid(1), stream=stream0)
buf358 = buf44; del buf44 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf357, buf358, 1, 65536, grid=grid(1), stream=stream0)
buf360 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf359, (1024, 65536), (1, 1024), 0), view_75, out=buf360)
del view_75
buf363 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf362, (4096, 65536), (1, 4096), 0), view_73, out=buf363)
del view_73
buf367 = buf331; del buf331 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf366, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_64, (65536, 1024), (1024, 1), 0), out=buf367)
del getitem_64
buf377 = buf341; del buf341 # reuse
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf373, getitem_59, rsqrt_23, buf377, 512, 131072, grid=grid(512), stream=stream0)
del getitem_59
del rsqrt_23
buf378 = buf42; del buf42 # reuse
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf377, buf378, 1, 512, grid=grid(1), stream=stream0)
buf388 = buf352; del buf352 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf387, (3072, 65536), (1, 3072), 0), view_67, out=buf388)
del view_67
buf390 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf367, buf388, buf390, 4194304, grid=grid(4194304), stream=stream0)
buf392 = buf321; del buf321 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf391, buf392, 1, 65536, grid=grid(1), stream=stream0)
buf394 = buf318; del buf318 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf393, buf394, 1, 65536, grid=grid(1), stream=stream0)
buf396 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf395, (1024, 65536), (1, 1024), 0), view_64, out=buf396)
del view_64
buf397 = reinterpret_tensor(buf362, (65536, 4096), (4096, 1), 0); del buf362 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf395, (65536, 1024), (1024, 1), 0), permute_341, out=buf397)
del permute_341
buf398 = reinterpret_tensor(mm_18, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_18 # reuse
# Topologically Sorted Source Nodes: [relu_4], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf398, buf397, 268435456, grid=grid(268435456), stream=stream0)
del buf397
buf399 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf398, (4096, 65536), (1, 4096), 0), view_62, out=buf399)
del view_62
buf400 = reinterpret_tensor(buf373, (65536, 1024), (1024, 1), 0); del buf373 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf398, (65536, 4096), (4096, 1), 0), permute_345, out=buf400)
del permute_345
buf402 = buf395; del buf395 # reuse
# Topologically Sorted Source Nodes: [rms_norm_20], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf402, buf400, add_56, rsqrt_20, 65536, 1024, grid=grid(65536), stream=stream0)
del add_56
del rsqrt_20
buf403 = buf367; del buf367 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf402, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_55, (65536, 1024), (1024, 1), 0), out=buf403)
buf404 = buf400; del buf400 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf402, (65536, 1024), (1024, 1), 0), permute_349, out=buf404)
del permute_349
buf405 = buf369; del buf369 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_55, buf404, buf405, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_55
buf406 = buf370; del buf370 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf405, buf406, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf408 = reinterpret_tensor(buf359, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf359 # reuse
buf409 = reinterpret_tensor(buf226, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf226 # reuse
buf410 = buf372; del buf372 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_37, permute_38, permute_39, getitem_56, buf406, buf404, buf408, buf409, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf410, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_56
del permute_37
del permute_38
del permute_39
buf413 = buf377; del buf377 # reuse
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf409, getitem_50, rsqrt_19, buf413, 512, 131072, grid=grid(512), stream=stream0)
buf414 = buf308; del buf308 # reuse
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf413, buf414, 1, 512, grid=grid(1), stream=stream0)
buf423 = buf387; del buf387 # reuse
buf422 = reinterpret_tensor(buf423, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_37.run(buf409, primals_7, getitem_50, rsqrt_19, buf422, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_50
del rsqrt_19
buf416 = buf381; del buf381 # reuse
buf417 = buf380; del buf380 # reuse
buf421 = reinterpret_tensor(buf423, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf420 = reinterpret_tensor(buf423, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_13, q_13], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf410, unsqueeze_78, unsqueeze_76, buf408, getitem_49, rsqrt_18, getitem_48, rsqrt_17, buf416, buf417, buf421, buf420, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_48
del getitem_49
del rsqrt_17
del rsqrt_18
del unsqueeze_76
del unsqueeze_78
del buf420
del buf421
del buf422
buf424 = buf388; del buf388 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf423, (3072, 65536), (1, 3072), 0), view_56, out=buf424)
del view_56
buf425 = reinterpret_tensor(buf410, (65536, 1024), (1024, 1), 0); del buf410 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf423, (65536, 3072), (3072, 1), 0), permute_357, out=buf425)
del permute_357
buf426 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf403, buf424, buf426, 4194304, grid=grid(4194304), stream=stream0)
buf429 = buf319; del buf319 # reuse
buf432 = reinterpret_tensor(buf408, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf408 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_38.run(buf429, buf330, buf353, primals_7, buf366, buf389, buf402, buf425, buf432, 67108864, grid=grid(67108864), stream=stream0)
buf434 = reinterpret_tensor(buf398, (65536, 4096), (4096, 1), 0); del buf398 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf432, (65536, 1024), (1024, 1), 0), permute_361, out=buf434)
del permute_361
buf435 = reinterpret_tensor(mm_14, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_14 # reuse
# Topologically Sorted Source Nodes: [relu_3], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf435, buf434, 268435456, grid=grid(268435456), stream=stream0)
del buf434
buf437 = buf389; del buf389 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf435, (65536, 4096), (4096, 1), 0), permute_365, out=buf437)
del permute_365
buf439 = reinterpret_tensor(buf437, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf437 # reuse
# Topologically Sorted Source Nodes: [rms_norm_16], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf439, add_45, buf432, rsqrt_16, 65536, 1024, grid=grid(65536), stream=stream0)
del add_45
del rsqrt_16
buf441 = reinterpret_tensor(buf366, (65536, 1024), (1024, 1), 0); del buf366 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf439, (65536, 1024), (1024, 1), 0), permute_369, out=buf441)
del permute_369
buf442 = buf405; del buf405 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_46, buf441, buf442, 524288, 128, grid=grid(524288), stream=stream0)
buf443 = buf406; del buf406 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf442, buf443, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf445 = reinterpret_tensor(buf353, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf353 # reuse
buf446 = reinterpret_tensor(buf330, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf330 # reuse
buf447 = buf409; del buf409 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_29, permute_30, permute_31, getitem_47, buf443, buf441, buf445, buf446, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf447, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_47
del permute_29
del permute_30
del permute_31
buf453 = buf417; del buf417 # reuse
buf454 = buf416; del buf416 # reuse
buf460 = buf423; del buf423 # reuse
buf458 = reinterpret_tensor(buf460, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf457 = reinterpret_tensor(buf460, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_10, q_10], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf447, unsqueeze_70, unsqueeze_68, buf445, getitem_40, rsqrt_14, getitem_39, rsqrt_13, buf453, buf454, buf458, buf457, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_39
del getitem_40
del rsqrt_13
del rsqrt_14
del unsqueeze_68
del unsqueeze_70
buf459 = reinterpret_tensor(buf460, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_39.run(buf446, primals_7, getitem_41, rsqrt_15, buf459, 524288, 128, grid=grid(524288), stream=stream0)
del buf457
del buf458
del buf459
buf462 = reinterpret_tensor(buf447, (65536, 1024), (1024, 1), 0); del buf447 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf460, (65536, 3072), (3072, 1), 0), permute_377, out=buf462)
del permute_377
buf427 = buf393; del buf393 # reuse
buf430 = buf391; del buf391 # reuse
buf464 = buf357; del buf357 # reuse
buf466 = buf355; del buf355 # reuse
buf468 = buf165; del buf165 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_40.run(buf468, buf402, buf425, embedding_3, rsqrt, add_47, buf439, buf462, add_36, buf188, primals_7, buf427, buf430, buf464, buf466, 65536, 1024, grid=grid(65536), stream=stream0)
del add_36
del add_47
buf428 = buf306; del buf306 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf427, buf428, 1, 65536, grid=grid(1), stream=stream0)
buf431 = buf270; del buf270 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf430, buf431, 1, 65536, grid=grid(1), stream=stream0)
buf433 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf432, (1024, 65536), (1, 1024), 0), view_53, out=buf433)
del view_53
buf436 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf435, (4096, 65536), (1, 4096), 0), view_51, out=buf436)
del view_51
buf440 = buf403; del buf403 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf439, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_46, (65536, 1024), (1024, 1), 0), out=buf440)
del getitem_46
buf450 = buf413; del buf413 # reuse
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_20.run(buf446, getitem_41, rsqrt_15, buf450, 512, 131072, grid=grid(512), stream=stream0)
del getitem_41
del rsqrt_15
buf451 = buf268; del buf268 # reuse
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf450, buf451, 1, 512, grid=grid(1), stream=stream0)
buf461 = buf424; del buf424 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf460, (3072, 65536), (1, 3072), 0), view_45, out=buf461)
del view_45
buf463 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf440, buf461, buf463, 4194304, grid=grid(4194304), stream=stream0)
buf465 = buf232; del buf232 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf464, buf465, 1, 65536, grid=grid(1), stream=stream0)
buf467 = buf229; del buf229 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf466, buf467, 1, 65536, grid=grid(1), stream=stream0)
buf469 = empty_strided_cuda((16, 2), (2, 1), torch.float32)
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_add_select_backward_41.run(buf26, buf28, buf64, buf66, buf102, buf104, buf141, buf177, buf215, buf254, buf292, buf342, buf378, buf414, buf451, buf469, 32, grid=grid(32), stream=stream0)
del buf102
del buf104
del buf141
del buf177
buf470 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf468, (1024, 65536), (1, 1024), 0), view_42, out=buf470)
del view_42
buf471 = reinterpret_tensor(buf435, (65536, 4096), (4096, 1), 0); del buf435 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf468, (65536, 1024), (1024, 1), 0), permute_381, out=buf471)
del permute_381
buf472 = reinterpret_tensor(mm_10, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_10 # reuse
# Topologically Sorted Source Nodes: [relu_2], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf472, buf471, 268435456, grid=grid(268435456), stream=stream0)
del buf471
buf473 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf472, (4096, 65536), (1, 4096), 0), view_40, out=buf473)
del view_40
buf474 = reinterpret_tensor(buf446, (65536, 1024), (1024, 1), 0); del buf446 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf472, (65536, 4096), (4096, 1), 0), permute_385, out=buf474)
del permute_385
buf476 = buf468; del buf468 # reuse
# Topologically Sorted Source Nodes: [rms_norm_12], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf476, buf474, add_34, rsqrt_12, 65536, 1024, grid=grid(65536), stream=stream0)
del add_34
del rsqrt_12
buf477 = buf440; del buf440 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf476, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_37, (65536, 1024), (1024, 1), 0), out=buf477)
buf478 = buf474; del buf474 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf476, (65536, 1024), (1024, 1), 0), permute_389, out=buf478)
del permute_389
buf479 = buf442; del buf442 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_37, buf478, buf479, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_37
buf480 = buf443; del buf443 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf479, buf480, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf482 = reinterpret_tensor(buf432, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf432 # reuse
buf483 = reinterpret_tensor(buf425, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf425 # reuse
buf484 = reinterpret_tensor(buf402, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf402 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_21, permute_22, permute_23, getitem_38, buf480, buf478, buf482, buf483, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf484, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del getitem_38
del permute_21
del permute_22
del permute_23
buf487 = buf450; del buf450 # reuse
buf489 = buf101; del buf101 # reuse
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten.sum, aten._to_copy]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_8.run(buf483, embedding_2, getitem_32, rsqrt_11, buf487, buf489, 512, 131072, grid=grid(512), stream=stream0)
del embedding_2
buf488 = buf66; del buf66 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf487, buf488, 1, 512, grid=grid(1), stream=stream0)
buf490 = buf64; del buf64 # reuse
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf489, buf490, 1, 512, grid=grid(1), stream=stream0)
buf499 = buf460; del buf460 # reuse
buf498 = reinterpret_tensor(buf499, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_mul_pow_sum_42.run(buf483, primals_7, getitem_32, rsqrt_11, buf498, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_32
del rsqrt_11
buf492 = buf454; del buf454 # reuse
buf493 = buf453; del buf453 # reuse
buf497 = reinterpret_tensor(buf499, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf496 = reinterpret_tensor(buf499, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_7, q_7], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf484, unsqueeze_62, unsqueeze_60, buf482, getitem_31, rsqrt_10, getitem_30, rsqrt_9, buf492, buf493, buf497, buf496, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_30
del getitem_31
del rsqrt_10
del rsqrt_9
del unsqueeze_60
del unsqueeze_62
del buf496
del buf497
del buf498
buf500 = buf461; del buf461 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf499, (3072, 65536), (1, 3072), 0), view_33, out=buf500)
del view_33
buf501 = reinterpret_tensor(buf484, (65536, 1024), (1024, 1), 0); del buf484 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf499, (65536, 3072), (3072, 1), 0), permute_397, out=buf501)
del permute_397
buf502 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf477, buf500, buf502, 4194304, grid=grid(4194304), stream=stream0)
buf507 = reinterpret_tensor(buf482, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf482 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_add_mul_43.run(buf476, buf501, primals_7, buf507, 67108864, grid=grid(67108864), stream=stream0)
buf509 = reinterpret_tensor(buf472, (65536, 4096), (4096, 1), 0); del buf472 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf507, (65536, 1024), (1024, 1), 0), permute_401, out=buf509)
del permute_401
buf510 = reinterpret_tensor(mm_6, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_6 # reuse
# Topologically Sorted Source Nodes: [relu_1], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf510, buf509, 268435456, grid=grid(268435456), stream=stream0)
del buf509
buf512 = buf478; del buf478 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf510, (65536, 4096), (4096, 1), 0), permute_405, out=buf512)
del permute_405
buf514 = reinterpret_tensor(buf512, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf512 # reuse
# Topologically Sorted Source Nodes: [rms_norm_8], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_14.run(buf514, add_22, buf507, rsqrt_8, 65536, 1024, grid=grid(65536), stream=stream0)
del add_22
del rsqrt_8
buf516 = buf188; del buf188 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf514, (65536, 1024), (1024, 1), 0), permute_409, out=buf516)
del permute_409
buf517 = buf479; del buf479 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_28, buf516, buf517, 524288, 128, grid=grid(524288), stream=stream0)
buf518 = buf480; del buf480 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf517, buf518, 8, 65536, grid=grid(8, 65536), stream=stream0)
buf520 = buf445; del buf445 # reuse
buf521 = reinterpret_tensor(buf441, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf441 # reuse
buf522 = reinterpret_tensor(buf404, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf404 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_tem_fused_zeros_7.run(permute_13, permute_14, permute_15, getitem_29, buf518, buf516, buf520, buf521, clamp_max_2, unsqueeze_9, convert_element_type_6, clone_10, clamp_max_3, unsqueeze_13, convert_element_type_8, clone_13, cumsum, buf522, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del buf516
del clamp_max_2
del clamp_max_3
del clone_10
del clone_13
del convert_element_type_6
del convert_element_type_8
del getitem_29
del permute_13
del permute_14
del permute_15
buf530 = buf493; del buf493 # reuse
buf531 = buf492; del buf492 # reuse
buf537 = buf499; del buf499 # reuse
buf535 = reinterpret_tensor(buf537, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf534 = reinterpret_tensor(buf537, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_4, q_4], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf522, unsqueeze_54, unsqueeze_52, buf520, getitem_22, rsqrt_6, getitem_21, rsqrt_5, buf530, buf531, buf535, buf534, 524288, 128, grid=grid(524288), stream=stream0)
del buf520
del getitem_21
del getitem_22
del rsqrt_5
del rsqrt_6
del unsqueeze_52
del unsqueeze_54
buf592 = empty_strided_cuda((50257, 1024), (1024, 1), torch.float32)
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf592, 51463168, grid=grid(51463168), stream=stream0)
buf536 = reinterpret_tensor(buf537, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [loss, v_4], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45.run(buf521, primals_7, getitem_23, rsqrt_7, primals_1, buf59, buf536, buf592, 524288, 128, grid=grid(524288), stream=stream0)
del buf534
del buf535
del buf536
buf539 = reinterpret_tensor(buf59, (65536, 1024), (1024, 1), 0); del buf59 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf537, (65536, 3072), (3072, 1), 0), permute_417, out=buf539)
del permute_417
buf543 = buf429; del buf429 # reuse
buf546 = reinterpret_tensor(buf522, (1, 65536, 1024), (67108864, 1024, 1), 0); del buf522 # reuse
buf503 = buf466; del buf466 # reuse
buf505 = buf464; del buf464 # reuse
buf541 = buf430; del buf430 # reuse
buf544 = buf427; del buf427 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_mul_sum_46.run(buf543, buf439, buf462, primals_7, buf476, buf501, buf514, buf539, embedding_3, rsqrt, add_24, add_12, buf546, buf503, buf505, buf541, buf544, 65536, 1024, grid=grid(65536), stream=stream0)
del add_12
del add_24
del buf439
del buf462
del buf476
del buf501
buf504 = buf451; del buf451 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf503, buf504, 1, 65536, grid=grid(1), stream=stream0)
del buf503
buf506 = buf414; del buf414 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf505, buf506, 1, 65536, grid=grid(1), stream=stream0)
del buf505
buf508 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf507, (1024, 65536), (1, 1024), 0), view_30, out=buf508)
del view_30
buf511 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf510, (4096, 65536), (1, 4096), 0), view_28, out=buf511)
del view_28
buf515 = buf477; del buf477 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf514, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_28, (65536, 1024), (1024, 1), 0), out=buf515)
del getitem_28
buf525 = buf489; del buf489 # reuse
buf527 = buf487; del buf487 # reuse
# Topologically Sorted Source Nodes: [v_4], Original ATen: [aten.mul, aten.sum, aten._to_copy]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_8.run(buf521, embedding_1, getitem_23, rsqrt_7, buf525, buf527, 512, 131072, grid=grid(512), stream=stream0)
del embedding_1
del getitem_23
del rsqrt_7
buf526 = buf378; del buf378 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf525, buf526, 1, 512, grid=grid(1), stream=stream0)
buf528 = buf342; del buf342 # reuse
# Topologically Sorted Source Nodes: [v_4], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf527, buf528, 1, 512, grid=grid(1), stream=stream0)
buf538 = buf500; del buf500 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf537, (3072, 65536), (1, 3072), 0), view_21, out=buf538)
del view_21
buf540 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf515, buf538, buf540, 4194304, grid=grid(4194304), stream=stream0)
buf542 = buf292; del buf292 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf541, buf542, 1, 65536, grid=grid(1), stream=stream0)
del buf541
buf545 = buf28; del buf28 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf544, buf545, 1, 65536, grid=grid(1), stream=stream0)
buf547 = empty_strided_cuda((1024, 4096), (4096, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf546, (1024, 65536), (1, 1024), 0), view_18, out=buf547)
del view_18
buf548 = reinterpret_tensor(buf510, (65536, 4096), (4096, 1), 0); del buf510 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf546, (65536, 1024), (1024, 1), 0), permute_421, out=buf548)
del permute_421
buf549 = reinterpret_tensor(mm_2, (1, 65536, 4096), (268435456, 4096, 1), 0); del mm_2 # reuse
# Topologically Sorted Source Nodes: [relu], Original ATen: [aten.threshold_backward, aten.relu, aten.pow, aten.mul]
stream0 = get_raw_stream(0)
triton_poi_fused_mul_pow_relu_threshold_backward_3.run(buf549, buf548, 268435456, grid=grid(268435456), stream=stream0)
del buf548
buf550 = empty_strided_cuda((4096, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf549, (4096, 65536), (1, 4096), 0), view_16, out=buf550)
del view_16
buf551 = reinterpret_tensor(buf521, (65536, 1024), (1024, 1), 0); del buf521 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf549, (65536, 4096), (4096, 1), 0), permute_425, out=buf551)
del buf549
del permute_425
buf553 = buf546; del buf546 # reuse
# Topologically Sorted Source Nodes: [rms_norm_4], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_div_mul_pow_sum_4.run(buf553, buf551, add_10, rsqrt_4, 65536, 1024, grid=grid(65536), stream=stream0)
del add_10
del rsqrt_4
buf554 = buf515; del buf515 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf553, (1024, 65536), (1, 1024), 0), reinterpret_tensor(getitem_19, (65536, 1024), (1024, 1), 0), out=buf554)
buf555 = buf551; del buf551 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf553, (65536, 1024), (1024, 1), 0), permute_429, out=buf555)
del permute_429
buf556 = buf517; del buf517 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_per_fused_zeros_5.run(getitem_19, buf555, buf556, 524288, 128, grid=grid(524288), stream=stream0)
del getitem_19
buf557 = buf518; del buf518 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
triton_poi_fused_zeros_6.run(buf556, buf557, 8, 65536, grid=grid(8, 65536), stream=stream0)
del buf556
buf559 = reinterpret_tensor(buf514, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf514 # reuse
buf560 = reinterpret_tensor(buf507, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf507 # reuse
buf561 = reinterpret_tensor(buf539, (1, 8, 65536, 128), (67108864, 128, 1024, 1), 0); del buf539 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
stream0 = get_raw_stream(0)
torch.save(
(permute_5, permute_6, permute_7, getitem_20, buf557, buf555, buf559, buf560, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf561),
"intermediate_3.2.pt",
)
triton_tem_fused_zeros_7.run(permute_5, permute_6, permute_7, getitem_20, buf557, buf555, buf559, buf560, clamp_max, unsqueeze_9, convert_element_type_2, clone_4, clamp_max_1, unsqueeze_13, convert_element_type_4, clone_7, cumsum, buf561, grid=torch._inductor.kernel.flex_attention.flex_attention_backward_grid(1, 8, 65536, 128, 8, 65536, meta0), stream=stream0)
del buf555
del buf557
del clamp_max
del clamp_max_1
del clone_4
del clone_7
del convert_element_type_2
del convert_element_type_4
del cumsum
del getitem_20
del permute_5
del permute_6
del permute_7
del unsqueeze_13
del unsqueeze_9
buf564 = buf527; del buf527 # reuse
buf566 = buf525; del buf525 # reuse
# Topologically Sorted Source Nodes: [v_1], Original ATen: [aten.mul, aten.sum, aten._to_copy]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_mul_sum_8.run(buf560, embedding, getitem_14, rsqrt_3, buf564, buf566, 512, 131072, grid=grid(512), stream=stream0)
del embedding
buf565 = buf26; del buf26 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf564, buf565, 1, 512, grid=grid(1), stream=stream0)
del buf564
buf567 = buf254; del buf254 # reuse
# Topologically Sorted Source Nodes: [v_1], Original ATen: [aten._to_copy, aten.mul, aten.sum]
stream0 = get_raw_stream(0)
triton_per_fused_mul_sum_9.run(buf566, buf567, 1, 512, grid=grid(1), stream=stream0)
del buf566
buf595 = empty_strided_cuda((50257, 1024), (1024, 1), torch.float32)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf595, 51463168, grid=grid(51463168), stream=stream0)
buf576 = buf537; del buf537 # reuse
buf575 = reinterpret_tensor(buf576, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 2048) # alias
# Topologically Sorted Source Nodes: [loss, v_1], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_47.run(buf560, primals_7, getitem_14, rsqrt_3, primals_1, buf97, buf575, buf595, 524288, 128, grid=grid(524288), stream=stream0)
del buf560
del buf97
del getitem_14
del rsqrt_3
buf569 = buf531; del buf531 # reuse
buf570 = buf530; del buf530 # reuse
buf574 = reinterpret_tensor(buf576, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 1024) # alias
buf573 = reinterpret_tensor(buf576, (1, 65536, 8, 128), (201326592, 3072, 128, 1), 0) # alias
# Topologically Sorted Source Nodes: [k_1, q_1], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11.run(buf561, unsqueeze_46, unsqueeze_44, buf559, getitem_13, rsqrt_2, getitem_12, rsqrt_1, buf569, buf570, buf574, buf573, 524288, 128, grid=grid(524288), stream=stream0)
del buf559
del buf569
del buf570
del getitem_12
del getitem_13
del rsqrt_1
del rsqrt_2
del unsqueeze_44
del unsqueeze_46
del buf573
del buf574
del buf575
buf577 = buf538; del buf538 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf576, (3072, 65536), (1, 3072), 0), view_9, out=buf577)
del view_9
buf578 = reinterpret_tensor(buf561, (65536, 1024), (1024, 1), 0); del buf561 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mm]
extern_kernels.mm(reinterpret_tensor(buf576, (65536, 3072), (3072, 1), 0), permute_437, out=buf578)
del buf576
del permute_437
buf579 = empty_strided_cuda((4, 1024, 1024), (1048576, 1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_select_backward_12.run(buf554, buf577, buf579, 4194304, grid=grid(4194304), stream=stream0)
del buf554
del buf577
buf586 = empty_strided_cuda((50257, 1024), (1024, 1), torch.float32)
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf586, 51463168, grid=grid(51463168), stream=stream0)
buf580 = buf544; del buf544 # reuse
# Topologically Sorted Source Nodes: [loss, x], Original ATen: [aten.nll_loss_forward, aten._to_copy, aten.mul, aten.add, aten.sum, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_per_fused__to_copy_add_embedding_dense_backward_mul_nll_loss_forward_sum_48.run(buf543, buf553, buf578, primals_7, embedding_3, rsqrt, primals_1, buf580, buf586, 65536, 1024, grid=grid(65536), stream=stream0)
del buf543
del buf553
del buf578
del embedding_3
del rsqrt
buf581 = buf215; del buf215 # reuse
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
stream0 = get_raw_stream(0)
triton_red_fused__to_copy_add_mul_sum_17.run(buf580, buf581, 1, 65536, grid=grid(1), stream=stream0)
del buf580
buf582 = empty_strided_cuda((80, ), (1, ), torch.float32)
buf583 = buf582; del buf582 # reuse
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add, aten.slice_backward]
stream0 = get_raw_stream(0)
triton_poi_fused__to_copy_add_select_backward_slice_backward_49.run(buf583, buf469, buf488, buf490, buf526, buf528, buf565, buf567, buf323, buf356, buf358, buf392, buf394, buf428, buf431, buf465, buf467, buf504, buf506, buf542, buf545, buf581, buf195, buf234, buf272, 80, grid=grid(80), stream=stream0)
del buf195
del buf234
del buf272
del buf323
del buf356
del buf358
del buf392
del buf394
del buf428
del buf431
del buf465
del buf467
del buf469
del buf488
del buf490
del buf504
del buf506
del buf526
del buf528
del buf542
del buf545
del buf565
del buf567
del buf581
buf588 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_50.run(buf586, buf588, 51463168, grid=grid(51463168), stream=stream0)
buf589 = buf586; del buf586 # reuse
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44.run(buf589, 51463168, grid=grid(51463168), stream=stream0)
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_51.run(primals_1, buf21, primals_7, buf483, buf589, 67108864, grid=grid(67108864), stream=stream0)
del buf21
del buf483
del primals_1
del primals_7
buf591 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_50.run(buf589, buf591, 51463168, grid=grid(51463168), stream=stream0)
del buf589
buf594 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_50.run(buf592, buf594, 51463168, grid=grid(51463168), stream=stream0)
del buf592
buf597 = empty_strided_cuda((50257, 1024), (1024, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [], Original ATen: [aten.embedding_dense_backward]
stream0 = get_raw_stream(0)
triton_poi_fused_embedding_dense_backward_50.run(buf595, buf597, 51463168, grid=grid(51463168), stream=stream0)
del buf595
return (None, buf597, buf594, buf591, None, buf588, buf583, buf579, None, None, buf550, buf547, buf540, None, None, buf511, buf508, buf502, None, None, buf473, buf470, buf463, None, None, buf436, buf433, buf426, None, None, buf399, buf396, buf390, None, None, buf363, buf360, buf354, None, None, buf327, buf324, buf313, buf310, buf304, None, None, buf277, buf274, buf266, None, None, buf239, buf236, buf227, None, None, buf200, buf197, buf189, None, None, buf162, buf159, buf153, None, None, buf126, buf123, buf116, None, None, buf87, buf84, buf78, None, None, buf49, buf46, buf40, None, None, buf11, buf8, buf5, None, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((65536, ), (1, ), device='cuda:0', dtype=torch.int32)
primals_7 = rand_strided((80, ), (1, ), device='cuda:0', dtype=torch.float32)
primals_86 = rand_strided((65536, ), (1, ), device='cuda:0', dtype=torch.int64)
embedding = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
embedding_1 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
embedding_2 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
cumsum = rand_strided((65536, ), (1, ), device='cuda:0', dtype=torch.int64)
unsqueeze_9 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32)
unsqueeze_13 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32)
clamp_max = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
clamp_max_1 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
convert_element_type_2 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
clone_4 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32)
convert_element_type_4 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
clone_7 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32)
clamp_max_2 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
clamp_max_3 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
convert_element_type_6 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
clone_10 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32)
convert_element_type_8 = rand_strided((1, 1, 512), (512, 512, 1), device='cuda:0', dtype=torch.int32)
clone_13 = rand_strided((1, 1, 512, 512), (262144, 262144, 512, 1), device='cuda:0', dtype=torch.int32)
embedding_3 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_9 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_12 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_13 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_14 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_1 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_2 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_44 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_46 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_3 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_5 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_6 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_7 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_19 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_20 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_10 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_4 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_16 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_2 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_18 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_12 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_21 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_21 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_22 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_23 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_5 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_6 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_52 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_54 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_7 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_13 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_14 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_15 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_28 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_29 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_22 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_8 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_28 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_6 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_30 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_24 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_33 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_30 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_31 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_32 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_9 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_10 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_60 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_62 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_11 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_21 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_22 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_23 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_37 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_38 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_34 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_12 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_40 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_10 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_42 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_36 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_45 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_39 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_40 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_41 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_13 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_14 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_68 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_70 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_15 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_29 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_30 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_31 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_46 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_47 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_45 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_16 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_51 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_14 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_53 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_47 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_56 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_48 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_49 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_50 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_17 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_18 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_76 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_78 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_19 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_37 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_38 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_39 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_55 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_56 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_56 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_20 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_62 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_18 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_64 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_58 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_67 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_57 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_58 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_59 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_21 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_22 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_84 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_86 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_23 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_45 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_46 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_47 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_64 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_65 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_67 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_24 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_73 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_22 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_75 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_69 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_78 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_66 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_67 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_68 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_25 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_26 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_92 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_94 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_27 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_53 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_54 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_55 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_73 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_74 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_78 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_28 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_84 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_26 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_86 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_80 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
add_81 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_29 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_88 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_28 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_90 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_83 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_93 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_75 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_76 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_77 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_30 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_31 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_100 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_102 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_32 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_63 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_64 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_65 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_82 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_83 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_92 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_33 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_99 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_32 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_101 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_95 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_104 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_84 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_85 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_86 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_34 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_35 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_108 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_110 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_36 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_71 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_72 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_73 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_91 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_92 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_104 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_37 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_110 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_36 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_112 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_107 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_115 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_93 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_94 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_95 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_38 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_39 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_116 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_118 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_40 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_79 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_80 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_81 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_100 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_101 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_116 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_41 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_121 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_40 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_123 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_119 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_126 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_102 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_103 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_104 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_42 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_43 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_124 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_126 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_44 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_87 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_88 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_89 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_109 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_110 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_128 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_45 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_132 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_44 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_134 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_130 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_137 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_111 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_112 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_113 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_46 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_47 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_132 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_134 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_48 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_95 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_96 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_97 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_118 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_119 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_139 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_49 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_143 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_48 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_145 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_141 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_148 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_120 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_121 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_122 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_50 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_51 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_140 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_142 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_52 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_103 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_104 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_105 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_127 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_128 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_151 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_53 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_155 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_52 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_157 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_153 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_160 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_129 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_130 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_131 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_54 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_55 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_148 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_150 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_56 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_111 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_112 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_113 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_136 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_137 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_163 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_57 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_167 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_56 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_169 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_165 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
view_172 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_138 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_139 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_140 = rand_strided((1, 65536, 8, 128), (201326592, 3072, 128, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_58 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
rsqrt_59 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_156 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
unsqueeze_158 = rand_strided((1, 65536, 1, 64), (16777216, 64, 64, 1), device='cuda:0', dtype=torch.float32)
rsqrt_60 = rand_strided((1, 65536, 8, 1), (524288, 8, 1, 1), device='cuda:0', dtype=torch.float32)
permute_119 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_120 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_121 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_145 = rand_strided((1, 8, 65536, 128), (67108864, 128, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
getitem_146 = rand_strided((1, 8, 65536), (524288, 65536, 1), device='cuda:0', dtype=torch.float32)
add_175 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_61 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_179 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_60 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
view_181 = rand_strided((65536, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
add_177 = rand_strided((1, 65536, 1024), (67108864, 1024, 1), device='cuda:0', dtype=torch.bfloat16)
rsqrt_62 = rand_strided((1, 65536, 1), (65536, 1, 1), device='cuda:0', dtype=torch.float32)
view_183 = rand_strided((65536, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
mm_62 = rand_strided((65536, 50304), (50304, 1), device='cuda:0', dtype=torch.bfloat16)
amax = rand_strided((65536, 1), (1, 1), device='cuda:0', dtype=torch.float32)
log = rand_strided((65536, 1), (1, 1), device='cuda:0', dtype=torch.float32)
convert_element_type_324 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
permute_129 = rand_strided((50304, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_133 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_137 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_141 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_149 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_153 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_157 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_161 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_169 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_173 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_177 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_181 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_189 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_193 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_197 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_201 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_209 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_213 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_217 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_221 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_229 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_233 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_237 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_241 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_249 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_253 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_257 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_261 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_269 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_273 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_277 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_281 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_289 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_293 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_297 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_301 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_305 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_309 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_317 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_321 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_325 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_329 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_337 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_341 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_345 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_349 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_357 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_361 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_365 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_369 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_377 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_381 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_385 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_389 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_397 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_401 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_405 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_409 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_417 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_421 = rand_strided((1024, 4096), (4096, 1), device='cuda:0', dtype=torch.bfloat16)
permute_425 = rand_strided((4096, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_429 = rand_strided((1024, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
permute_437 = rand_strided((3072, 1024), (1024, 1), device='cuda:0', dtype=torch.bfloat16)
tangents_1 = rand_strided((), (), device='cuda:0', dtype=torch.float32)
fn = lambda: call([primals_1, primals_7, primals_86, embedding, embedding_1, embedding_2, cumsum, unsqueeze_9, unsqueeze_13, clamp_max, clamp_max_1, convert_element_type_2, clone_4, convert_element_type_4, clone_7, clamp_max_2, clamp_max_3, convert_element_type_6, clone_10, convert_element_type_8, clone_13, embedding_3, rsqrt, view_9, getitem_12, getitem_13, getitem_14, rsqrt_1, rsqrt_2, unsqueeze_44, unsqueeze_46, rsqrt_3, permute_5, permute_6, permute_7, getitem_19, getitem_20, add_10, rsqrt_4, view_16, mm_2, view_18, add_12, view_21, getitem_21, getitem_22, getitem_23, rsqrt_5, rsqrt_6, unsqueeze_52, unsqueeze_54, rsqrt_7, permute_13, permute_14, permute_15, getitem_28, getitem_29, add_22, rsqrt_8, view_28, mm_6, view_30, add_24, view_33, getitem_30, getitem_31, getitem_32, rsqrt_9, rsqrt_10, unsqueeze_60, unsqueeze_62, rsqrt_11, permute_21, permute_22, permute_23, getitem_37, getitem_38, add_34, rsqrt_12, view_40, mm_10, view_42, add_36, view_45, getitem_39, getitem_40, getitem_41, rsqrt_13, rsqrt_14, unsqueeze_68, unsqueeze_70, rsqrt_15, permute_29, permute_30, permute_31, getitem_46, getitem_47, add_45, rsqrt_16, view_51, mm_14, view_53, add_47, view_56, getitem_48, getitem_49, getitem_50, rsqrt_17, rsqrt_18, unsqueeze_76, unsqueeze_78, rsqrt_19, permute_37, permute_38, permute_39, getitem_55, getitem_56, add_56, rsqrt_20, view_62, mm_18, view_64, add_58, view_67, getitem_57, getitem_58, getitem_59, rsqrt_21, rsqrt_22, unsqueeze_84, unsqueeze_86, rsqrt_23, permute_45, permute_46, permute_47, getitem_64, getitem_65, add_67, rsqrt_24, view_73, mm_22, view_75, add_69, view_78, getitem_66, getitem_67, getitem_68, rsqrt_25, rsqrt_26, unsqueeze_92, unsqueeze_94, rsqrt_27, permute_53, permute_54, permute_55, getitem_73, getitem_74, add_78, rsqrt_28, view_84, mm_26, view_86, add_80, add_81, rsqrt_29, view_88, mm_28, view_90, add_83, view_93, getitem_75, getitem_76, getitem_77, rsqrt_30, rsqrt_31, unsqueeze_100, unsqueeze_102, rsqrt_32, permute_63, permute_64, permute_65, getitem_82, getitem_83, add_92, rsqrt_33, view_99, mm_32, view_101, add_95, view_104, getitem_84, getitem_85, getitem_86, rsqrt_34, rsqrt_35, unsqueeze_108, unsqueeze_110, rsqrt_36, permute_71, permute_72, permute_73, getitem_91, getitem_92, add_104, rsqrt_37, view_110, mm_36, view_112, add_107, view_115, getitem_93, getitem_94, getitem_95, rsqrt_38, rsqrt_39, unsqueeze_116, unsqueeze_118, rsqrt_40, permute_79, permute_80, permute_81, getitem_100, getitem_101, add_116, rsqrt_41, view_121, mm_40, view_123, add_119, view_126, getitem_102, getitem_103, getitem_104, rsqrt_42, rsqrt_43, unsqueeze_124, unsqueeze_126, rsqrt_44, permute_87, permute_88, permute_89, getitem_109, getitem_110, add_128, rsqrt_45, view_132, mm_44, view_134, add_130, view_137, getitem_111, getitem_112, getitem_113, rsqrt_46, rsqrt_47, unsqueeze_132, unsqueeze_134, rsqrt_48, permute_95, permute_96, permute_97, getitem_118, getitem_119, add_139, rsqrt_49, view_143, mm_48, view_145, add_141, view_148, getitem_120, getitem_121, getitem_122, rsqrt_50, rsqrt_51, unsqueeze_140, unsqueeze_142, rsqrt_52, permute_103, permute_104, permute_105, getitem_127, getitem_128, add_151, rsqrt_53, view_155, mm_52, view_157, add_153, view_160, getitem_129, getitem_130, getitem_131, rsqrt_54, rsqrt_55, unsqueeze_148, unsqueeze_150, rsqrt_56, permute_111, permute_112, permute_113, getitem_136, getitem_137, add_163, rsqrt_57, view_167, mm_56, view_169, add_165, view_172, getitem_138, getitem_139, getitem_140, rsqrt_58, rsqrt_59, unsqueeze_156, unsqueeze_158, rsqrt_60, permute_119, permute_120, permute_121, getitem_145, getitem_146, add_175, rsqrt_61, view_179, mm_60, view_181, add_177, rsqrt_62, view_183, mm_62, amax, log, convert_element_type_324, permute_129, permute_133, permute_137, permute_141, permute_149, permute_153, permute_157, permute_161, permute_169, permute_173, permute_177, permute_181, permute_189, permute_193, permute_197, permute_201, permute_209, permute_213, permute_217, permute_221, permute_229, permute_233, permute_237, permute_241, permute_249, permute_253, permute_257, permute_261, permute_269, permute_273, permute_277, permute_281, permute_289, permute_293, permute_297, permute_301, permute_305, permute_309, permute_317, permute_321, permute_325, permute_329, permute_337, permute_341, permute_345, permute_349, permute_357, permute_361, permute_365, permute_369, permute_377, permute_381, permute_385, permute_389, permute_397, permute_401, permute_405, permute_409, permute_417, permute_421, permute_425, permute_429, permute_437, tangents_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
This file has been truncated, but you can view the full file.
# AOT ID: ['0_backward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import (
grid,
split_scan_grid,
grid_combo_kernels,
start_graph,
end_graph,
cooperative_reduction_grid,
)
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import torch._inductor.kernel.flex_attention
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/hh/chhrb4fk5xruv6uuerqegmi4rbdnnzuq2z23xh2gwuvych6si5w7.py
# Topologically Sorted Source Nodes: [loss, logits, mul_176, square_16, add_116, rsqrt, mul_177], Original ATen: [aten.nll_loss_backward, aten.nll_loss_forward, aten._to_copy, aten.mul, aten.pow, aten.add, aten.rsqrt, aten._log_softmax, aten._log_softmax_backward_data]
# Source node to ATen node mapping:
# add_116 => add_179
# logits => convert_element_type_323
# loss => full_default_12, full_default_13, sub_4, sub_5
# mul_176 => mul_239
# mul_177 => mul_240
# rsqrt => rsqrt_63
# square_16 => pow_80
# Graph fragment:
# %div_2 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%tangents_1, %convert_element_type_324), kwargs = {})
# %ne_3 : [num_users=2] = call_function[target=torch.ops.aten.ne.Scalar](args = (%unsqueeze_164, -100), kwargs = {})
# %full_default_12 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_6 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %unsqueeze_164, %full_default_12), kwargs = {})
# %scatter_upon_const_tensor : [num_users=1] = call_function[target=torch._inductor.fx_passes.post_grad.scatter_upon_const_tensor](args = (), kwargs = {shape: [65536, 50304], background_val: 0, dtype: torch.float32, dim: 1, selector: %where_6, val: -1.0})
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_7 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_3, %div_2, %full_default_13), kwargs = {})
# %mul_241 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%scatter_upon_const_tensor, %where_7), kwargs = {})
# %convert_element_type_323 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_62, torch.float32), kwargs = {})
# %mul_239 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_323, 15), kwargs = {})
# %pow_80 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 2), kwargs = {})
# %add_179 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%pow_80, 225), kwargs = {})
# %rsqrt_63 : [num_users=3] = call_function[target=torch.ops.aten.rsqrt.default](args = (%add_179,), kwargs = {})
# %mul_240 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_239, %rsqrt_63), kwargs = {})
# %sub_4 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_240, %amax), kwargs = {})
# %sub_5 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub_4, %log), kwargs = {})
# %exp_1 : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub_5,), kwargs = {})
# %sum_10 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_241, [1], True), kwargs = {})
# %mul_242 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%exp_1, %sum_10), kwargs = {})
# %sub_6 : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%mul_241, %mul_242), kwargs = {})
# %mul_243 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %mul_239), kwargs = {})
# %mul_244 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%sub_6, %rsqrt_63), kwargs = {})
# %pow_81 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%rsqrt_63, 3), kwargs = {})
# %mul_245 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%mul_243, -0.5), kwargs = {})
# %mul_246 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_245, %pow_81), kwargs = {})
# %pow_82 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_323, 1.0), kwargs = {})
# %mul_247 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_82, 2.0), kwargs = {})
# %mul_248 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_246, %mul_247), kwargs = {})
# %mul_249 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_244, 15), kwargs = {})
# %add_180 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_248, %mul_249), kwargs = {})
# %convert_element_type_325 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_180, torch.bfloat16), kwargs = {})
triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0 = async_compile.triton('triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 65536, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'xnumel': 'i64', 'r0_numel': 'i64', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 8, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__log_softmax__log_softmax_backward_data__to_copy_add_mul_nll_loss_backward_nll_loss_forward_pow_rsqrt_0(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 65536
r0_numel = 50304
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0).to(tl.int64) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None].to(tl.int64)
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :].to(tl.int64)
rbase = r0_base
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last')
tmp10 = tl.load(in_ptr1 + (0))
tmp11 = tl.broadcast_to(tmp10, [XBLOCK, R0_BLOCK])
tmp12 = tl.load(in_ptr2 + (0))
tmp13 = tl.broadcast_to(tmp12, [XBLOCK, R0_BLOCK])
_tmp18 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp1 = tl.full([1, 1], -100, tl.int64)
tmp2 = tmp0 != tmp1
tmp3 = tl.full([1, 1], 0, tl.int64)
tmp4 = tl.where(tmp2, tmp0, tmp3)
tmp5 = r0_1
tmp6 = tmp4 == tmp5
tmp7 = -1.0
tmp8 = 0.0
tmp9 = tl.where(tmp6, tmp7, tmp8)
tmp14 = (tmp11 / tmp13)
tmp15 = tl.where(tmp2, tmp14, tmp8)
tmp16 = tmp9 * tmp15
tmp17 = tl.broadcast_to(tmp16, [XBLOCK, R0_BLOCK])
tmp19 = _tmp18 + tmp17
_tmp18 = tl.where(r0_mask, tmp19, _tmp18)
tmp18 = tl.sum(_tmp18, 1)[:, None]
tmp29 = tl.load(in_ptr1 + (0))
tmp30 = tl.broadcast_to(tmp29, [XBLOCK, R0_BLOCK])
tmp31 = tl.load(in_ptr2 + (0))
tmp32 = tl.broadcast_to(tmp31, [XBLOCK, R0_BLOCK])
tmp45 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp47 = tl.load(in_ptr4 + (x0), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp36 = tl.load(in_out_ptr0 + (r0_1 + 50304*x0), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp20 = tl.full([1, 1], -100, tl.int64)
tmp21 = tmp0 != tmp20
tmp22 = tl.full([1, 1], 0, tl.int64)
tmp23 = tl.where(tmp21, tmp0, tmp22)
tmp24 = r0_1
tmp25 = tmp23 == tmp24
tmp26 = -1.0
tmp27 = 0.0
tmp28 = tl.where(tmp25, tmp26, tmp27)
tmp33 = (tmp30 / tmp32)
tmp34 = tl.where(tmp21, tmp33, tmp27)
tmp35 = tmp28 * tmp34
tmp37 = tmp36.to(tl.float32)
tmp38 = 15.0
tmp39 = tmp37 * tmp38
tmp40 = tmp37 * tmp37
tmp41 = 225.0
tmp42 = tmp40 + tmp41
tmp43 = libdevice.rsqrt(tmp42)
tmp44 = tmp39 * tmp43
tmp46 = tmp44 - tmp45
tmp48 = tmp46 - tmp47
tmp49 = tl_math.exp(tmp48)
tmp50 = tmp49 * tmp18
tmp51 = tmp35 - tmp50
tmp52 = tmp51 * tmp39
tmp53 = -0.5
tmp54 = tmp52 * tmp53
tmp55 = tmp43 * tmp43
tmp56 = tmp55 * tmp43
tmp57 = tmp54 * tmp56
tmp58 = 2.0
tmp59 = tmp37 * tmp58
tmp60 = tmp57 * tmp59
tmp61 = tmp51 * tmp43
tmp62 = tmp61 * tmp38
tmp63 = tmp60 + tmp62
tmp64 = tmp63.to(tl.float32)
tl.store(in_out_ptr0 + (r0_1 + 50304*x0), tmp64, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/tn/ctnjpswwi3gqhcnwhzij5s7dx5ph7fovddyn7g76xfykplagbfrm.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy]
# Source node to ATen node mapping:
# Graph fragment:
# %convert_element_type_330 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mm_63, torch.float32), kwargs = {})
triton_poi_fused__to_copy_1 = async_compile.triton('triton_poi_fused__to_copy_1', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_1', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_1(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 51511296
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tl.store(out_ptr0 + (x0), tmp1, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/nm/cnml7os74h3shqb5iuld3alyjqrwei2qthuqadrqzek7sgkvlydx.py
# Topologically Sorted Source Nodes: [x_144], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# x_144 => convert_element_type_318
# Graph fragment:
# %convert_element_type_331 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_184, torch.float32), kwargs = {})
# %convert_element_type_318 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_177, torch.float32), kwargs = {})
# %mul_250 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %convert_element_type_318), kwargs = {})
# %mul_251 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_331, %rsqrt_62), kwargs = {})
# %sum_11 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_250, [2], True), kwargs = {})
# %div_3 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_20, 1024), kwargs = {})
# %pow_84 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_318, 1.0), kwargs = {})
# %mul_254 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_84, 2.0), kwargs = {})
# %mul_255 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_3, %mul_254), kwargs = {})
# %add_181 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_251, %mul_255), kwargs = {})
# %convert_element_type_332 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_181, torch.bfloat16), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_2 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_2', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_2', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_2(in_out_ptr0, in_ptr0, in_ptr1, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp9 = tmp1 * tmp8
tmp10 = -0.5
tmp11 = tmp7 * tmp10
tmp12 = tmp8 * tmp8
tmp13 = tmp12 * tmp8
tmp14 = tmp11 * tmp13
tmp15 = 0.0009765625
tmp16 = tmp14 * tmp15
tmp17 = 2.0
tmp18 = tmp3 * tmp17
tmp19 = tmp16 * tmp18
tmp20 = tmp9 + tmp19
tmp21 = tmp20.to(tl.float32)
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp21, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/62/c62mj7hc5xoylhivn7ul5ogozxwe57duoiwwckqoeymqoszlcalf.py
# Topologically Sorted Source Nodes: [relu_15], Original ATen: [aten.relu, aten.pow, aten.mul, aten.threshold_backward]
# Source node to ATen node mapping:
# relu_15 => relu_15
# Graph fragment:
# %relu_15 : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%view_180,), kwargs = {})
# %pow_85 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%relu_15, 1.0), kwargs = {})
# %mul_256 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_85, 2.0), kwargs = {})
# %mul_257 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%view_186, %mul_256), kwargs = {})
# %le_1 : [num_users=1] = call_function[target=torch.ops.aten.le.Scalar](args = (%relu_15, 0), kwargs = {})
# %full_default_17 : [num_users=16] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %where_8 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%le_1, %full_default_17, %mul_257), kwargs = {})
triton_poi_fused_mul_pow_relu_threshold_backward_3 = async_compile.triton('triton_poi_fused_mul_pow_relu_threshold_backward_3', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 268435456},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_mul_pow_relu_threshold_backward_3', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_mul_pow_relu_threshold_backward_3(in_out_ptr0, in_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 268435456
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.full([1], 0, tl.int32)
tmp2 = triton_helpers.maximum(tmp1, tmp0)
tmp3 = 0.0
tmp4 = tmp2 <= tmp3
tmp6 = 2.0
tmp7 = tmp2 * tmp6
tmp8 = tmp5 * tmp7
tmp9 = tl.where(tmp4, tmp3, tmp8)
tl.store(in_out_ptr0 + (x0), tmp9, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ul/culpiuny7rsb7xrfxqlunbb4ygzdfdptdamd6p2d43ojix2u5dvg.py
# Topologically Sorted Source Nodes: [rms_norm_61], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# rms_norm_61 => convert_element_type_312
# Graph fragment:
# %convert_element_type_341 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_188, torch.float32), kwargs = {})
# %convert_element_type_312 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_175, torch.float32), kwargs = {})
# %mul_258 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %convert_element_type_312), kwargs = {})
# %mul_259 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_341, %rsqrt_61), kwargs = {})
# %sum_12 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_258, [2], True), kwargs = {})
# %div_4 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_21, 1024), kwargs = {})
# %pow_87 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_312, 1.0), kwargs = {})
# %mul_262 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_87, 2.0), kwargs = {})
# %mul_263 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_4, %mul_262), kwargs = {})
# %add_182 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_259, %mul_263), kwargs = {})
# %convert_element_type_342 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_182, torch.bfloat16), kwargs = {})
# %add_183 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%convert_element_type_332, %convert_element_type_342), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_4 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_4', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_4', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_4(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp10 = tmp1 * tmp9
tmp11 = -0.5
tmp12 = tmp7 * tmp11
tmp13 = tmp9 * tmp9
tmp14 = tmp13 * tmp9
tmp15 = tmp12 * tmp14
tmp16 = 0.0009765625
tmp17 = tmp15 * tmp16
tmp18 = 2.0
tmp19 = tmp3 * tmp18
tmp20 = tmp17 * tmp19
tmp21 = tmp10 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp8 + tmp22
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/3d/c3dqvlchxst3ygwswuqw55ensuczxcndkjpqvcilnha65gtorjio.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {})
triton_per_fused_zeros_5 = async_compile.triton('triton_per_fused_zeros_5', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_zeros_5', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused_zeros_5(in_ptr0, in_ptr1, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
R0_BLOCK: tl.constexpr = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 128*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 128*x0), None).to(tl.float32)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp5 = tl.sum(tmp3, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp5, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/g7/cg7vrix5hzkplkjfm4y5yzz67ap26cysqwhzfniuom4fca22jmrl.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {})
triton_poi_fused_zeros_6 = async_compile.triton('triton_poi_fused_zeros_6', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'y': 8, 'x': 65536}, tile_hint=TileHint.SQUARE,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_zeros_6', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_zeros_6(in_ptr0, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 8
xnumel = 65536
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, YBLOCK], True, tl.int1)
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (y0 + 8*x1), ymask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = 0.0
tmp3 = tmp1 - tmp2
tl.store(out_ptr0 + (x1 + 65536*y0), tmp3, ymask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/g5/cg5pjx7bwly5zrm6xs4clxzfh2xs7znnajmp4og73hop4xabftat.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.zeros]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_19 : [num_users=15] = call_function[target=torch.ops.aten.full.default](args = ([1, 8, 65536], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %flex_attention_backward : [num_users=3] = call_function[target=torch.ops.higher_order.flex_attention_backward](args = (%permute_119, %permute_120, %permute_121, %getitem_145, %getitem_146, %permute_143, %full_default_19, %fw_graph0, %joint_graph0, (65536, 65536, %clamp_max, %unsqueeze_9, %clamp_max_1, %unsqueeze_13, %convert_element_type_2, %clone_4, %convert_element_type_4, %clone_7, 128, 128, %mask_graph0), 0.12, {PRESCALE_QK: False, ROWS_GUARANTEED_SAFE: False, BLOCKS_ARE_CONTIGUOUS: False, WRITE_DQ: True, OUTPUT_LOGSUMEXP: True}, (), (%cumsum,)), kwargs = {})
triton_tem_fused_zeros_7 = async_compile.triton('triton_tem_fused_zeros_7', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=3,
num_warps=8,
triton_meta={'signature': {'arg_Q': '*bf16', 'arg_K': '*bf16', 'arg_V': '*bf16', 'arg_LSE': '*fp32', 'arg_DELTA': '*fp32', 'arg_DO': '*bf16', 'arg_DQ': '*bf16', 'arg_DV': '*bf16', 'arg_KV_NUM_BLKS': '*i32', 'arg_KV_IDX': '*i32', 'arg_Q_NUM_BLKS': '*i32', 'arg_Q_IDX': '*i32', 'arg_FULL_KV_NUM_BLKS': '*i32', 'arg_FULL_KV_IDX': '*i32', 'arg_FULL_Q_NUM_BLKS': '*i32', 'arg_FULL_Q_IDX': '*i32', 'in_ptr16': '*i64', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
inductor_meta={'kernel_name': 'triton_tem_fused_zeros_7', 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
)
@triton.jit
def triton_tem_fused_zeros_7(arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
Q = arg_Q
K = arg_K
V = arg_V
LSE = arg_LSE
DELTA = arg_DELTA
DO = arg_DO
DQ = arg_DQ
DV = arg_DV
KV_NUM_BLKS = arg_KV_NUM_BLKS
KV_IDX = arg_KV_IDX
Q_NUM_BLKS = arg_Q_NUM_BLKS
Q_IDX = arg_Q_IDX
FULL_KV_NUM_BLKS = arg_FULL_KV_NUM_BLKS
FULL_KV_IDX = arg_FULL_KV_IDX
FULL_Q_NUM_BLKS = arg_FULL_Q_NUM_BLKS
FULL_Q_IDX = arg_FULL_Q_IDX
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = 67108864, 128, 1024, 1
stride_kz, stride_kh, stride_kn, stride_kd = 67108864, 128, 1024, 1
stride_vz, stride_vh, stride_vn, stride_vd = 67108864, 128, 1024, 1
stride_doz, stride_doh, stride_dom, stride_dod = 67108864, 128, 1024, 1
stride_dqz, stride_dqh, stride_dqm, stride_dqd = 67108864, 128, 1024, 1
stride_dvz, stride_dvh, stride_dvm, stride_dvd = 67108864, 128, 1024, 1
ZQ = 1
HQ = 8
HKV = 8
Q_LEN = 65536
ZKV = 1
KV_LEN = 65536
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_hz = tl.program_id(2)
off_zq = off_hz // HKV # q batch idx
off_hkv = off_hz % HKV # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = 1
SPARSE_HQ = 1
sparse_idx_z = off_zq % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = 512
stride_kv_idx_h = 262144
stride_kv_idx_m = 512
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = 512
stride_q_idx_h = 262144
stride_q_idx_n = 512
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offseted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
index_v = offs_v[None, :]
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
dk *= SM_SCALE
if SAFE_HEAD_DIM:
mask = index_n < KV_LEN
else:
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
xindex = index_k + 128*index_n + 8388608*off_hkv + 67108864*off_zq
tl.store(out_ptr0 + (tl.broadcast_to(index_k + 128*off_hkv + 1024*index_n, dk.shape)), dk, mask)
@triton.jit
def bwd_dq_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 65536
KV_LEN = 65536
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_n in range(0, hi - 1):
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_n in range(0, hi):
dq = bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
return dq
@triton.jit
def bwd_dq_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
# NB reversed order to since K is transposed
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
# that the M reads out of bounds prior to the last loop
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
tmp0 = (qk)
post_mod_scores = tmp0
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp1 = (m)
tmp2 = (n)
tmp3 = tmp1 >= tmp2
tmp4 = tl.load(in_ptr16 + tmp1)
tmp5 = tl.load(in_ptr16 + tmp2)
tmp6 = tmp4 == tmp5
tmp7 = tmp3 & tmp6
mask_mod_output = tmp7
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
# NB reversed order to since V is transposed
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
tmp8 = (ds)
grad_scores = tmp8
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
return dq
@triton.jit
def bwd_dkdv_inner(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = 65536
KV_LEN = 65536
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_m in range(0, hi - 1):
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_m in range(0, hi):
dk, dv = bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
return dk, dv
@triton.jit
def bwd_dkdv_block_mn(
arg_Q, arg_K, arg_V, arg_LSE, arg_DELTA, arg_DO, arg_DQ, arg_DV, arg_KV_NUM_BLKS, arg_KV_IDX, arg_Q_NUM_BLKS, arg_Q_IDX, arg_FULL_KV_NUM_BLKS, arg_FULL_KV_IDX, arg_FULL_Q_NUM_BLKS, arg_FULL_Q_IDX, in_ptr16, out_ptr0,
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
PRESCALE_QK : tl.constexpr = False
ROWS_GUARANTEED_SAFE : tl.constexpr = False
BLOCKS_ARE_CONTIGUOUS : tl.constexpr = False
WRITE_DQ : tl.constexpr = True
OUTPUT_LOGSUMEXP : tl.constexpr = True
FLOAT32_PRECISION : tl.constexpr = 'ieee'
IS_DIVISIBLE : tl.constexpr = True
SM_SCALE : tl.constexpr = 0.12
GQA_SHARED_HEADS : tl.constexpr = 1
HAS_FULL_BLOCKS : tl.constexpr = True
QK_HEAD_DIM : tl.constexpr = 128
QK_HEAD_DIM_ROUNDED : tl.constexpr = 128
V_HEAD_DIM : tl.constexpr = 128
V_HEAD_DIM_ROUNDED : tl.constexpr = 128
SAFE_HEAD_DIM : tl.constexpr = True
BLOCK_M1 : tl.constexpr = 64
BLOCK_N1 : tl.constexpr = 128
BLOCK_M2 : tl.constexpr = 128
BLOCK_N2 : tl.constexpr = 64
SPARSE_Q_BLOCK_SIZE : tl.constexpr = 128
SPARSE_KV_BLOCK_SIZE : tl.constexpr = 128
# NB reversed order since Q is transposed
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
lse = tl.load(LSE + offs_m1)
else:
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
# that the n reads out of bounds prior to the last loop
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
pre_mod_scores = qkT
tmp9 = (qkT)
post_mod_scores = tmp9
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
tmp10 = (m)
tmp11 = (n)
tmp12 = tmp10 >= tmp11
tmp13 = tl.load(in_ptr16 + tmp10)
tmp14 = tl.load(in_ptr16 + tmp11)
tmp15 = tmp13 == tmp14
tmp16 = tmp12 & tmp15
mask_mod_output = tmp16
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
tmp17 = (dsT)
grad_scores = tmp17
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
dsT = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
return dk, dv
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_DIM: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
''', device_str='cuda')
meta0 = {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'FLOAT32_PRECISION': "'ieee'", 'IS_DIVISIBLE': True, 'SM_SCALE': 0.12, 'GQA_SHARED_HEADS': 1, 'HAS_FULL_BLOCKS': True, 'QK_HEAD_DIM': 128, 'QK_HEAD_DIM_ROUNDED': 128, 'V_HEAD_DIM': 128, 'V_HEAD_DIM_ROUNDED': 128, 'SAFE_HEAD_DIM': True, 'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'SPARSE_Q_BLOCK_SIZE': 128, 'SPARSE_KV_BLOCK_SIZE': 128}
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/sl/csl7aubmdgjkdh44byrp4ynq374ifxk76hyvh6pkbszhj2arqklp.py
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten.sum, aten._to_copy]
# Source node to ATen node mapping:
# v_43 => convert_element_type_308, convert_element_type_309, mul_234
# Graph fragment:
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {})
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {})
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {})
# %mul_234 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_308, %rsqrt_60), kwargs = {})
# %convert_element_type_309 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_234, torch.bfloat16), kwargs = {})
# %mul_267 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %convert_element_type_309), kwargs = {})
# %sum_14 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_267,), kwargs = {})
triton_red_fused__to_copy_mul_sum_8 = async_compile.triton('triton_red_fused__to_copy_mul_sum_8', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 512, 'r0_': 131072},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_8', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 4, 'num_reduction': 2, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_mul_sum_8(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 512
r0_numel = 131072
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp4 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
_tmp13 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp8 = tl.load(in_ptr3 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0)
tmp2 = tmp0 * tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, R0_BLOCK])
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(xmask, tmp5, _tmp4)
tmp7 = tmp6.to(tl.float32)
tmp9 = tmp7 * tmp8
tmp10 = tmp9.to(tl.float32)
tmp11 = tmp0 * tmp10
tmp12 = tl.broadcast_to(tmp11, [XBLOCK, R0_BLOCK])
tmp14 = _tmp13 + tmp12
_tmp13 = tl.where(xmask, tmp14, _tmp13)
tmp4 = tl.sum(_tmp4, 1)[:, None]
tmp13 = tl.sum(_tmp13, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp4, xmask)
tl.store(out_ptr1 + (x0), tmp13, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ip/cipvpmrkhwlcho5qp6b5hh7tf5gzn2zmmw6oqxmrgh7e6dgbrpl4.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.mul, aten.sum]
# Source node to ATen node mapping:
# Graph fragment:
# %mul_265 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %view_36), kwargs = {})
# %sum_13 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_265,), kwargs = {})
triton_per_fused_mul_sum_9 = async_compile.triton('triton_per_fused_mul_sum_9', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 1, 'r0_': 512},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_mul_sum_9', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused_mul_sum_9(in_ptr0, out_ptr0, xnumel, r0_numel):
xnumel = 1
XBLOCK: tl.constexpr = 1
r0_numel = 512
R0_BLOCK: tl.constexpr = 512
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_0 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_0), None)
tmp1 = tl.broadcast_to(tmp0, [R0_BLOCK])
tmp3 = triton_helpers.promote_to_tensor(tl.sum(tmp1, 0))
tl.store(out_ptr0 + (tl.full([1], 0, tl.int32)), tmp3, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/t7/ct7jp2xwhotrjgyyvvxtn4sdhjkd3odeyv5jotixf4j4625b7o64.py
# Topologically Sorted Source Nodes: [v_43], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_43 => convert_element_type_308
# Graph fragment:
# %mul_266 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_144, %select_101), kwargs = {})
# %convert_element_type_308 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_140, torch.float32), kwargs = {})
# %convert_element_type_349 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_266, torch.float32), kwargs = {})
# %mul_268 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %convert_element_type_308), kwargs = {})
# %mul_269 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_349, %rsqrt_60), kwargs = {})
# %sum_15 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_268, [3], True), kwargs = {})
# %div_5 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_22, 128), kwargs = {})
# %pow_89 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_308, 1.0), kwargs = {})
# %mul_272 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_89, 2.0), kwargs = {})
# %mul_273 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_5, %mul_272), kwargs = {})
# %add_185 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_269, %mul_273), kwargs = {})
# %convert_element_type_350 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_185, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_10 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_10', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_10', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_10(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (78))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (78))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/qm/cqmbpar5nxuou6yq7grwenf2cpvetjywkppq53bzvstjg6xoegcz.py
# Topologically Sorted Source Nodes: [k_43, q_43], Original ATen: [aten.cat, aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# k_43 => convert_element_type_302
# q_43 => convert_element_type_300
# Graph fragment:
# %cat_30 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_187, %add_186], 3), kwargs = {})
# %cat_31 : [num_users=2] = call_function[target=torch.ops.aten.cat.default](args = ([%add_189, %add_188], 3), kwargs = {})
# %convert_element_type_302 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_139, torch.float32), kwargs = {})
# %mul_282 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %convert_element_type_302), kwargs = {})
# %mul_283 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_30, %rsqrt_59), kwargs = {})
# %sum_16 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_282, [3], True), kwargs = {})
# %div_6 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_23, 128), kwargs = {})
# %pow_91 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_302, 1.0), kwargs = {})
# %mul_286 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_91, 2.0), kwargs = {})
# %mul_287 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_6, %mul_286), kwargs = {})
# %add_190 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_283, %mul_287), kwargs = {})
# %convert_element_type_356 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_190, torch.bfloat16), kwargs = {})
# %convert_element_type_300 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_138, torch.float32), kwargs = {})
# %mul_288 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %convert_element_type_300), kwargs = {})
# %mul_289 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%cat_31, %rsqrt_58), kwargs = {})
# %sum_17 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_288, [3], True), kwargs = {})
# %div_7 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_24, 128), kwargs = {})
# %pow_93 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_300, 1.0), kwargs = {})
# %mul_292 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_93, 2.0), kwargs = {})
# %mul_293 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_7, %mul_292), kwargs = {})
# %add_191 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_289, %mul_293), kwargs = {})
# %convert_element_type_358 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_191, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11 = async_compile.triton('triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr3': '*bf16', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 21, 'num_reduction': 2, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_cat_div_mul_pow_sum_11(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, out_ptr0, out_ptr1, out_ptr3, out_ptr5, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
x1 = xindex // 8
x0 = (xindex % 8)
_tmp55 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp51 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp0 = r0_2
tmp1 = tl.full([1, 1], 0, tl.int64)
tmp2 = tmp0 >= tmp1
tmp3 = tl.full([1, 1], 64, tl.int64)
tmp4 = tmp0 < tmp3
tmp5 = tl.load(in_ptr0 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = tl.load(in_ptr1 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0)
tmp8 = -tmp7
tmp9 = tmp6 * tmp8
tmp10 = tl.load(in_ptr0 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp11 = tmp10.to(tl.float32)
tmp12 = tl.load(in_ptr2 + (64*x1 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0)
tmp13 = tmp11 * tmp12
tmp14 = tmp9 + tmp13
tmp15 = tl.full(tmp14.shape, 0.0, tmp14.dtype)
tmp16 = tl.where(tmp4, tmp14, tmp15)
tmp17 = tmp0 >= tmp3
tmp18 = tl.full([1, 1], 128, tl.int64)
tmp19 = tmp0 < tmp18
tmp20 = tl.load(in_ptr0 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp21 = tmp20.to(tl.float32)
tmp22 = tl.load(in_ptr2 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0)
tmp23 = tmp21 * tmp22
tmp24 = tl.load(in_ptr0 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp25 = tmp24.to(tl.float32)
tmp26 = tl.load(in_ptr1 + (64*x1 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0)
tmp27 = tmp25 * tmp26
tmp28 = tmp23 + tmp27
tmp29 = tl.full(tmp28.shape, 0.0, tmp28.dtype)
tmp30 = tl.where(tmp17, tmp28, tmp29)
tmp31 = tl.where(tmp4, tmp16, tmp30)
tmp32 = tl.load(in_ptr3 + (64 + 128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp33 = tmp32.to(tl.float32)
tmp34 = tmp33 * tmp8
tmp35 = tl.load(in_ptr3 + (128*x3 + (r0_2)), r0_mask & tmp4, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp36 = tmp35.to(tl.float32)
tmp37 = tmp36 * tmp12
tmp38 = tmp34 + tmp37
tmp39 = tl.full(tmp38.shape, 0.0, tmp38.dtype)
tmp40 = tl.where(tmp4, tmp38, tmp39)
tmp41 = tl.load(in_ptr3 + (64 + 128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp42 * tmp22
tmp44 = tl.load(in_ptr3 + (128*x3 + ((-64) + r0_2)), r0_mask & tmp17, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp45 = tmp44.to(tl.float32)
tmp46 = tmp45 * tmp26
tmp47 = tmp43 + tmp46
tmp48 = tl.full(tmp47.shape, 0.0, tmp47.dtype)
tmp49 = tl.where(tmp17, tmp47, tmp48)
tmp50 = tl.where(tmp4, tmp40, tmp49)
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp31 * tmp52
tmp54 = tl.broadcast_to(tmp53, [XBLOCK, R0_BLOCK])
tmp56 = _tmp55 + tmp54
_tmp55 = tl.where(r0_mask, tmp56, _tmp55)
tl.store(out_ptr0 + (r0_2 + 128*x3), tmp31, r0_mask)
tl.store(out_ptr1 + (r0_2 + 128*x3), tmp50, r0_mask)
tmp55 = tl.sum(_tmp55, 1)[:, None]
tmp58 = tl.load(in_ptr5 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp57 = tl.load(out_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0)
tmp67 = tl.load(in_ptr4 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp59 = tmp57 * tmp58
tmp60 = -0.5
tmp61 = tmp55 * tmp60
tmp62 = tmp58 * tmp58
tmp63 = tmp62 * tmp58
tmp64 = tmp61 * tmp63
tmp65 = 0.0078125
tmp66 = tmp64 * tmp65
tmp68 = tmp67.to(tl.float32)
tmp69 = 2.0
tmp70 = tmp68 * tmp69
tmp71 = tmp66 * tmp70
tmp72 = tmp59 + tmp71
tmp73 = tmp72.to(tl.float32)
tl.store(out_ptr3 + (r0_2 + 128*x0 + 3072*x1), tmp73, r0_mask)
_tmp79 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp74 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0)
tmp75 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp76 = tmp75.to(tl.float32)
tmp77 = tmp74 * tmp76
tmp78 = tl.broadcast_to(tmp77, [XBLOCK, R0_BLOCK])
tmp80 = _tmp79 + tmp78
_tmp79 = tl.where(r0_mask, tmp80, _tmp79)
tmp79 = tl.sum(_tmp79, 1)[:, None]
tmp82 = tl.load(in_ptr7 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp81 = tl.load(out_ptr1 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0)
tmp91 = tl.load(in_ptr6 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp83 = tmp81 * tmp82
tmp84 = -0.5
tmp85 = tmp79 * tmp84
tmp86 = tmp82 * tmp82
tmp87 = tmp86 * tmp82
tmp88 = tmp85 * tmp87
tmp89 = 0.0078125
tmp90 = tmp88 * tmp89
tmp92 = tmp91.to(tl.float32)
tmp93 = 2.0
tmp94 = tmp92 * tmp93
tmp95 = tmp90 * tmp94
tmp96 = tmp83 + tmp95
tmp97 = tmp96.to(tl.float32)
tl.store(out_ptr5 + (r0_2 + 128*x0 + 3072*x1), tmp97, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/k2/ck2d2pyt4c2xy4hbpuabbajwzm2ocm7rodqo6uzeivtga7ibn3pa.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten.add]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_18 : [num_users=30] = call_function[target=torch.ops.aten.full.default](args = ([4, 1024, 1024], 0), kwargs = {dtype: torch.bfloat16, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_18, %mm_69, 0, 3), kwargs = {})
# %slice_scatter_default : [num_users=1] = call_function[target=torch.ops.aten.slice_scatter.default](args = (%full_default_18, %view_196, 0, 0, 3), kwargs = {})
# %add_193 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default, %slice_scatter_default), kwargs = {})
triton_poi_fused_add_select_backward_12 = async_compile.triton('triton_poi_fused_add_select_backward_12', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 4194304},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_select_backward_12', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_select_backward_12(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 4194304
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x1 = xindex // 1048576
x0 = (xindex % 1048576)
x2 = xindex
tmp3 = tl.load(in_ptr0 + (x0), None, eviction_policy='evict_last').to(tl.float32)
tmp0 = x1
tmp1 = tl.full([1], 3, tl.int32)
tmp2 = tmp0 == tmp1
tmp4 = 0.0
tmp5 = tl.where(tmp2, tmp3, tmp4)
tmp6 = tl.full([1], 3, tl.int64)
tmp7 = tmp0 < tmp6
tmp8 = tl.load(in_ptr1 + (x2), tmp7, other=0.0).to(tl.float32)
tmp9 = tl.where(tmp7, tmp8, tmp4)
tmp10 = tmp5 + tmp9
tl.store(out_ptr0 + (x2), tmp10, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/a5/ca56au744fshqomiwmgjdngi2et5vflsgkwrpod6zptxlqclh6dv.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %mul_296 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_99), kwargs = {})
triton_poi_fused_add_mul_13 = async_compile.triton('triton_poi_fused_add_mul_13', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_13', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_13(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (46))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/md/cmdhjwnox4mptoiaadn5oichuh2ehyfrlgqib4p7z5y6kig7caz6.py
# Topologically Sorted Source Nodes: [rms_norm_57], Original ATen: [aten._to_copy, aten.mul, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# rms_norm_57 => convert_element_type_292
# Graph fragment:
# %convert_element_type_373 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_200, torch.float32), kwargs = {})
# %convert_element_type_292 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_163, torch.float32), kwargs = {})
# %mul_300 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %convert_element_type_292), kwargs = {})
# %mul_301 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_373, %rsqrt_57), kwargs = {})
# %sum_20 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_300, [2], True), kwargs = {})
# %div_8 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_25, 1024), kwargs = {})
# %pow_96 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_292, 1.0), kwargs = {})
# %mul_304 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_96, 2.0), kwargs = {})
# %mul_305 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_8, %mul_304), kwargs = {})
# %add_195 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_301, %mul_305), kwargs = {})
# %convert_element_type_374 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_195, torch.bfloat16), kwargs = {})
# %add_196 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_296, %convert_element_type_374), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_14 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_14', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_14', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 4, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_14(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp10 = tmp1 * tmp9
tmp11 = -0.5
tmp12 = tmp7 * tmp11
tmp13 = tmp9 * tmp9
tmp14 = tmp13 * tmp9
tmp15 = tmp12 * tmp14
tmp16 = 0.0009765625
tmp17 = tmp15 * tmp16
tmp18 = 2.0
tmp19 = tmp3 * tmp18
tmp20 = tmp17 * tmp19
tmp21 = tmp10 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp8 + tmp22
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp23, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/xg/cxged3dxs4opgztavvr7fkyy2lg6uqfqee4esalldmdyijnennf5.py
# Topologically Sorted Source Nodes: [v_40], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_40 => convert_element_type_288
# Graph fragment:
# %mul_308 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_164, %select_94), kwargs = {})
# %convert_element_type_288 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_131, torch.float32), kwargs = {})
# %convert_element_type_381 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_308, torch.float32), kwargs = {})
# %mul_310 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %convert_element_type_288), kwargs = {})
# %mul_311 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_381, %rsqrt_56), kwargs = {})
# %sum_23 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_310, [3], True), kwargs = {})
# %div_9 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_26, 128), kwargs = {})
# %pow_98 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_288, 1.0), kwargs = {})
# %mul_314 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_98, 2.0), kwargs = {})
# %mul_315 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_9, %mul_314), kwargs = {})
# %add_198 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_311, %mul_315), kwargs = {})
# %convert_element_type_382 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_198, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_15 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_15', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_15', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_15(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (76))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (76))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/st/cstxqu5wndh2mme6n7qysm6mhvq6r36u3keo4fwrmmgxo5rjyjzu.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {})
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {})
# %mul_297 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %add_165), kwargs = {})
# %sum_19 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_297,), kwargs = {})
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {})
# %mul_337 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %convert_element_type_11), kwargs = {})
# %sum_26 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_337,), kwargs = {})
# %mul_338 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_92), kwargs = {})
# %mul_339 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %add_153), kwargs = {})
# %sum_27 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_339,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_16 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_16', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_16', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 9, 'num_reduction': 4, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_16(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_ptr8 + (44))
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp19 * tmp31
tl.store(out_ptr4 + (r0_1 + 1024*x0), tmp32, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/4y/c4ykabbdrwrdol5x6eczezz3t7rsq5fk7zv6xmtnuvk22t2q5jdc.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %mul_295 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %convert_element_type_11), kwargs = {})
# %sum_18 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_295,), kwargs = {})
triton_red_fused__to_copy_add_mul_sum_17 = async_compile.triton('triton_red_fused__to_copy_add_mul_sum_17', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 1, 'r0_': 65536},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_mul_sum_17', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_mul_sum_17(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1
r0_numel = 65536
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
_tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_0 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_0), None, eviction_policy='evict_first')
tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = _tmp2 + tmp1
_tmp2 = tmp3
tmp2 = tl.sum(_tmp2, 1)[:, None]
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp2, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/vf/cvf3n5zrn4y5xotwv7lzpbxws7wmfhbxv2cnzgqiqengfefpfz4g.py
# Topologically Sorted Source Nodes: [v_37], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_37 => convert_element_type_268
# Graph fragment:
# %mul_350 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_184, %select_87), kwargs = {})
# %convert_element_type_268 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_122, torch.float32), kwargs = {})
# %convert_element_type_413 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_350, torch.float32), kwargs = {})
# %mul_352 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %convert_element_type_268), kwargs = {})
# %mul_353 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_413, %rsqrt_52), kwargs = {})
# %sum_31 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_352, [3], True), kwargs = {})
# %div_13 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_30, 128), kwargs = {})
# %pow_107 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_268, 1.0), kwargs = {})
# %mul_356 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_107, 2.0), kwargs = {})
# %mul_357 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_13, %mul_356), kwargs = {})
# %add_214 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_353, %mul_357), kwargs = {})
# %convert_element_type_414 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_214, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_18 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_18', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_18', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_18(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (74))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (74))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/he/chewvuajuq55rjjxwpgjkdibqap2uf4svllmqmaiitre66ivyggh.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten.add, aten.mul, aten._to_copy, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %add_192 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_183, %view_195), kwargs = {})
# %mul_294 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_192, %select_100), kwargs = {})
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_205 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_196, %view_207), kwargs = {})
# %mul_336 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_205, %select_93), kwargs = {})
# %add_207 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_294, %mul_336), kwargs = {})
# %add_221 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_212, %view_219), kwargs = {})
# %mul_378 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_86), kwargs = {})
# %mul_379 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %convert_element_type_11), kwargs = {})
# %sum_34 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_379,), kwargs = {})
# %add_223 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_207, %mul_378), kwargs = {})
# %mul_380 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %select_85), kwargs = {})
# %mul_381 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_221, %add_141), kwargs = {})
# %sum_35 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_381,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_19 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_19', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*fp32', 'in_ptr8': '*bf16', 'out_ptr0': '*bf16', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_19', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 2, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_19(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, out_ptr0, out_ptr1, out_ptr2, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr1 + (47))
tmp4 = tl.broadcast_to(tmp3, [R0_BLOCK])
tmp7 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp10 = tl.load(in_ptr1 + (45))
tmp11 = tl.broadcast_to(tmp10, [R0_BLOCK])
tmp15 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp16 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr1 + (43))
tmp19 = tl.broadcast_to(tmp18, [R0_BLOCK])
tmp23 = tl.load(in_ptr1 + (42))
tmp24 = tl.broadcast_to(tmp23, [R0_BLOCK])
tmp27 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_ptr7 + (x0), None, eviction_policy='evict_last')
tmp36 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tmp9 = tmp7 + tmp8
tmp12 = tmp11.to(tl.float32)
tmp13 = tmp9 * tmp12
tmp14 = tmp6 + tmp13
tmp17 = tmp15 + tmp16
tmp20 = tmp19.to(tl.float32)
tmp21 = tmp17 * tmp20
tmp22 = tmp14 + tmp21
tmp25 = tmp24.to(tl.float32)
tmp26 = tmp17 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp30 = tmp28 * tmp29
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp17 * tmp31
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK])
tmp35 = triton_helpers.promote_to_tensor(tl.sum(tmp33, 0))
tmp37 = tmp17 * tmp36
tmp38 = tl.broadcast_to(tmp37, [R0_BLOCK])
tmp40 = triton_helpers.promote_to_tensor(tl.sum(tmp38, 0))
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp22, None)
tl.store(out_ptr0 + (r0_1 + 1024*x0), tmp26, None)
tl.store(out_ptr1 + (x0), tmp35, None)
tl.store(out_ptr2 + (x0), tmp40, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/zw/czwk3untwmpllx2wu6vrpai3qkimsauvxusvmc74tmu73zz3k6zn.py
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten._to_copy, aten.mul, aten.sum]
# Source node to ATen node mapping:
# v_34 => convert_element_type_248, convert_element_type_249, mul_187
# Graph fragment:
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {})
# %mul_187 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_248, %rsqrt_48), kwargs = {})
# %convert_element_type_249 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_187, torch.bfloat16), kwargs = {})
# %mul_391 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %convert_element_type_249), kwargs = {})
# %sum_37 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_391,), kwargs = {})
triton_red_fused__to_copy_mul_sum_20 = async_compile.triton('triton_red_fused__to_copy_mul_sum_20', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 512, 'r0_': 131072},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_mul_sum_20', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_mul_sum_20(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 512
r0_numel = 131072
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x0 = xindex
_tmp8 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_1 + 131072*x0), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (3072*(r0_1 // 1024) + 393216*x0 + ((r0_1 % 1024))), xmask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (1024*x0 + (r0_1 // 128)), xmask, eviction_policy='evict_last', other=0.0)
tmp2 = tmp1.to(tl.float32)
tmp4 = tmp2 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp0 * tmp5
tmp7 = tl.broadcast_to(tmp6, [XBLOCK, R0_BLOCK])
tmp9 = _tmp8 + tmp7
_tmp8 = tl.where(xmask, tmp9, _tmp8)
tmp8 = tl.sum(_tmp8, 1)[:, None]
tl.store(out_ptr0 + (x0), tmp8, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/xt/cxt3lc5hr6vgginbc2ivwhx3tof46klook7rqcj5a7asaqgh3yez.py
# Topologically Sorted Source Nodes: [v_34], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_34 => convert_element_type_248
# Graph fragment:
# %mul_390 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_204, %select_81), kwargs = {})
# %convert_element_type_248 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_113, torch.float32), kwargs = {})
# %convert_element_type_444 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_390, torch.float32), kwargs = {})
# %mul_392 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %convert_element_type_248), kwargs = {})
# %mul_393 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_444, %rsqrt_48), kwargs = {})
# %sum_38 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_392, [3], True), kwargs = {})
# %div_17 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_34, 128), kwargs = {})
# %pow_116 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_248, 1.0), kwargs = {})
# %mul_396 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_116, 2.0), kwargs = {})
# %mul_397 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_17, %mul_396), kwargs = {})
# %add_229 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_393, %mul_397), kwargs = {})
# %convert_element_type_445 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_229, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_21 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_21', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_21', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_21(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (72))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (72))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/i5/ci5mcynkhhv2drjtnbviwl5ibwpvhvcagyfiku76l7vlecpg2j4c.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {})
# %mul_420 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_79), kwargs = {})
triton_poi_fused_add_mul_22 = async_compile.triton('triton_poi_fused_add_mul_22', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_22', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_22(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (40))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/2q/c2q426mapna6iekh3hgp2hihn6zqzbqso4gblzbvbonknafgmvyc.py
# Topologically Sorted Source Nodes: [v_31], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_31 => convert_element_type_228
# Graph fragment:
# %mul_430 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_224, %select_75), kwargs = {})
# %convert_element_type_228 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_104, torch.float32), kwargs = {})
# %convert_element_type_475 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_430, torch.float32), kwargs = {})
# %mul_432 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %convert_element_type_228), kwargs = {})
# %mul_433 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_475, %rsqrt_44), kwargs = {})
# %sum_45 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_432, [3], True), kwargs = {})
# %div_21 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_38, 128), kwargs = {})
# %pow_125 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_228, 1.0), kwargs = {})
# %mul_436 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_125, 2.0), kwargs = {})
# %mul_437 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_21, %mul_436), kwargs = {})
# %add_244 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_433, %mul_437), kwargs = {})
# %convert_element_type_476 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_244, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_23 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_23', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_23', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_23(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (70))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (70))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/7j/c7j54wiarheofqb4gbyajl45ernnjlddmdz77p5ehb46w46cawmo.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {})
# %mul_419 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %convert_element_type_11), kwargs = {})
# %sum_41 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_419,), kwargs = {})
# %mul_421 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %add_130), kwargs = {})
# %sum_42 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_421,), kwargs = {})
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {})
# %mul_459 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %convert_element_type_11), kwargs = {})
# %sum_48 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_459,), kwargs = {})
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {})
# %mul_461 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %add_119), kwargs = {})
# %sum_49 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_461,), kwargs = {})
# %mul_463 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %add_36), kwargs = {})
# %sum_50 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_463,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_24 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_24', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*fp32', 'in_ptr9': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_24', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 10, 'num_reduction': 5, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_24(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_ptr8 + (38))
tmp30 = tl.broadcast_to(tmp29, [R0_BLOCK])
tmp33 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp30.to(tl.float32)
tmp32 = tmp19 * tmp31
tmp34 = tmp32 * tmp33
tmp35 = tl.broadcast_to(tmp34, [R0_BLOCK])
tmp37 = triton_helpers.promote_to_tensor(tl.sum(tmp35, 0))
tl.store(out_ptr5 + (r0_1 + 1024*x0), tmp32, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
tl.store(out_ptr4 + (x0), tmp37, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/f6/cf6ojfz6bfezrzis2b5mf55ez4x3btg3ufefdfa5vevfoiqfwice.py
# Topologically Sorted Source Nodes: [v_28], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_28 => convert_element_type_208
# Graph fragment:
# %mul_472 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_244, %select_68), kwargs = {})
# %convert_element_type_208 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_95, torch.float32), kwargs = {})
# %convert_element_type_507 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_472, torch.float32), kwargs = {})
# %mul_474 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %convert_element_type_208), kwargs = {})
# %mul_475 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_507, %rsqrt_40), kwargs = {})
# %sum_53 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_474, [3], True), kwargs = {})
# %div_25 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_42, 128), kwargs = {})
# %pow_134 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_208, 1.0), kwargs = {})
# %mul_478 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_134, 2.0), kwargs = {})
# %mul_479 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_25, %mul_478), kwargs = {})
# %add_259 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_475, %mul_479), kwargs = {})
# %convert_element_type_508 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_259, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_25 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_25', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_25', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_25(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (68))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (68))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/yt/cytha5ygxd572kkbk6gsg75fiya7avvgj346towpq75mnegn35vm.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_236 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_228, %view_230), kwargs = {})
# %mul_418 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_236, %select_80), kwargs = {})
# %add_238 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_223, %mul_418), kwargs = {})
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {})
# %mul_458 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_74), kwargs = {})
# %add_253 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_238, %mul_458), kwargs = {})
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {})
# %mul_500 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_67), kwargs = {})
# %mul_501 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %convert_element_type_11), kwargs = {})
# %sum_56 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_501,), kwargs = {})
# %add_268 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_253, %mul_500), kwargs = {})
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {})
# %mul_503 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %add_107), kwargs = {})
# %sum_57 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_503,), kwargs = {})
# %mul_505 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %add_58), kwargs = {})
# %sum_58 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_505,), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_26 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_26', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_26', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 15, 'num_reduction': 3, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_26(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (36))
tmp18 = tl.broadcast_to(tmp17, [R0_BLOCK])
tmp21 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp26 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp27 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp28 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp30 = tl.load(in_ptr5 + (41))
tmp31 = tl.broadcast_to(tmp30, [R0_BLOCK])
tmp35 = tl.load(in_ptr9 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp36 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp38 = tl.load(in_ptr5 + (39))
tmp39 = tl.broadcast_to(tmp38, [R0_BLOCK])
tmp43 = tl.load(in_ptr5 + (37))
tmp44 = tl.broadcast_to(tmp43, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp18.to(tl.float32)
tmp20 = tmp2 * tmp19
tmp22 = tmp20 * tmp21
tmp23 = tl.broadcast_to(tmp22, [R0_BLOCK])
tmp25 = triton_helpers.promote_to_tensor(tl.sum(tmp23, 0))
tmp29 = tmp27 + tmp28
tmp32 = tmp31.to(tl.float32)
tmp33 = tmp29 * tmp32
tmp34 = tmp26 + tmp33
tmp37 = tmp35 + tmp36
tmp40 = tmp39.to(tl.float32)
tmp41 = tmp37 * tmp40
tmp42 = tmp34 + tmp41
tmp45 = tmp44.to(tl.float32)
tmp46 = tmp2 * tmp45
tmp47 = tmp42 + tmp46
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp47, None)
tl.store(out_ptr3 + (r0_1 + 1024*x0), tmp20, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp25, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/7j/c7julr3b2yeaunzxy3p4uhxnmaq2mkueth2riqtaxxiceiks7e4f.py
# Topologically Sorted Source Nodes: [v_25], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_25 => convert_element_type_188
# Graph fragment:
# %mul_514 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_264, %select_61), kwargs = {})
# %convert_element_type_188 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_86, torch.float32), kwargs = {})
# %convert_element_type_539 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_514, torch.float32), kwargs = {})
# %mul_516 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %convert_element_type_188), kwargs = {})
# %mul_517 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_539, %rsqrt_36), kwargs = {})
# %sum_61 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_516, [3], True), kwargs = {})
# %div_29 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_46, 128), kwargs = {})
# %pow_143 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_188, 1.0), kwargs = {})
# %mul_520 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_143, 2.0), kwargs = {})
# %mul_521 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_29, %mul_520), kwargs = {})
# %add_275 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_517, %mul_521), kwargs = {})
# %convert_element_type_540 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_275, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_27 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_27', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_27', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_27(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (66))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (66))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/px/cpxzk6lleynsgxocbapsdzouctqzdj734tcdpezlppd4n46yo5ov.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {})
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {})
triton_poi_fused_add_mul_28 = async_compile.triton('triton_poi_fused_add_mul_28', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_28', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_28(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (34))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ha/chasxebvoqyu5qi3wjqxf4oumcj4fcqploj745fbxu6kdwm3tfoh.py
# Topologically Sorted Source Nodes: [v_22], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_22 => convert_element_type_168
# Graph fragment:
# %mul_556 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_284, %select_54), kwargs = {})
# %convert_element_type_168 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_77, torch.float32), kwargs = {})
# %convert_element_type_571 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_556, torch.float32), kwargs = {})
# %mul_558 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %convert_element_type_168), kwargs = {})
# %mul_559 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_571, %rsqrt_32), kwargs = {})
# %sum_69 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_558, [3], True), kwargs = {})
# %div_33 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_50, 128), kwargs = {})
# %pow_152 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_168, 1.0), kwargs = {})
# %mul_562 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_152, 2.0), kwargs = {})
# %mul_563 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_33, %mul_562), kwargs = {})
# %add_291 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_559, %mul_563), kwargs = {})
# %convert_element_type_572 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_291, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_29 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_29', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_29', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_29(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (64))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (64))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/hi/chiekkfjcumovdpaptfvgzwej74uquxfu45x5g5on3x3zr6rjrc2.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {})
# %mul_586 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_52), kwargs = {})
triton_poi_fused_add_mul_30 = async_compile.triton('triton_poi_fused_add_mul_30', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_30', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_30(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (32))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/a6/ca646ft4uesvxwm7jyd4ahfawp3vmjghfta3dwmjfxoyizhpowhp.py
# Topologically Sorted Source Nodes: [x, rms_norm_29], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum, aten.div, aten.pow]
# Source node to ATen node mapping:
# rms_norm_29 => convert_element_type_152
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_282 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_274, %view_263), kwargs = {})
# %mul_542 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_60), kwargs = {})
# %mul_543 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %convert_element_type_11), kwargs = {})
# %sum_64 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_543,), kwargs = {})
# %add_284 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_268, %mul_542), kwargs = {})
# %mul_544 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %select_59), kwargs = {})
# %mul_545 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_282, %add_95), kwargs = {})
# %sum_65 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_545,), kwargs = {})
# %mul_546 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %select_56), kwargs = {})
# %mul_547 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_544, %add_80), kwargs = {})
# %sum_66 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_547,), kwargs = {})
# %add_298 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_290, %view_274), kwargs = {})
# %mul_584 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %select_53), kwargs = {})
# %mul_585 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %convert_element_type_11), kwargs = {})
# %sum_72 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_585,), kwargs = {})
# %add_300 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_284, %mul_584), kwargs = {})
# %mul_587 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_298, %add_83), kwargs = {})
# %sum_73 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_587,), kwargs = {})
# %convert_element_type_595 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%view_279, torch.float32), kwargs = {})
# %convert_element_type_152 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_81, torch.float32), kwargs = {})
# %mul_590 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %convert_element_type_152), kwargs = {})
# %mul_591 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_595, %rsqrt_29), kwargs = {})
# %sum_74 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_590, [2], True), kwargs = {})
# %div_36 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_53, 1024), kwargs = {})
# %pow_159 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_152, 1.0), kwargs = {})
# %mul_594 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_159, 2.0), kwargs = {})
# %mul_595 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_36, %mul_594), kwargs = {})
# %add_304 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_591, %mul_595), kwargs = {})
# %convert_element_type_596 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_304, torch.bfloat16), kwargs = {})
# %add_305 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_586, %convert_element_type_596), kwargs = {})
# %mul_596 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_49), kwargs = {})
# %mul_597 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %convert_element_type_11), kwargs = {})
# %sum_75 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_597,), kwargs = {})
# %add_306 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_300, %mul_596), kwargs = {})
# %mul_598 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %select_48), kwargs = {})
# %mul_599 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_305, %add_80), kwargs = {})
# %sum_76 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_599,), kwargs = {})
# %add_307 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_546, %mul_598), kwargs = {})
triton_per_fused__to_copy_add_div_mul_pow_sum_31 = async_compile.triton('triton_per_fused__to_copy_add_div_mul_pow_sum_31', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_out_ptr1': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*fp32', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'out_ptr4': '*fp32', 'out_ptr5': '*fp32', 'out_ptr6': '*fp32', 'out_ptr7': '*fp32', 'out_ptr8': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]], (19,): [['tt.divisibility', 16]], (20,): [['tt.divisibility', 16]], (21,): [['tt.divisibility', 16]], (22,): [['tt.divisibility', 16]], (23,): [['tt.divisibility', 16]], (24,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_div_mul_pow_sum_31', 'mutated_arg_names': ['in_out_ptr0', 'in_out_ptr1'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 20, 'num_reduction': 8, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_div_mul_pow_sum_31(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, out_ptr1, out_ptr2, out_ptr3, out_ptr4, out_ptr5, out_ptr6, out_ptr7, out_ptr8, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp8 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp9 = tl.load(in_ptr2 + (x0), None, eviction_policy='evict_last')
tmp24 = tl.load(in_out_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp25 = tl.load(in_ptr3 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp26 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp28 = tl.load(in_ptr5 + (35))
tmp29 = tl.broadcast_to(tmp28, [R0_BLOCK])
tmp33 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp34 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp36 = tl.load(in_ptr5 + (33))
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK])
tmp41 = tl.load(in_ptr5 + (31))
tmp42 = tl.broadcast_to(tmp41, [R0_BLOCK])
tmp46 = tl.load(in_ptr5 + (34))
tmp47 = tl.broadcast_to(tmp46, [R0_BLOCK])
tmp50 = tl.load(in_ptr5 + (6))
tmp51 = tl.broadcast_to(tmp50, [R0_BLOCK])
tmp54 = tl.load(in_ptr5 + (30))
tmp55 = tl.broadcast_to(tmp54, [R0_BLOCK])
tmp59 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp61 = tl.load(in_ptr9 + (x0), None, eviction_policy='evict_last')
tmp68 = tl.load(in_ptr10 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp73 = tl.load(in_ptr11 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp82 = tl.load(in_ptr12 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 * tmp3
tmp5 = tl.broadcast_to(tmp4, [R0_BLOCK])
tmp7 = triton_helpers.promote_to_tensor(tl.sum(tmp5, 0))
tmp10 = tmp1 * tmp9
tmp11 = -0.5
tmp12 = tmp7 * tmp11
tmp13 = tmp9 * tmp9
tmp14 = tmp13 * tmp9
tmp15 = tmp12 * tmp14
tmp16 = 0.0009765625
tmp17 = tmp15 * tmp16
tmp18 = 2.0
tmp19 = tmp3 * tmp18
tmp20 = tmp17 * tmp19
tmp21 = tmp10 + tmp20
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp8 + tmp22
tmp27 = tmp25 + tmp26
tmp30 = tmp29.to(tl.float32)
tmp31 = tmp27 * tmp30
tmp32 = tmp24 + tmp31
tmp35 = tmp33 + tmp34
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp35 * tmp38
tmp40 = tmp32 + tmp39
tmp43 = tmp42.to(tl.float32)
tmp44 = tmp23 * tmp43
tmp45 = tmp40 + tmp44
tmp48 = tmp47.to(tl.float32)
tmp49 = tmp27 * tmp48
tmp52 = tmp51.to(tl.float32)
tmp53 = tmp49 * tmp52
tmp56 = tmp55.to(tl.float32)
tmp57 = tmp23 * tmp56
tmp58 = tmp53 + tmp57
tmp60 = tmp59.to(tl.float32)
tmp62 = tmp60 * tmp61
tmp63 = tmp62.to(tl.float32)
tmp64 = tmp27 * tmp63
tmp65 = tl.broadcast_to(tmp64, [R0_BLOCK])
tmp67 = triton_helpers.promote_to_tensor(tl.sum(tmp65, 0))
tmp69 = tmp27 * tmp68
tmp70 = tl.broadcast_to(tmp69, [R0_BLOCK])
tmp72 = triton_helpers.promote_to_tensor(tl.sum(tmp70, 0))
tmp74 = tmp49 * tmp73
tmp75 = tl.broadcast_to(tmp74, [R0_BLOCK])
tmp77 = triton_helpers.promote_to_tensor(tl.sum(tmp75, 0))
tmp78 = tmp35 * tmp63
tmp79 = tl.broadcast_to(tmp78, [R0_BLOCK])
tmp81 = triton_helpers.promote_to_tensor(tl.sum(tmp79, 0))
tmp83 = tmp35 * tmp82
tmp84 = tl.broadcast_to(tmp83, [R0_BLOCK])
tmp86 = triton_helpers.promote_to_tensor(tl.sum(tmp84, 0))
tmp87 = tmp23 * tmp63
tmp88 = tl.broadcast_to(tmp87, [R0_BLOCK])
tmp90 = triton_helpers.promote_to_tensor(tl.sum(tmp88, 0))
tmp91 = tmp23 * tmp73
tmp92 = tl.broadcast_to(tmp91, [R0_BLOCK])
tmp94 = triton_helpers.promote_to_tensor(tl.sum(tmp92, 0))
tl.store(in_out_ptr1 + (r0_1 + 1024*x0), tmp45, None)
tl.store(out_ptr1 + (r0_1 + 1024*x0), tmp58, None)
tl.store(out_ptr2 + (x0), tmp67, None)
tl.store(out_ptr3 + (x0), tmp72, None)
tl.store(out_ptr4 + (x0), tmp77, None)
tl.store(out_ptr5 + (x0), tmp81, None)
tl.store(out_ptr6 + (x0), tmp86, None)
tl.store(out_ptr7 + (x0), tmp90, None)
tl.store(out_ptr8 + (x0), tmp94, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/bw/cbwa3hunkne3dnuzwzdtrx3cnl2c43w6e2yxbyvp2aetgwpzbih3.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.select_backward, aten._to_copy, aten.add]
# Source node to ATen node mapping:
# Graph fragment:
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_363 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_18, torch.float32), kwargs = {})
# %select_scatter_default_3 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_363, 0, 1), kwargs = {})
# %convert_element_type_364 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_19, torch.float32), kwargs = {})
# %select_scatter_default_4 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_364, 0, 0), kwargs = {})
# %add_194 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_3, %select_scatter_default_4), kwargs = {})
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_6 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_194, 0, 15), kwargs = {})
# %convert_element_type_395 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_26, torch.float32), kwargs = {})
# %select_scatter_default_10 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_395, 0, 1), kwargs = {})
# %convert_element_type_396 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_27, torch.float32), kwargs = {})
# %select_scatter_default_11 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_396, 0, 0), kwargs = {})
# %add_208 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_10, %select_scatter_default_11), kwargs = {})
# %select_scatter_default_13 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_208, 0, 14), kwargs = {})
# %add_210 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_6, %select_scatter_default_13), kwargs = {})
# %convert_element_type_427 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_34, torch.float32), kwargs = {})
# %select_scatter_default_17 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_427, 0, 1), kwargs = {})
# %convert_element_type_428 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_35, torch.float32), kwargs = {})
# %select_scatter_default_18 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_428, 0, 0), kwargs = {})
# %add_224 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_17, %select_scatter_default_18), kwargs = {})
# %select_scatter_default_20 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_224, 0, 13), kwargs = {})
# %add_226 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_210, %select_scatter_default_20), kwargs = {})
# %convert_element_type_458 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_41, torch.float32), kwargs = {})
# %select_scatter_default_23 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_458, 0, 1), kwargs = {})
# %convert_element_type_459 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_42, torch.float32), kwargs = {})
# %select_scatter_default_24 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_459, 0, 0), kwargs = {})
# %add_239 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_23, %select_scatter_default_24), kwargs = {})
# %select_scatter_default_26 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_239, 0, 12), kwargs = {})
# %add_241 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_226, %select_scatter_default_26), kwargs = {})
# %convert_element_type_489 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_48, torch.float32), kwargs = {})
# %select_scatter_default_29 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_489, 0, 1), kwargs = {})
# %convert_element_type_490 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_49, torch.float32), kwargs = {})
# %select_scatter_default_30 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_490, 0, 0), kwargs = {})
# %add_254 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_29, %select_scatter_default_30), kwargs = {})
# %select_scatter_default_32 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_254, 0, 11), kwargs = {})
# %add_256 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_241, %select_scatter_default_32), kwargs = {})
# %convert_element_type_521 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_56, torch.float32), kwargs = {})
# %select_scatter_default_36 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_521, 0, 1), kwargs = {})
# %convert_element_type_522 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_57, torch.float32), kwargs = {})
# %select_scatter_default_37 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_522, 0, 0), kwargs = {})
# %add_269 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_36, %select_scatter_default_37), kwargs = {})
# %select_scatter_default_39 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_269, 0, 10), kwargs = {})
# %add_271 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_256, %select_scatter_default_39), kwargs = {})
# %convert_element_type_553 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_64, torch.float32), kwargs = {})
# %select_scatter_default_43 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_553, 0, 1), kwargs = {})
# %convert_element_type_554 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_65, torch.float32), kwargs = {})
# %select_scatter_default_44 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_554, 0, 0), kwargs = {})
# %add_285 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_43, %select_scatter_default_44), kwargs = {})
# %select_scatter_default_46 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_285, 0, 9), kwargs = {})
# %add_287 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_271, %select_scatter_default_46), kwargs = {})
# %convert_element_type_585 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_72, torch.float32), kwargs = {})
# %select_scatter_default_50 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_585, 0, 1), kwargs = {})
# %convert_element_type_586 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_73, torch.float32), kwargs = {})
# %select_scatter_default_51 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_586, 0, 0), kwargs = {})
# %add_301 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_50, %select_scatter_default_51), kwargs = {})
# %select_scatter_default_53 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_301, 0, 8), kwargs = {})
# %add_303 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_287, %select_scatter_default_53), kwargs = {})
# %convert_element_type_597 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_75, torch.float32), kwargs = {})
# %select_scatter_default_54 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_597, 0, 1), kwargs = {})
# %convert_element_type_598 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_76, torch.float32), kwargs = {})
# %select_scatter_default_55 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_598, 0, 0), kwargs = {})
# %add_308 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_54, %select_scatter_default_55), kwargs = {})
# %select_scatter_default_56 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_308, 0, 7), kwargs = {})
# %add_309 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_303, %select_scatter_default_56), kwargs = {})
triton_poi_fused__to_copy_add_select_backward_32 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_32', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 32},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'in_ptr15': '*bf16', 'in_ptr16': '*bf16', 'in_ptr17': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]], (17,): [['tt.divisibility', 16]], (18,): [['tt.divisibility', 16]], (19,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_32', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 18, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_select_backward_32(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, in_ptr15, in_ptr16, in_ptr17, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex // 2
x0 = (xindex % 2)
x2 = xindex
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK])
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32)
tmp14 = tl.broadcast_to(tmp13, [XBLOCK])
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32)
tmp22 = tl.broadcast_to(tmp21, [XBLOCK])
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK])
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32)
tmp35 = tl.broadcast_to(tmp34, [XBLOCK])
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32)
tmp39 = tl.broadcast_to(tmp38, [XBLOCK])
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32)
tmp48 = tl.broadcast_to(tmp47, [XBLOCK])
tmp51 = tl.load(in_ptr7 + (0)).to(tl.float32)
tmp52 = tl.broadcast_to(tmp51, [XBLOCK])
tmp60 = tl.load(in_ptr8 + (0)).to(tl.float32)
tmp61 = tl.broadcast_to(tmp60, [XBLOCK])
tmp64 = tl.load(in_ptr9 + (0)).to(tl.float32)
tmp65 = tl.broadcast_to(tmp64, [XBLOCK])
tmp73 = tl.load(in_ptr10 + (0)).to(tl.float32)
tmp74 = tl.broadcast_to(tmp73, [XBLOCK])
tmp77 = tl.load(in_ptr11 + (0)).to(tl.float32)
tmp78 = tl.broadcast_to(tmp77, [XBLOCK])
tmp86 = tl.load(in_ptr12 + (0)).to(tl.float32)
tmp87 = tl.broadcast_to(tmp86, [XBLOCK])
tmp90 = tl.load(in_ptr13 + (0)).to(tl.float32)
tmp91 = tl.broadcast_to(tmp90, [XBLOCK])
tmp99 = tl.load(in_ptr14 + (0)).to(tl.float32)
tmp100 = tl.broadcast_to(tmp99, [XBLOCK])
tmp103 = tl.load(in_ptr15 + (0)).to(tl.float32)
tmp104 = tl.broadcast_to(tmp103, [XBLOCK])
tmp112 = tl.load(in_ptr16 + (0)).to(tl.float32)
tmp113 = tl.broadcast_to(tmp112, [XBLOCK])
tmp116 = tl.load(in_ptr17 + (0)).to(tl.float32)
tmp117 = tl.broadcast_to(tmp116, [XBLOCK])
tmp0 = x1
tmp1 = tl.full([1], 15, tl.int32)
tmp2 = tmp0 == tmp1
tmp3 = x0
tmp4 = tl.full([1], 1, tl.int32)
tmp5 = tmp3 == tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = 0.0
tmp10 = tl.where(tmp5, tmp8, tmp9)
tmp11 = tl.full([1], 0, tl.int32)
tmp12 = tmp3 == tmp11
tmp15 = tmp14.to(tl.float32)
tmp16 = tl.where(tmp12, tmp15, tmp9)
tmp17 = tmp10 + tmp16
tmp18 = tl.where(tmp2, tmp17, tmp9)
tmp19 = tl.full([1], 14, tl.int32)
tmp20 = tmp0 == tmp19
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.where(tmp5, tmp23, tmp9)
tmp27 = tmp26.to(tl.float32)
tmp28 = tl.where(tmp12, tmp27, tmp9)
tmp29 = tmp24 + tmp28
tmp30 = tl.where(tmp20, tmp29, tmp9)
tmp31 = tmp18 + tmp30
tmp32 = tl.full([1], 13, tl.int32)
tmp33 = tmp0 == tmp32
tmp36 = tmp35.to(tl.float32)
tmp37 = tl.where(tmp5, tmp36, tmp9)
tmp40 = tmp39.to(tl.float32)
tmp41 = tl.where(tmp12, tmp40, tmp9)
tmp42 = tmp37 + tmp41
tmp43 = tl.where(tmp33, tmp42, tmp9)
tmp44 = tmp31 + tmp43
tmp45 = tl.full([1], 12, tl.int32)
tmp46 = tmp0 == tmp45
tmp49 = tmp48.to(tl.float32)
tmp50 = tl.where(tmp5, tmp49, tmp9)
tmp53 = tmp52.to(tl.float32)
tmp54 = tl.where(tmp12, tmp53, tmp9)
tmp55 = tmp50 + tmp54
tmp56 = tl.where(tmp46, tmp55, tmp9)
tmp57 = tmp44 + tmp56
tmp58 = tl.full([1], 11, tl.int32)
tmp59 = tmp0 == tmp58
tmp62 = tmp61.to(tl.float32)
tmp63 = tl.where(tmp5, tmp62, tmp9)
tmp66 = tmp65.to(tl.float32)
tmp67 = tl.where(tmp12, tmp66, tmp9)
tmp68 = tmp63 + tmp67
tmp69 = tl.where(tmp59, tmp68, tmp9)
tmp70 = tmp57 + tmp69
tmp71 = tl.full([1], 10, tl.int32)
tmp72 = tmp0 == tmp71
tmp75 = tmp74.to(tl.float32)
tmp76 = tl.where(tmp5, tmp75, tmp9)
tmp79 = tmp78.to(tl.float32)
tmp80 = tl.where(tmp12, tmp79, tmp9)
tmp81 = tmp76 + tmp80
tmp82 = tl.where(tmp72, tmp81, tmp9)
tmp83 = tmp70 + tmp82
tmp84 = tl.full([1], 9, tl.int32)
tmp85 = tmp0 == tmp84
tmp88 = tmp87.to(tl.float32)
tmp89 = tl.where(tmp5, tmp88, tmp9)
tmp92 = tmp91.to(tl.float32)
tmp93 = tl.where(tmp12, tmp92, tmp9)
tmp94 = tmp89 + tmp93
tmp95 = tl.where(tmp85, tmp94, tmp9)
tmp96 = tmp83 + tmp95
tmp97 = tl.full([1], 8, tl.int32)
tmp98 = tmp0 == tmp97
tmp101 = tmp100.to(tl.float32)
tmp102 = tl.where(tmp5, tmp101, tmp9)
tmp105 = tmp104.to(tl.float32)
tmp106 = tl.where(tmp12, tmp105, tmp9)
tmp107 = tmp102 + tmp106
tmp108 = tl.where(tmp98, tmp107, tmp9)
tmp109 = tmp96 + tmp108
tmp110 = tl.full([1], 7, tl.int32)
tmp111 = tmp0 == tmp110
tmp114 = tmp113.to(tl.float32)
tmp115 = tl.where(tmp5, tmp114, tmp9)
tmp118 = tmp117.to(tl.float32)
tmp119 = tl.where(tmp12, tmp118, tmp9)
tmp120 = tmp115 + tmp119
tmp121 = tl.where(tmp111, tmp120, tmp9)
tmp122 = tmp109 + tmp121
tl.store(out_ptr0 + (x2), tmp122, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/4i/c4i2c47wbii5qwc7mntsdhih6p2hh5tnkpctkkyukymyjyzxkagu.py
# Topologically Sorted Source Nodes: [v_19], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_19 => convert_element_type_142
# Graph fragment:
# %mul_608 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_312, %select_45), kwargs = {})
# %convert_element_type_142 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_68, torch.float32), kwargs = {})
# %convert_element_type_614 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_608, torch.float32), kwargs = {})
# %mul_610 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %convert_element_type_142), kwargs = {})
# %mul_611 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_614, %rsqrt_27), kwargs = {})
# %sum_79 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_610, [3], True), kwargs = {})
# %div_38 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_55, 128), kwargs = {})
# %pow_164 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_142, 1.0), kwargs = {})
# %mul_614 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_164, 2.0), kwargs = {})
# %mul_615 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_38, %mul_614), kwargs = {})
# %add_312 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_611, %mul_615), kwargs = {})
# %convert_element_type_615 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_312, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_33 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_33', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_33', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_33(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (60))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (60))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/t4/ct4e3vdtvk4aq7yblppd4mvqfalkrljcmgsro5womvtonc2jbebt.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {})
# %mul_638 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_43), kwargs = {})
triton_poi_fused_add_mul_34 = async_compile.triton('triton_poi_fused_add_mul_34', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_34', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_34(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (28))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/6h/c6h57xn6lpkcocdgjcl3vgnzzktkctorpl7zvtirarkpn7qgycrr.py
# Topologically Sorted Source Nodes: [v_16], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_16 => convert_element_type_122
# Graph fragment:
# %mul_648 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_332, %select_39), kwargs = {})
# %convert_element_type_122 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_59, torch.float32), kwargs = {})
# %convert_element_type_645 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_648, torch.float32), kwargs = {})
# %mul_650 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %convert_element_type_122), kwargs = {})
# %mul_651 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_645, %rsqrt_23), kwargs = {})
# %sum_86 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_650, [3], True), kwargs = {})
# %div_42 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_59, 128), kwargs = {})
# %pow_173 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_122, 1.0), kwargs = {})
# %mul_654 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_173, 2.0), kwargs = {})
# %mul_655 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_42, %mul_654), kwargs = {})
# %add_327 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_651, %mul_655), kwargs = {})
# %convert_element_type_646 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_327, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_35 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_35', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_35', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_35(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (58))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (58))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/wl/cwlkwk35d2q4uxlip2fikqjyurywrohrez5ylmucjva6zxu36fx4.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_266 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_258, %view_252), kwargs = {})
# %mul_502 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_266, %select_66), kwargs = {})
# %mul_504 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_502, %select_63), kwargs = {})
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {})
# %mul_637 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %convert_element_type_11), kwargs = {})
# %sum_82 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_637,), kwargs = {})
# %mul_639 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %add_69), kwargs = {})
# %sum_83 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_639,), kwargs = {})
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {})
# %mul_677 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %convert_element_type_11), kwargs = {})
# %sum_89 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_677,), kwargs = {})
# %mul_678 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_37), kwargs = {})
# %mul_679 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %add_58), kwargs = {})
# %sum_90 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_679,), kwargs = {})
# %add_337 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_504, %mul_678), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_36 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_36', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_36', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_36(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp32 = tl.load(in_ptr9 + (36))
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK])
tmp36 = tl.load(in_ptr9 + (4))
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK])
tmp40 = tl.load(in_ptr9 + (26))
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp29 + tmp30
tmp34 = tmp33.to(tl.float32)
tmp35 = tmp31 * tmp34
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp35 * tmp38
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp19 * tmp42
tmp44 = tmp39 + tmp43
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/wb/cwbn6yy6ck5xt3fzhic3wvkkrf6xfqeppzluritiwvusplzfnr6d.py
# Topologically Sorted Source Nodes: [v_13], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_13 => convert_element_type_102
# Graph fragment:
# %mul_688 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_352, %select_33), kwargs = {})
# %convert_element_type_102 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_50, torch.float32), kwargs = {})
# %convert_element_type_676 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_688, torch.float32), kwargs = {})
# %mul_690 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %convert_element_type_102), kwargs = {})
# %mul_691 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_676, %rsqrt_19), kwargs = {})
# %sum_93 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_690, [3], True), kwargs = {})
# %div_46 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_63, 128), kwargs = {})
# %pow_182 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_102, 1.0), kwargs = {})
# %mul_694 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_182, 2.0), kwargs = {})
# %mul_695 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_46, %mul_694), kwargs = {})
# %add_343 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_691, %mul_695), kwargs = {})
# %convert_element_type_677 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_343, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_37 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_37', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_37', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_37(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (56))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (56))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/bv/cbvz6ftrc4j32z2qzbtvnccdakox773dpuzkgduiwvzrk5muc5rw.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_319 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_311, %view_289), kwargs = {})
# %mul_636 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_319, %select_44), kwargs = {})
# %add_321 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_306, %mul_636), kwargs = {})
# %add_334 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_326, %view_300), kwargs = {})
# %mul_676 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_334, %select_38), kwargs = {})
# %add_336 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_321, %mul_676), kwargs = {})
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {})
# %mul_716 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_32), kwargs = {})
# %add_352 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_336, %mul_716), kwargs = {})
# %mul_718 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %select_31), kwargs = {})
triton_poi_fused_add_mul_38 = async_compile.triton('triton_poi_fused_add_mul_38', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_38', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_38(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_out_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp2 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp4 = tl.load(in_ptr2 + (29))
tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
tmp9 = tl.load(in_ptr3 + (x0), None).to(tl.float32)
tmp10 = tl.load(in_ptr4 + (x0), None).to(tl.float32)
tmp12 = tl.load(in_ptr2 + (27))
tmp13 = tl.broadcast_to(tmp12, [XBLOCK])
tmp17 = tl.load(in_ptr5 + (x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (x0), None).to(tl.float32)
tmp20 = tl.load(in_ptr2 + (25))
tmp21 = tl.broadcast_to(tmp20, [XBLOCK])
tmp25 = tl.load(in_ptr2 + (24))
tmp26 = tl.broadcast_to(tmp25, [XBLOCK])
tmp3 = tmp1 + tmp2
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp3 * tmp6
tmp8 = tmp0 + tmp7
tmp11 = tmp9 + tmp10
tmp14 = tmp13.to(tl.float32)
tmp15 = tmp11 * tmp14
tmp16 = tmp8 + tmp15
tmp19 = tmp17 + tmp18
tmp22 = tmp21.to(tl.float32)
tmp23 = tmp19 * tmp22
tmp24 = tmp16 + tmp23
tmp27 = tmp26.to(tl.float32)
tmp28 = tmp19 * tmp27
tl.store(in_out_ptr0 + (x0), tmp24, None)
tl.store(out_ptr0 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ae/caeu7pgawm6k6vk7vpmr2zuhpdwebvifn4flr2bzxrarimqrm2t3.py
# Topologically Sorted Source Nodes: [v_10], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_10 => convert_element_type_82
# Graph fragment:
# %mul_728 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_372, %select_27), kwargs = {})
# %convert_element_type_82 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_41, torch.float32), kwargs = {})
# %convert_element_type_707 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_728, torch.float32), kwargs = {})
# %mul_730 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %convert_element_type_82), kwargs = {})
# %mul_731 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_707, %rsqrt_15), kwargs = {})
# %sum_100 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_730, [3], True), kwargs = {})
# %div_50 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_67, 128), kwargs = {})
# %pow_191 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_82, 1.0), kwargs = {})
# %mul_734 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_191, 2.0), kwargs = {})
# %mul_735 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_50, %mul_734), kwargs = {})
# %add_358 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_731, %mul_735), kwargs = {})
# %convert_element_type_708 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_358, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_39 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_39', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_39', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_39(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (54))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (54))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/e7/ce7jis7nrcnkxdvhijgd3ls2ioixwrag75qnb3dbukxtzj6su45m.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_251 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_243, %view_241), kwargs = {})
# %mul_460 : [num_users=4] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_251, %select_73), kwargs = {})
# %mul_462 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mul_460, %select_70), kwargs = {})
# %add_350 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_342, %view_311), kwargs = {})
# %mul_717 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %convert_element_type_11), kwargs = {})
# %sum_96 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_717,), kwargs = {})
# %mul_719 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_350, %add_47), kwargs = {})
# %sum_97 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_719,), kwargs = {})
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {})
# %mul_757 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %convert_element_type_11), kwargs = {})
# %sum_103 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_757,), kwargs = {})
# %mul_758 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_25), kwargs = {})
# %mul_759 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %add_36), kwargs = {})
# %sum_104 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_759,), kwargs = {})
# %add_368 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_462, %mul_758), kwargs = {})
triton_per_fused__to_copy_add_mul_sum_40 = async_compile.triton('triton_per_fused__to_copy_add_mul_sum_40', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.persistent_reduction(
size_hints={'x': 65536, 'r0_': 1024},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*fp32', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'out_ptr2': '*fp32', 'out_ptr3': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused__to_copy_add_mul_sum_40', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': True, 'num_load': 13, 'num_reduction': 4, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_per_fused__to_copy_add_mul_sum_40(in_out_ptr0, in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, out_ptr0, out_ptr1, out_ptr2, out_ptr3, xnumel, r0_numel):
xnumel = 65536
XBLOCK: tl.constexpr = 1
r0_numel = 1024
R0_BLOCK: tl.constexpr = 1024
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = tl.full([1], xoffset, tl.int32)
xmask = tl.full([R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[:]
r0_offset = 0
r0_mask = tl.full([R0_BLOCK], True, tl.int1)
roffset = r0_offset
rindex = r0_index
r0_1 = r0_index
x0 = xindex
tmp0 = tl.load(in_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp5 = tl.load(in_ptr3 + (x0), None, eviction_policy='evict_last')
tmp12 = tl.load(in_ptr4 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp17 = tl.load(in_ptr5 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp18 = tl.load(in_ptr6 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp24 = tl.load(in_ptr7 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp29 = tl.load(in_out_ptr0 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp30 = tl.load(in_ptr8 + (r0_1 + 1024*x0), None).to(tl.float32)
tmp32 = tl.load(in_ptr9 + (38))
tmp33 = tl.broadcast_to(tmp32, [R0_BLOCK])
tmp36 = tl.load(in_ptr9 + (2))
tmp37 = tl.broadcast_to(tmp36, [R0_BLOCK])
tmp40 = tl.load(in_ptr9 + (22))
tmp41 = tl.broadcast_to(tmp40, [R0_BLOCK])
tmp2 = tmp0 + tmp1
tmp4 = tmp3.to(tl.float32)
tmp6 = tmp4 * tmp5
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp2 * tmp7
tmp9 = tl.broadcast_to(tmp8, [R0_BLOCK])
tmp11 = triton_helpers.promote_to_tensor(tl.sum(tmp9, 0))
tmp13 = tmp2 * tmp12
tmp14 = tl.broadcast_to(tmp13, [R0_BLOCK])
tmp16 = triton_helpers.promote_to_tensor(tl.sum(tmp14, 0))
tmp19 = tmp17 + tmp18
tmp20 = tmp19 * tmp7
tmp21 = tl.broadcast_to(tmp20, [R0_BLOCK])
tmp23 = triton_helpers.promote_to_tensor(tl.sum(tmp21, 0))
tmp25 = tmp19 * tmp24
tmp26 = tl.broadcast_to(tmp25, [R0_BLOCK])
tmp28 = triton_helpers.promote_to_tensor(tl.sum(tmp26, 0))
tmp31 = tmp29 + tmp30
tmp34 = tmp33.to(tl.float32)
tmp35 = tmp31 * tmp34
tmp38 = tmp37.to(tl.float32)
tmp39 = tmp35 * tmp38
tmp42 = tmp41.to(tl.float32)
tmp43 = tmp19 * tmp42
tmp44 = tmp39 + tmp43
tl.store(in_out_ptr0 + (r0_1 + 1024*x0), tmp44, None)
tl.store(out_ptr0 + (x0), tmp11, None)
tl.store(out_ptr1 + (x0), tmp16, None)
tl.store(out_ptr2 + (x0), tmp23, None)
tl.store(out_ptr3 + (x0), tmp28, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/rm/crm2pq6wksbw6bjoc7vc22gi6qau7lq7ec3klu22afni6b5bl6ck.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten._to_copy, aten.select_backward, aten.add]
# Source node to ATen node mapping:
# Graph fragment:
# %convert_element_type_347 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_13, torch.float32), kwargs = {})
# %full_default_20 : [num_users=53] = call_function[target=torch.ops.aten.full.default](args = ([2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_1 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_347, 0, 1), kwargs = {})
# %convert_element_type_348 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_14, torch.float32), kwargs = {})
# %select_scatter_default_2 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_348, 0, 0), kwargs = {})
# %add_184 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_1, %select_scatter_default_2), kwargs = {})
# %full_default_25 : [num_users=31] = call_function[target=torch.ops.aten.full.default](args = ([16, 2], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %select_scatter_default_5 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_184, 0, 15), kwargs = {})
# %convert_element_type_379 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_21, torch.float32), kwargs = {})
# %select_scatter_default_8 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_379, 0, 1), kwargs = {})
# %convert_element_type_380 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_22, torch.float32), kwargs = {})
# %select_scatter_default_9 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_380, 0, 0), kwargs = {})
# %add_197 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_8, %select_scatter_default_9), kwargs = {})
# %select_scatter_default_12 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_197, 0, 14), kwargs = {})
# %add_209 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_5, %select_scatter_default_12), kwargs = {})
# %convert_element_type_411 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_29, torch.float32), kwargs = {})
# %select_scatter_default_15 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_411, 0, 1), kwargs = {})
# %convert_element_type_412 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_30, torch.float32), kwargs = {})
# %select_scatter_default_16 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_412, 0, 0), kwargs = {})
# %add_213 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%select_scatter_default_15, %select_scatter_default_16), kwargs = {})
# %select_scatter_default_19 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %add_213, 0, 13), kwargs = {})
# %add_225 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_209, %select_scatter_default_19), kwargs = {})
# %convert_element_type_443 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_37, torch.float32), kwargs = {})
# %select_scatter_default_22 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_443, 0, 0), kwargs = {})
# %select_scatter_default_25 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_22, 0, 12), kwargs = {})
# %add_240 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_225, %select_scatter_default_25), kwargs = {})
# %convert_element_type_474 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_44, torch.float32), kwargs = {})
# %select_scatter_default_28 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_474, 0, 0), kwargs = {})
# %select_scatter_default_31 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_28, 0, 11), kwargs = {})
# %add_255 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_240, %select_scatter_default_31), kwargs = {})
# %convert_element_type_506 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_52, torch.float32), kwargs = {})
# %select_scatter_default_35 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_506, 0, 0), kwargs = {})
# %select_scatter_default_38 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_35, 0, 10), kwargs = {})
# %add_270 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_255, %select_scatter_default_38), kwargs = {})
# %convert_element_type_538 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_60, torch.float32), kwargs = {})
# %select_scatter_default_42 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_538, 0, 0), kwargs = {})
# %select_scatter_default_45 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_42, 0, 9), kwargs = {})
# %add_286 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_270, %select_scatter_default_45), kwargs = {})
# %convert_element_type_570 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_68, torch.float32), kwargs = {})
# %select_scatter_default_49 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_570, 0, 0), kwargs = {})
# %select_scatter_default_52 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_49, 0, 8), kwargs = {})
# %add_302 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_286, %select_scatter_default_52), kwargs = {})
# %convert_element_type_613 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_78, torch.float32), kwargs = {})
# %select_scatter_default_58 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_613, 0, 0), kwargs = {})
# %select_scatter_default_61 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_58, 0, 6), kwargs = {})
# %add_323 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_302, %select_scatter_default_61), kwargs = {})
# %convert_element_type_644 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_85, torch.float32), kwargs = {})
# %select_scatter_default_64 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_644, 0, 0), kwargs = {})
# %select_scatter_default_67 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_64, 0, 5), kwargs = {})
# %add_339 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_323, %select_scatter_default_67), kwargs = {})
# %convert_element_type_675 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_92, torch.float32), kwargs = {})
# %select_scatter_default_70 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_675, 0, 0), kwargs = {})
# %select_scatter_default_73 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_70, 0, 4), kwargs = {})
# %add_354 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_339, %select_scatter_default_73), kwargs = {})
# %convert_element_type_706 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_99, torch.float32), kwargs = {})
# %select_scatter_default_76 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_20, %convert_element_type_706, 0, 0), kwargs = {})
# %select_scatter_default_79 : [num_users=1] = call_function[target=torch.ops.aten.select_scatter.default](args = (%full_default_25, %select_scatter_default_76, 0, 3), kwargs = {})
# %add_370 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_354, %select_scatter_default_79), kwargs = {})
triton_poi_fused__to_copy_add_select_backward_41 = async_compile.triton('triton_poi_fused__to_copy_add_select_backward_41', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 32},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*bf16', 'in_ptr3': '*bf16', 'in_ptr4': '*bf16', 'in_ptr5': '*bf16', 'in_ptr6': '*bf16', 'in_ptr7': '*bf16', 'in_ptr8': '*bf16', 'in_ptr9': '*bf16', 'in_ptr10': '*bf16', 'in_ptr11': '*bf16', 'in_ptr12': '*bf16', 'in_ptr13': '*bf16', 'in_ptr14': '*bf16', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]], (10,): [['tt.divisibility', 16]], (11,): [['tt.divisibility', 16]], (12,): [['tt.divisibility', 16]], (13,): [['tt.divisibility', 16]], (14,): [['tt.divisibility', 16]], (15,): [['tt.divisibility', 16]], (16,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__to_copy_add_select_backward_41', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 15, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__to_copy_add_select_backward_41(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, in_ptr6, in_ptr7, in_ptr8, in_ptr9, in_ptr10, in_ptr11, in_ptr12, in_ptr13, in_ptr14, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 32
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex // 2
x0 = (xindex % 2)
x2 = xindex
tmp6 = tl.load(in_ptr0 + (0)).to(tl.float32)
tmp7 = tl.broadcast_to(tmp6, [XBLOCK])
tmp13 = tl.load(in_ptr1 + (0)).to(tl.float32)
tmp14 = tl.broadcast_to(tmp13, [XBLOCK])
tmp21 = tl.load(in_ptr2 + (0)).to(tl.float32)
tmp22 = tl.broadcast_to(tmp21, [XBLOCK])
tmp25 = tl.load(in_ptr3 + (0)).to(tl.float32)
tmp26 = tl.broadcast_to(tmp25, [XBLOCK])
tmp34 = tl.load(in_ptr4 + (0)).to(tl.float32)
tmp35 = tl.broadcast_to(tmp34, [XBLOCK])
tmp38 = tl.load(in_ptr5 + (0)).to(tl.float32)
tmp39 = tl.broadcast_to(tmp38, [XBLOCK])
tmp47 = tl.load(in_ptr6 + (0)).to(tl.float32)
tmp48 = tl.broadcast_to(tmp47, [XBLOCK])
tmp55 = tl.load(in_ptr7 + (0)).to(tl.float32)
tmp56 = tl.broadcast_to(tmp55, [XBLOCK])
tmp63 = tl.load(in_ptr8 + (0)).to(tl.float32)
tmp64 = tl.broadcast_to(tmp63, [XBLOCK])
tmp71 = tl.load(in_ptr9 + (0)).to(tl.float32)
tmp72 = tl.broadcast_to(tmp71, [XBLOCK])
tmp79 = tl.load(in_ptr10 + (0)).to(tl.float32)
tmp80 = tl.broadcast_to(tmp79, [XBLOCK])
tmp87 = tl.load(in_ptr11 + (0)).to(tl.float32)
tmp88 = tl.broadcast_to(tmp87, [XBLOCK])
tmp95 = tl.load(in_ptr12 + (0)).to(tl.float32)
tmp96 = tl.broadcast_to(tmp95, [XBLOCK])
tmp103 = tl.load(in_ptr13 + (0)).to(tl.float32)
tmp104 = tl.broadcast_to(tmp103, [XBLOCK])
tmp111 = tl.load(in_ptr14 + (0)).to(tl.float32)
tmp112 = tl.broadcast_to(tmp111, [XBLOCK])
tmp0 = x1
tmp1 = tl.full([1], 15, tl.int32)
tmp2 = tmp0 == tmp1
tmp3 = x0
tmp4 = tl.full([1], 1, tl.int32)
tmp5 = tmp3 == tmp4
tmp8 = tmp7.to(tl.float32)
tmp9 = 0.0
tmp10 = tl.where(tmp5, tmp8, tmp9)
tmp11 = tl.full([1], 0, tl.int32)
tmp12 = tmp3 == tmp11
tmp15 = tmp14.to(tl.float32)
tmp16 = tl.where(tmp12, tmp15, tmp9)
tmp17 = tmp10 + tmp16
tmp18 = tl.where(tmp2, tmp17, tmp9)
tmp19 = tl.full([1], 14, tl.int32)
tmp20 = tmp0 == tmp19
tmp23 = tmp22.to(tl.float32)
tmp24 = tl.where(tmp5, tmp23, tmp9)
tmp27 = tmp26.to(tl.float32)
tmp28 = tl.where(tmp12, tmp27, tmp9)
tmp29 = tmp24 + tmp28
tmp30 = tl.where(tmp20, tmp29, tmp9)
tmp31 = tmp18 + tmp30
tmp32 = tl.full([1], 13, tl.int32)
tmp33 = tmp0 == tmp32
tmp36 = tmp35.to(tl.float32)
tmp37 = tl.where(tmp5, tmp36, tmp9)
tmp40 = tmp39.to(tl.float32)
tmp41 = tl.where(tmp12, tmp40, tmp9)
tmp42 = tmp37 + tmp41
tmp43 = tl.where(tmp33, tmp42, tmp9)
tmp44 = tmp31 + tmp43
tmp45 = tl.full([1], 12, tl.int32)
tmp46 = tmp0 == tmp45
tmp49 = tmp48.to(tl.float32)
tmp50 = tl.where(tmp12, tmp49, tmp9)
tmp51 = tl.where(tmp46, tmp50, tmp9)
tmp52 = tmp44 + tmp51
tmp53 = tl.full([1], 11, tl.int32)
tmp54 = tmp0 == tmp53
tmp57 = tmp56.to(tl.float32)
tmp58 = tl.where(tmp12, tmp57, tmp9)
tmp59 = tl.where(tmp54, tmp58, tmp9)
tmp60 = tmp52 + tmp59
tmp61 = tl.full([1], 10, tl.int32)
tmp62 = tmp0 == tmp61
tmp65 = tmp64.to(tl.float32)
tmp66 = tl.where(tmp12, tmp65, tmp9)
tmp67 = tl.where(tmp62, tmp66, tmp9)
tmp68 = tmp60 + tmp67
tmp69 = tl.full([1], 9, tl.int32)
tmp70 = tmp0 == tmp69
tmp73 = tmp72.to(tl.float32)
tmp74 = tl.where(tmp12, tmp73, tmp9)
tmp75 = tl.where(tmp70, tmp74, tmp9)
tmp76 = tmp68 + tmp75
tmp77 = tl.full([1], 8, tl.int32)
tmp78 = tmp0 == tmp77
tmp81 = tmp80.to(tl.float32)
tmp82 = tl.where(tmp12, tmp81, tmp9)
tmp83 = tl.where(tmp78, tmp82, tmp9)
tmp84 = tmp76 + tmp83
tmp85 = tl.full([1], 6, tl.int32)
tmp86 = tmp0 == tmp85
tmp89 = tmp88.to(tl.float32)
tmp90 = tl.where(tmp12, tmp89, tmp9)
tmp91 = tl.where(tmp86, tmp90, tmp9)
tmp92 = tmp84 + tmp91
tmp93 = tl.full([1], 5, tl.int32)
tmp94 = tmp0 == tmp93
tmp97 = tmp96.to(tl.float32)
tmp98 = tl.where(tmp12, tmp97, tmp9)
tmp99 = tl.where(tmp94, tmp98, tmp9)
tmp100 = tmp92 + tmp99
tmp101 = tl.full([1], 4, tl.int32)
tmp102 = tmp0 == tmp101
tmp105 = tmp104.to(tl.float32)
tmp106 = tl.where(tmp12, tmp105, tmp9)
tmp107 = tl.where(tmp102, tmp106, tmp9)
tmp108 = tmp100 + tmp107
tmp109 = tl.full([1], 3, tl.int32)
tmp110 = tmp0 == tmp109
tmp113 = tmp112.to(tl.float32)
tmp114 = tl.where(tmp12, tmp113, tmp9)
tmp115 = tl.where(tmp110, tmp114, tmp9)
tmp116 = tmp108 + tmp115
tl.store(out_ptr0 + (x2), tmp116, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/tm/ctmmn45fi4yzwoybueak7k3sy6rxrcom6pr6q6vovoawwcixsrod.py
# Topologically Sorted Source Nodes: [v_7], Original ATen: [aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.add]
# Source node to ATen node mapping:
# v_7 => convert_element_type_62
# Graph fragment:
# %mul_770 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_392, %select_20), kwargs = {})
# %convert_element_type_62 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_32, torch.float32), kwargs = {})
# %convert_element_type_739 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_770, torch.float32), kwargs = {})
# %mul_772 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %convert_element_type_62), kwargs = {})
# %mul_773 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_739, %rsqrt_11), kwargs = {})
# %sum_108 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_772, [3], True), kwargs = {})
# %div_54 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_71, 128), kwargs = {})
# %pow_200 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_62, 1.0), kwargs = {})
# %mul_776 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_200, 2.0), kwargs = {})
# %mul_777 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_54, %mul_776), kwargs = {})
# %add_376 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_773, %mul_777), kwargs = {})
# %convert_element_type_740 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_376, torch.bfloat16), kwargs = {})
triton_red_fused__to_copy_add_div_mul_pow_sum_42 = async_compile.triton('triton_red_fused__to_copy_add_div_mul_pow_sum_42', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'out_ptr1': '*bf16', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_mul_pow_sum_42', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 7, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_mul_pow_sum_42(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr1, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (52))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (52))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/iy/ciyfyhl44ar5iouy2c6tng6yyeqta2nvvqv2qkv5v47eqdfdiequ.py
# Topologically Sorted Source Nodes: [], Original ATen: [aten.add, aten.mul]
# Source node to ATen node mapping:
# Graph fragment:
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {})
# %mul_800 : [num_users=2] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_18), kwargs = {})
triton_poi_fused_add_mul_43 = async_compile.triton('triton_poi_fused_add_mul_43', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*bf16', 'in_ptr2': '*fp32', 'out_ptr0': '*bf16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_43', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_mul_43(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 67108864
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = tl.full([XBLOCK], True, tl.int1)
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)
tmp1 = tl.load(in_ptr1 + (x0), None).to(tl.float32)
tmp3 = tl.load(in_ptr2 + (20))
tmp4 = tl.broadcast_to(tmp3, [XBLOCK])
tmp2 = tmp0 + tmp1
tmp5 = tmp4.to(tl.float32)
tmp6 = tmp2 * tmp5
tl.store(out_ptr0 + (x0), tmp6, None)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/de/cdez6a43hw7ae3gtzsdmit5bwldgblxmtsegoexxdboxn64ojfhd.py
# Topologically Sorted Source Nodes: [loss], Original ATen: [aten.nll_loss_forward, aten.add, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {})
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {})
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {})
triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44 = async_compile.triton('triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'x': 67108864},
filename=__file__,
triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 0, 'num_reduction': 0, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_embedding_dense_backward_nll_loss_forward_44(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 51463168
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = 0.0
tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/wc/cwcnc3zg2mhcjo34dem5rdh6dr3jzso4l2wbr3kinwnai7rrc2xj.py
# Topologically Sorted Source Nodes: [loss, v_4], Original ATen: [aten.nll_loss_forward, aten.add, aten.mul, aten._to_copy, aten.sum, aten.div, aten.pow, aten.embedding_dense_backward]
# Source node to ATen node mapping:
# loss => full_default_13
# v_4 => convert_element_type_42
# Graph fragment:
# %full_default_13 : [num_users=5] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %add_391 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%view_204, %view_343), kwargs = {})
# %mul_812 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%permute_412, %select_13), kwargs = {})
# %convert_element_type_42 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%getitem_23, torch.float32), kwargs = {})
# %convert_element_type_771 : [num_users=2] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul_812, torch.float32), kwargs = {})
# %mul_814 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %convert_element_type_42), kwargs = {})
# %mul_815 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_771, %rsqrt_7), kwargs = {})
# %sum_116 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%mul_814, [3], True), kwargs = {})
# %div_58 : [num_users=1] = call_function[target=torch.ops.aten.div.Scalar](args = (%expand_75, 128), kwargs = {})
# %pow_209 : [num_users=1] = call_function[target=torch.ops.aten.pow.Tensor_Scalar](args = (%convert_element_type_42, 1.0), kwargs = {})
# %mul_818 : [num_users=1] = call_function[target=torch.ops.aten.mul.Scalar](args = (%pow_209, 2.0), kwargs = {})
# %mul_819 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%div_58, %mul_818), kwargs = {})
# %add_393 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul_815, %mul_819), kwargs = {})
# %convert_element_type_772 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_393, torch.bfloat16), kwargs = {})
# %full_default_169 : [num_users=4] = call_function[target=torch.ops.aten.full.default](args = ([50257, 1024], 0), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})
# %convert_element_type_827 : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%add_391, torch.float32), kwargs = {})
# %where_26 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%unsqueeze_165, %full_default_13, %convert_element_type_827), kwargs = {})
# %index_put_6 : [num_users=1] = call_function[target=torch.ops.aten.index_put.default](args = (%full_default_169, [%convert_element_type_822], %where_26, True), kwargs = {})
triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45 = async_compile.triton('triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', '''
import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.reduction(
size_hints={'x': 524288, 'r0_': 128},
reduction_hint=ReductionHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*bf16', 'in_ptr1': '*fp32', 'in_ptr2': '*bf16', 'in_ptr3': '*fp32', 'in_ptr4': '*i32', 'in_ptr5': '*bf16', 'out_ptr1': '*bf16', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]], (5,): [['tt.divisibility', 16]], (6,): [['tt.divisibility', 16]], (7,): [['tt.divisibility', 16]], (8,): [['tt.divisibility', 16]], (9,): [['tt.divisibility', 16]]}]},
inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45', 'mutated_arg_names': ['out_ptr2'], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 11, 'num_reduction': 1, 'backend_hash': 'C1BD6C778D5DBCD2C2E28504AEE73FCFB458E194A308F3EF755A092106C6A95D', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'coordinate_descent_tuning': True, 'coordinate_descent_search_radius': 1, 'coordinate_descent_check_all_directions': False}
)
@triton.jit
def triton_red_fused__to_copy_add_div_embedding_dense_backward_mul_nll_loss_forward_pow_sum_45(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 524288
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
x3 = xindex
tmp1 = tl.load(in_ptr1 + (50))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
x0 = (xindex % 8)
x1 = xindex // 8
_tmp10 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp0 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp6 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp0 * tmp3
tmp5 = tmp4.to(tl.float32)
tmp7 = tmp6.to(tl.float32)
tmp8 = tmp5 * tmp7
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, R0_BLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(r0_mask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
tmp13 = tl.load(in_ptr1 + (50))
tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK])
tmp18 = tl.load(in_ptr3 + (x3), None, eviction_policy='evict_last')
tmp34 = tl.load(in_ptr4 + (x1), None, eviction_policy='evict_last')
tmp44 = tl.load(in_ptr1 + (77))
tmp45 = tl.broadcast_to(tmp44, [XBLOCK, R0_BLOCK])
tmp48 = tl.load(in_ptr1 + (51))
tmp49 = tl.broadcast_to(tmp48, [XBLOCK, R0_BLOCK])
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp12 = tl.load(in_ptr0 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp27 = tl.load(in_ptr2 + (r0_2 + 128*x0 + 3072*x1), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp43 = tl.load(in_ptr5 + (r0_2 + 128*x3), r0_mask, eviction_policy='evict_first', other=0.0).to(tl.float32)
tmp15 = tmp14.to(tl.float32)
tmp16 = tmp12 * tmp15
tmp17 = tmp16.to(tl.float32)
tmp19 = tmp17 * tmp18
tmp20 = -0.5
tmp21 = tmp10 * tmp20
tmp22 = tmp18 * tmp18
tmp23 = tmp22 * tmp18
tmp24 = tmp21 * tmp23
tmp25 = 0.0078125
tmp26 = tmp24 * tmp25
tmp28 = tmp27.to(tl.float32)
tmp29 = 2.0
tmp30 = tmp28 * tmp29
tmp31 = tmp26 * tmp30
tmp32 = tmp19 + tmp31
tmp33 = tmp32.to(tl.float32)
tmp35 = tmp34.to(tl.int64)
tmp36 = tl.full([XBLOCK, R0_BLOCK], 50257, tl.int32)
tmp37 = tmp35 + tmp36
tmp38 = tmp35 < 0
tmp39 = tl.where(tmp38, tmp37, tmp35)
tl.device_assert((0 <= tmp39) & (tmp39 < 50257), "index out of bounds: 0 <= tmp39 < 50257")
tmp41 = tl.full([1, 1], -1, tl.int64)
tmp42 = tmp35 == tmp41
tmp46 = tmp45.to(tl.float32)
tmp47 = tmp43 * tmp46
tmp50 = tmp49.to(tl.float32)
tmp51 = tmp12 * tmp50
tmp52 = tmp47 + tmp51
tmp53 = tmp52.to(tl.float32)
tmp54 = 0.0
tmp55 = tl.where(tmp42, tmp54, tmp53)
tl.store(out_ptr1 + (r0_2 + 128*x0 + 3072*x1), tmp33, r0_mask)
tl.atomic_add(out_ptr2 + (r0_2 + 128*x0 + 1024*tmp39), tmp55, r0_mask, sem='relaxed')
''', device_str='cuda')
# kernel path: /tmp/torchinductor_xmfan/tmpez3lej8d/ty/ctyl36nsnpehqd4df5ksk7bsvoehmhqsdu3iojfgd3yky7vkrutv.py
# Topologically Sorted Source Nodes: [x], Original ATen: [aten._to_copy, aten.mul, aten.add, aten.sum]
# Source node to ATen node mapping:
# x => convert_element_type_10, convert_element_type_11, mul
# Graph fragment:
# %convert_element_type_10 : [num_users=3] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%unsqueeze_42, torch.float32), kwargs = {})
# %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convert_element_type_10, %rsqrt), kwargs = {})
# %convert_element_type_11 : [num_users=16] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%mul, torch.bfloat16), kwargs = {})
# %add_365 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_357, %view_322), kwargs = {})
# %mul_756 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_365, %select_26), kwargs = {})
# %add_367 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_352, %mul_756), kwargs = {})
# %add_383 : [num_users=4] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_373, %view_334), kwargs = {})
# %mul_798 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %select_19), kwargs = {})
# %mul_799 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add_383, %convert_element_type_11), kwargs = {})
# %sum_111 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%mul_799,), kwargs = {})
# %add_385 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_367, %mul_798), kwargs = {})
# %m
View raw

(Sorry about that, but we can’t show files that are this big right now.)

View raw

(Sorry about that, but we can’t show files that are this big right now.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment