Created
June 26, 2025 13:11
-
-
Save Birch-san/a02d6146af2f2f943233736f7fa12165 to your computer and use it in GitHub Desktop.
(Can't run this) triton JVP attn blind-code for TensorDescriptor-era triton
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from argparse import ArgumentParser, Namespace | |
from dataclasses import dataclass | |
from functools import partial | |
from os import environ | |
from typing import Any, Callable, NamedTuple, Optional | |
import torch | |
from torch import Tensor, no_grad, enable_grad | |
from torch.autograd import Function | |
from torch.autograd.function import FunctionCtx | |
import torch.autograd.forward_ad as fwAD | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.nn import MSELoss | |
from torch.utils.flop_counter import FlopCounterMode | |
import triton | |
import triton.language as tl | |
from triton.testing import do_bench | |
try: | |
from triton.tools.tensor_descriptor import TensorDescriptor | |
HAS_TENSOR_DESC = True | |
except ModuleNotFoundError: | |
HAS_TENSOR_DESC = False | |
NiladicFn = Callable[[], None] | |
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float: | |
iters_per_second = 1e3 / ms_per_iter | |
return iters_per_second * flop_count | |
def fmt_flops(flops: int) -> str: | |
return f"{flops / 1e12:5.1f} TFLOP/s" | |
# Python *please* bring back support for generic NamedTuples | |
def get_flop_count(f: Callable[[], Any], display_ops=True) -> int: | |
flop_counter = FlopCounterMode(display=display_ops) | |
with flop_counter: | |
f() | |
return flop_counter.get_total_flops() | |
# --- Typical attn fwd kernel from Triton tutorial --- | |
# DEVICE = triton.runtime.driver.active.get_active_torch_device() | |
def is_hip(): | |
return triton.runtime.driver.active.get_current_target().backend == "hip" | |
def is_cuda(): | |
return triton.runtime.driver.active.get_current_target().backend == "cuda" | |
def supports_host_descriptor(): | |
return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9 | |
def is_blackwell(): | |
return is_cuda() and torch.cuda.get_device_capability()[0] == 10 | |
@triton.jit | |
def _attn_fwd_inner( | |
acc, g_acc, | |
l_i, m_i, | |
mu_i, p_tv_acc, | |
q, desc_k, desc_v, # | |
q_t, desc_k_t, desc_v_t, # | |
offset_y, dtype: tl.constexpr, start_m, qk_scale, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # | |
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # | |
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, | |
): | |
# range of values handled by this stage | |
if STAGE == 1: | |
lo, hi = 0, start_m * BLOCK_M | |
elif STAGE == 2: | |
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
lo = tl.multiple_of(lo, BLOCK_M) | |
# causal = False | |
else: | |
lo, hi = 0, N_CTX | |
offsetkv_y = offset_y + lo | |
# loop over k, v and update accumulator | |
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
k = desc_k.load([offsetkv_y, 0]).T | |
k_t = desc_k_t.load([offsetkv_y, 0]).T | |
qk = tl.dot(q, k) | |
qk_t = tl.dot(q_t, tl.trans(k)) + tl.dot(q, tl.trans(k_t)) | |
qk_t = qk_t * qk_scale | |
if STAGE == 2: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) | |
qk_t += tl.where(mask, 0, -1.0e6) | |
# TODO: do we need a separate row maximum for qk_t? | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk -= m_ij[:, None] | |
else: | |
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | |
qk = qk * qk_scale - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
p_tqk = p * qk_t | |
# -- compute correction factor | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_ij = tl.sum(p, 1) | |
# -- update output accumulator -- | |
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: | |
BM: tl.constexpr = acc.shape[0] | |
BN: tl.constexpr = acc.shape[1] | |
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
acc0 = acc0 * alpha[:, None] | |
acc1 = acc1 * alpha[:, None] | |
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
acc = acc * alpha[:, None] | |
# prepare p and v for the dot | |
v = desc_v.load([offsetkv_y, 0]) | |
v_t = desc_v_t.load([offsetkv_y, 0]) | |
p = p.to(dtype) | |
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk, v) | |
# note that this non transposed v for FP8 is only supported on Blackwell | |
acc = tl.dot(p, v, acc) | |
# update m_i and l_i | |
# place this at the end of the loop to reduce register pressure | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, v_t) | |
l_i = l_i * alpha + l_ij | |
m_i = m_ij | |
offsetkv_y += BLOCK_N | |
return acc, l_i, m_i | |
def _host_descriptor_pre_hook(nargs): | |
BLOCK_M = nargs["BLOCK_M"] | |
BLOCK_N = nargs["BLOCK_N"] | |
HEAD_DIM = nargs["HEAD_DIM"] | |
if not HAS_TENSOR_DESC or not isinstance(nargs["desc_q"], TensorDescriptor): | |
return | |
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] | |
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] | |
if is_hip(): | |
NUM_STAGES_OPTIONS = [1] | |
elif supports_host_descriptor(): | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
else: | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
configs = [ | |
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \ | |
for BM in [64, 128]\ | |
for BN in [64, 128]\ | |
for s in NUM_STAGES_OPTIONS \ | |
for w in [4, 8]\ | |
] | |
if "PYTEST_VERSION" in environ: | |
# Use a single config in testing for reproducibility | |
configs = [ | |
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook), | |
] | |
def keep(conf): | |
BLOCK_M = conf.kwargs["BLOCK_M"] | |
BLOCK_N = conf.kwargs["BLOCK_N"] | |
return not (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8) | |
def prune_invalid_configs(configs, named_args, **kwargs): | |
N_CTX = kwargs["N_CTX"] | |
# Filter out configs where BLOCK_M > N_CTX | |
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX] | |
if HAS_TENSOR_DESC: | |
@triton.jit | |
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): | |
if isinstance(desc_or_ptr, tl.tensor_descriptor): | |
return desc_or_ptr | |
else: | |
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) | |
else: | |
@triton.jit | |
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): | |
# If we don't have tensor descriptors, just return the pointer | |
return desc_or_ptr | |
# --- Typical attn forward kernel from Triton tutorial --- | |
# https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py | |
# MIT Licensed | |
# https://github.com/triton-lang/triton/blob/main/LICENSE | |
# with modifications to support JVP | |
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], | |
prune_configs_by={'early_config_prune': prune_invalid_configs}) | |
@triton.jit | |
def _attn_fwd(sm_scale, M, # | |
Z, H, | |
desc_q, desc_k, desc_v, | |
desc_q_t, desc_k_t, desc_v_t, | |
desc_o, desc_o_t, | |
N_CTX, # | |
HEAD_DIM: tl.constexpr, # | |
BLOCK_M: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
FP8_OUTPUT: tl.constexpr, # | |
STAGE: tl.constexpr, # | |
warp_specialize: tl.constexpr, # | |
): | |
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 | |
tl.static_assert(BLOCK_N <= HEAD_DIM) | |
start_m = tl.program_id(0) # Combined batch and head index | |
off_hz = tl.program_id(1) # Query block index | |
off_z = off_hz // H | |
off_h = off_hz % H | |
y_dim = Z * H * N_CTX | |
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_q_t = _maybe_make_tensor_desc(desc_q_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
desc_v_t = _maybe_make_tensor_desc(desc_v_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_k_t = _maybe_make_tensor_desc(desc_k_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
desc_o_t = _maybe_make_tensor_desc(desc_o_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
offset_y = off_z * (N_CTX * H) + off_h * N_CTX | |
qo_offset_y = offset_y + start_m * BLOCK_M | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
# load scales | |
qk_scale = sm_scale | |
qk_scale *= 1.44269504 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
q = desc_q.load([qo_offset_y, 0]) | |
q_t = desc_q_t.load([qo_offset_y, 0]) | |
# stage 1: off-band | |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE | |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE | |
if STAGE & 1: | |
acc, l_i, m_i = _attn_fwd_inner( | |
acc, g_acc, | |
l_i, m_i, | |
mu_i, p_tv_acc, | |
q, desc_k, desc_v, # | |
q_t, desc_k_t, desc_v_t, # | |
offset_y, dtype, start_m, qk_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
4 - STAGE, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
) | |
# stage 2: on-band | |
if STAGE & 2: | |
acc, l_i, m_i = _attn_fwd_inner( | |
acc, g_acc, | |
l_i, m_i, | |
mu_i, p_tv_acc, | |
q, desc_k, desc_v, # | |
q_t, desc_k_t, desc_v_t, # | |
offset_y, dtype, start_m, qk_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
2, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
) | |
# epilogue | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
tl.store(m_ptrs, m_i) | |
desc_o.store([qo_offset_y, 0], acc.to(dtype)) | |
desc_o_t.store([qo_offset_y, 0], t_y_out.to(dtype)) | |
# --- Typical attn backward kernel from Triton tutorial --- | |
# https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py | |
# MIT Licensed | |
# https://github.com/triton-lang/triton/blob/main/LICENSE | |
@triton.jit | |
def _attn_bwd_preprocess(O, DO, # | |
Delta, # | |
Z, H, N_CTX, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # | |
): | |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) | |
off_hz = tl.program_id(1) | |
off_n = tl.arange(0, HEAD_DIM) | |
# load | |
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) | |
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) | |
delta = tl.sum(o * do, axis=1) | |
# write-back | |
tl.store(Delta + off_hz * N_CTX + off_m, delta) | |
# The main inner-loop logic for computing dK and dV. | |
@triton.jit | |
def _attn_bwd_dkdv(dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
# shared by Q/K/V/DO. | |
stride_tok, stride_d, # | |
H, N_CTX, BLOCK_M1: tl.constexpr, # | |
BLOCK_N1: tl.constexpr, # | |
HEAD_DIM: tl.constexpr, # | |
# Filled in by the wrapper. | |
start_n, start_m, num_steps, # | |
MASK: tl.constexpr): | |
offs_m = start_m + tl.arange(0, BLOCK_M1) | |
offs_n = start_n + tl.arange(0, BLOCK_N1) | |
offs_k = tl.arange(0, HEAD_DIM) | |
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d | |
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d | |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
curr_m = start_m | |
step_m = BLOCK_M1 | |
for blk_idx in range(num_steps): | |
qT = tl.load(qT_ptrs) | |
# Load m before computing qk to reduce pipeline stall. | |
offs_m = curr_m + tl.arange(0, BLOCK_M1) | |
m = tl.load(M + offs_m) | |
qkT = tl.dot(k, qT) | |
pT = tl.math.exp2(qkT - m[None, :]) | |
# Autoregressive masking. | |
if MASK: | |
mask = (offs_m[None, :] >= offs_n[:, None]) | |
pT = tl.where(mask, pT, 0.0) | |
do = tl.load(do_ptrs) | |
# Compute dV. | |
ppT = pT | |
ppT = ppT.to(tl.float16) | |
dv += tl.dot(ppT, do) | |
# D (= delta) is pre-divided by ds_scale. | |
Di = tl.load(D + offs_m) | |
# Compute dP and dS. | |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32) | |
dsT = pT * (dpT - Di[None, :]) | |
dsT = dsT.to(tl.float16) | |
dk += tl.dot(dsT, tl.trans(qT)) | |
# Increment pointers. | |
curr_m += step_m | |
qT_ptrs += step_m * stride_tok | |
do_ptrs += step_m * stride_tok | |
return dk, dv | |
# the main inner-loop logic for computing dQ | |
@triton.jit | |
def _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, | |
# shared by Q/K/V/DO. | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2: tl.constexpr, # | |
BLOCK_N2: tl.constexpr, # | |
HEAD_DIM: tl.constexpr, | |
# Filled in by the wrapper. | |
start_m, start_n, num_steps, # | |
MASK: tl.constexpr): | |
offs_m = start_m + tl.arange(0, BLOCK_M2) | |
offs_n = start_n + tl.arange(0, BLOCK_N2) | |
offs_k = tl.arange(0, HEAD_DIM) | |
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d | |
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d | |
# D (= delta) is pre-divided by ds_scale. | |
Di = tl.load(D + offs_m) | |
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
curr_n = start_n | |
step_n = BLOCK_N2 | |
for blk_idx in range(num_steps): | |
kT = tl.load(kT_ptrs) | |
vT = tl.load(vT_ptrs) | |
qk = tl.dot(q, kT) | |
p = tl.math.exp2(qk - m) | |
# Autoregressive masking. | |
if MASK: | |
offs_n = curr_n + tl.arange(0, BLOCK_N2) | |
mask = (offs_m[:, None] >= offs_n[None, :]) | |
p = tl.where(mask, p, 0.0) | |
# Compute dP and dS. | |
dp = tl.dot(do, vT).to(tl.float32) | |
ds = p * (dp - Di[:, None]) | |
ds = ds.to(tl.float16) | |
# Compute dQ. | |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled. | |
dq += tl.dot(ds, tl.trans(kT)) | |
# Increment pointers. | |
curr_n += step_n | |
kT_ptrs += step_n * stride_tok | |
vT_ptrs += step_n * stride_tok | |
return dq | |
@triton.jit | |
def _attn_bwd(Q, K, V, sm_scale, # | |
DO, # | |
DQ, DK, DV, # | |
M, D, | |
# shared by Q/K/V/DO. | |
stride_z, stride_h, stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M1: tl.constexpr, # | |
BLOCK_N1: tl.constexpr, # | |
BLOCK_M2: tl.constexpr, # | |
BLOCK_N2: tl.constexpr, # | |
BLK_SLICE_FACTOR: tl.constexpr, # | |
HEAD_DIM: tl.constexpr): | |
LN2: tl.constexpr = 0.6931471824645996 # = ln(2) | |
bhid = tl.program_id(2) | |
off_chz = (bhid * N_CTX).to(tl.int64) | |
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) | |
pid = tl.program_id(0) | |
# offset pointers for batch/head | |
Q += adj | |
K += adj | |
V += adj | |
DO += adj | |
DQ += adj | |
DK += adj | |
DV += adj | |
M += off_chz | |
D += off_chz | |
# load scales | |
offs_k = tl.arange(0, HEAD_DIM) | |
start_n = pid * BLOCK_N1 | |
start_m = start_n | |
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | |
offs_n = start_n + tl.arange(0, BLOCK_N1) | |
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) | |
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) | |
# load K and V: they stay in SRAM throughout the inner loop. | |
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
num_steps = BLOCK_N1 // MASK_BLOCK_M1 | |
dk, dv = _attn_bwd_dkdv(dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # | |
start_n, start_m, num_steps, # | |
MASK=True # | |
) | |
start_m += num_steps * MASK_BLOCK_M1 | |
num_steps = (N_CTX - start_m) // BLOCK_M1 | |
# Compute dK and dV for non-masked blocks. | |
dk, dv = _attn_bwd_dkdv( # | |
dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M1, BLOCK_N1, HEAD_DIM, # | |
start_n, start_m, num_steps, # | |
MASK=False # | |
) | |
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d | |
tl.store(dv_ptrs, dv) | |
# Write back dK. | |
dk *= sm_scale | |
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d | |
tl.store(dk_ptrs, dk) | |
# THIS BLOCK DOES DQ: | |
start_m = pid * BLOCK_M2 | |
end_n = start_m + BLOCK_M2 | |
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR | |
offs_m = start_m + tl.arange(0, BLOCK_M2) | |
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) | |
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
m = tl.load(M + offs_m) | |
m = m[:, None] | |
# Compute dQ for masked (diagonal) blocks. | |
# NOTE: This code scans each row of QK^T backward (from right to left, | |
# but inside each call to _attn_bwd_dq, from left to right), but that's | |
# not due to anything important. I just wanted to reuse the loop | |
# structure for dK & dV above as much as possible. | |
num_steps = BLOCK_M2 // MASK_BLOCK_N2 | |
dq = _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # | |
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # | |
MASK=True # | |
) | |
end_n -= num_steps * MASK_BLOCK_N2 | |
# stage 2 | |
num_steps = end_n // BLOCK_N2 | |
dq = _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2, BLOCK_N2, HEAD_DIM, # | |
start_m, end_n - num_steps * BLOCK_N2, num_steps, # | |
MASK=False # | |
) | |
# Write back dQ. | |
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d | |
dq *= LN2 | |
tl.store(dq_ptrs, dq) | |
class JVPAttn(Function): | |
class FnCtx(FunctionCtx): | |
sm_scale: float | |
HEAD_DIM: int | |
causal: bool | |
@staticmethod | |
def forward( | |
ctx: FnCtx, | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
q_t: Tensor, | |
k_t: Tensor, | |
v_t: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
) -> Tensor: | |
# shape constraints | |
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] | |
sm_scale = HEAD_DIM_K ** -.5 if sm_scale is None else sm_scale | |
# when v is in float8_e5m2 it is transposed. | |
HEAD_DIM_V = v.shape[-1] | |
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V | |
assert HEAD_DIM_K in {16, 32, 64, 128, 256} | |
o = torch.empty_like(q) | |
o_t = torch.empty_like(q_t) | |
stage = 3 if causal else 1 | |
extra_kern_args = {} | |
# Tuning for AMD target | |
if is_hip(): | |
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 | |
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} | |
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) | |
if supports_host_descriptor(): | |
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor | |
y_dim = q.shape[0] * q.shape[1] * q.shape[2] | |
dummy_block = [1, 1] | |
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_q_t = TensorDescriptor(q_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_v_t = TensorDescriptor(v_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_k_t = TensorDescriptor(k_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_o_t = TensorDescriptor(o_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
else: | |
desc_q = q | |
desc_v = v | |
desc_k = k | |
desc_o = o | |
desc_q_t = q_t | |
desc_v_t = v_t | |
desc_k_t = k_t | |
desc_o_t = o_t | |
if hasattr(triton, "set_allocator"): | |
def alloc_fn(size: int, align: int, _): | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
def grid(META): | |
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) | |
ctx.grid = grid | |
if is_cuda() and warp_specialize: | |
if HEAD_DIM_K == 128 and q.dtype == torch.float16: | |
extra_kern_args["maxnreg"] = 168 | |
else: | |
extra_kern_args["maxnreg"] = 80 | |
_attn_fwd[grid]( | |
sm_scale, M, # | |
q.shape[0], q.shape[1], # | |
desc_q, desc_k, desc_v, | |
desc_q_t, desc_k_t, desc_v_t, | |
desc_o, desc_o_t, # | |
N_CTX=q.shape[2], # | |
HEAD_DIM=HEAD_DIM_K, # | |
FP8_OUTPUT=q.dtype == torch.float8_e5m2, # | |
STAGE=stage, # | |
warp_specialize=warp_specialize, # | |
**extra_kern_args) | |
ctx.save_for_forward(o_t) | |
ctx.save_for_backward(q, k, v, o, M) | |
ctx.sm_scale = sm_scale | |
ctx.HEAD_DIM = HEAD_DIM_K | |
ctx.causal = causal | |
return o | |
@staticmethod | |
def fwd_dual( | |
Q: Tensor, | |
K: Tensor, | |
V: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround for invoking | |
JVPAttn::forward with the right arguments when you have a dual tensor input. | |
""" | |
q_p, q_t = fwAD.unpack_dual(Q) | |
k_p, k_t = fwAD.unpack_dual(K) | |
v_p, v_t = fwAD.unpack_dual(V) | |
# we pass some dualtensor args to ensure jvp() will be called | |
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason | |
a = JVPAttn.apply(Q, K, V, q_t, k_t, v_t, causal, sm_scale, warp_specialize) | |
return a | |
@staticmethod | |
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> Tensor: | |
return ctx.saved_for_forward[0] | |
@staticmethod | |
def backward(ctx: JVPAttn.FnCtx, do: Tensor) -> Tensor: | |
q, k, v, o, M = ctx.saved_tensors | |
assert do.is_contiguous() | |
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() | |
dq = torch.empty_like(q) | |
dk = torch.empty_like(k) | |
dv = torch.empty_like(v) | |
BATCH, N_HEAD, N_CTX = q.shape[:3] | |
PRE_BLOCK = 128 | |
NUM_WARPS, NUM_STAGES = 4, 5 | |
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 | |
BLK_SLICE_FACTOR = 2 | |
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) | |
arg_k = k | |
arg_k = arg_k * (ctx.sm_scale * RCP_LN2) | |
PRE_BLOCK = 128 | |
assert N_CTX % PRE_BLOCK == 0 | |
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) | |
delta = torch.empty_like(M) | |
_attn_bwd_preprocess[pre_grid]( | |
o, do, # | |
delta, # | |
BATCH, N_HEAD, N_CTX, # | |
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM # | |
) | |
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) | |
_attn_bwd[grid]( | |
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # | |
M, delta, # | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), # | |
N_HEAD, N_CTX, # | |
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # | |
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # | |
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # | |
HEAD_DIM=ctx.HEAD_DIM, # | |
num_warps=NUM_WARPS, # | |
num_stages=NUM_STAGES # | |
) | |
return dq, dk, dv, None, None, None, None | |
@dataclass | |
class Args: | |
bsz: int | |
model_dim: int | |
head_dim: int | |
seq_len: int | |
@staticmethod | |
def get_parser() -> ArgumentParser: | |
parser = ArgumentParser() | |
parser.add_argument("--bsz", default=1, type=int) | |
parser.add_argument("--model-dim", default=320, type=int) | |
parser.add_argument("--head-dim", default=64, type=int) | |
parser.add_argument("--seq-len", default=128, type=int) | |
return parser | |
@staticmethod | |
def from_namespace(namespace: Namespace) -> Args: | |
args = Args(**vars(namespace)) | |
return args | |
def main(args: Args) -> None: | |
device = torch.device('cuda') | |
dtype = torch.float16 | |
seed = 42 | |
gen = torch.Generator(device=device) | |
heads = args.model_dim // args.head_dim | |
q_p, q_t, k_p, k_t, v_p, v_t, target = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(7)) | |
# for t in (q_p, k_p, v_p): | |
# t.requires_grad = True | |
# t.retain_grad() | |
# loss_fn = MSELoss() | |
# if we use MSELoss, we get this error: | |
# ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone() if you want a mutable tensor. | |
def loss_fn(out: Tensor, target: Tensor) -> Tensor: | |
return (out - target).square().mean() | |
def make_qkv(q_p, k_p, v_p, q_t, k_t, v_t): | |
q, k, v = (fwAD.make_dual(p, t) for p, t in zip((q_p, k_p, v_p), (q_t, k_t, v_t))) | |
for t in (q, k, v): | |
t.requires_grad = True | |
t.retain_grad() | |
return q, k, v | |
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad(): | |
q0, k0, v0 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()) | |
sdpa_out = scaled_dot_product_attention(q0, k0, v0) | |
sdpa_out.retain_grad() | |
sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out) | |
loss0: Tensor = loss_fn(sdpa_out, target) | |
loss0.backward() | |
q1, k1, v1 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()) | |
ag_out = JVPAttn.fwd_dual(q1, k1, v1) | |
ag_out.retain_grad() | |
ag_op, ag_ot = fwAD.unpack_dual(ag_out) | |
torch.testing.assert_close(ag_op, sdpa_op, atol=5e-4, rtol=1e-5) | |
torch.testing.assert_close(ag_ot, sdpa_ot, atol=5e-4, rtol=1e-5) | |
loss2: Tensor = loss_fn(ag_out, target) | |
torch.testing.assert_close(loss2, loss0, atol=5e-4, rtol=1e-5) | |
loss2.backward() | |
torch.testing.assert_close(q1.grad, q0.grad, atol=5e-4, rtol=1e-5) | |
torch.testing.assert_close(k1.grad, k0.grad, atol=5e-4, rtol=1e-5) | |
torch.testing.assert_close(v1.grad, v0.grad, atol=5e-4, rtol=1e-5) | |
pass | |
pass | |
if __name__ == "__main__": | |
parser = Args.get_parser() | |
args_untyped: Namespace = parser.parse_args() | |
args: Args = Args.from_namespace(args_untyped) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I took almost-latest triton fused attn tutorial code (it doesn't have the fp8 fixes, I didn't realize there was a newer commit available),
and blind-coded "how would I add JVP to this".
Then I tried running it and found that my triton version doesn't support TensorDescriptor, so I won't be able to verify if it works.
I will start again using a blockptr-based implementation (https://gist.github.com/Birch-san/0e852a42a933d3a1c0fcae21ccd15200).