Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created June 26, 2025 13:11
Show Gist options
  • Save Birch-san/a02d6146af2f2f943233736f7fa12165 to your computer and use it in GitHub Desktop.
Save Birch-san/a02d6146af2f2f943233736f7fa12165 to your computer and use it in GitHub Desktop.
(Can't run this) triton JVP attn blind-code for TensorDescriptor-era triton
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from os import environ
from typing import Any, Callable, NamedTuple, Optional
import torch
from torch import Tensor, no_grad, enable_grad
from torch.autograd import Function
from torch.autograd.function import FunctionCtx
import torch.autograd.forward_ad as fwAD
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.nn import MSELoss
from torch.utils.flop_counter import FlopCounterMode
import triton
import triton.language as tl
from triton.testing import do_bench
try:
from triton.tools.tensor_descriptor import TensorDescriptor
HAS_TENSOR_DESC = True
except ModuleNotFoundError:
HAS_TENSOR_DESC = False
NiladicFn = Callable[[], None]
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
iters_per_second = 1e3 / ms_per_iter
return iters_per_second * flop_count
def fmt_flops(flops: int) -> str:
return f"{flops / 1e12:5.1f} TFLOP/s"
# Python *please* bring back support for generic NamedTuples
def get_flop_count(f: Callable[[], Any], display_ops=True) -> int:
flop_counter = FlopCounterMode(display=display_ops)
with flop_counter:
f()
return flop_counter.get_total_flops()
# --- Typical attn fwd kernel from Triton tutorial ---
# DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_host_descriptor():
return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def is_blackwell():
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
@triton.jit
def _attn_fwd_inner(
acc, g_acc,
l_i, m_i,
mu_i, p_tv_acc,
q, desc_k, desc_v, #
q_t, desc_k_t, desc_v_t, #
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr, warp_specialize: tl.constexpr,
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
offsetkv_y = offset_y + lo
# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = desc_k.load([offsetkv_y, 0]).T
k_t = desc_k_t.load([offsetkv_y, 0]).T
qk = tl.dot(q, k)
qk_t = tl.dot(q_t, tl.trans(k)) + tl.dot(q, tl.trans(k_t))
qk_t = qk_t * qk_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
qk_t += tl.where(mask, 0, -1.0e6)
# TODO: do we need a separate row maximum for qk_t?
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
p_tqk = p * qk_t
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# -- update output accumulator --
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
BM: tl.constexpr = acc.shape[0]
BN: tl.constexpr = acc.shape[1]
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
else:
acc = acc * alpha[:, None]
# prepare p and v for the dot
v = desc_v.load([offsetkv_y, 0])
v_t = desc_v_t.load([offsetkv_y, 0])
p = p.to(dtype)
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk, v)
# note that this non transposed v for FP8 is only supported on Blackwell
acc = tl.dot(p, v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, v_t)
l_i = l_i * alpha + l_ij
m_i = m_ij
offsetkv_y += BLOCK_N
return acc, l_i, m_i
def _host_descriptor_pre_hook(nargs):
BLOCK_M = nargs["BLOCK_M"]
BLOCK_N = nargs["BLOCK_N"]
HEAD_DIM = nargs["HEAD_DIM"]
if not HAS_TENSOR_DESC or not isinstance(nargs["desc_q"], TensorDescriptor):
return
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
if is_hip():
NUM_STAGES_OPTIONS = [1]
elif supports_host_descriptor():
NUM_STAGES_OPTIONS = [2, 3, 4]
else:
NUM_STAGES_OPTIONS = [2, 3, 4]
configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
for BM in [64, 128]\
for BN in [64, 128]\
for s in NUM_STAGES_OPTIONS \
for w in [4, 8]\
]
if "PYTEST_VERSION" in environ:
# Use a single config in testing for reproducibility
configs = [
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook),
]
def keep(conf):
BLOCK_M = conf.kwargs["BLOCK_M"]
BLOCK_N = conf.kwargs["BLOCK_N"]
return not (torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8)
def prune_invalid_configs(configs, named_args, **kwargs):
N_CTX = kwargs["N_CTX"]
# Filter out configs where BLOCK_M > N_CTX
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
if HAS_TENSOR_DESC:
@triton.jit
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
if isinstance(desc_or_ptr, tl.tensor_descriptor):
return desc_or_ptr
else:
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
else:
@triton.jit
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
# If we don't have tensor descriptors, just return the pointer
return desc_or_ptr
# --- Typical attn forward kernel from Triton tutorial ---
# https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
# MIT Licensed
# https://github.com/triton-lang/triton/blob/main/LICENSE
# with modifications to support JVP
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
prune_configs_by={'early_config_prune': prune_invalid_configs})
@triton.jit
def _attn_fwd(sm_scale, M, #
Z, H,
desc_q, desc_k, desc_v,
desc_q_t, desc_k_t, desc_v_t,
desc_o, desc_o_t,
N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
STAGE: tl.constexpr, #
warp_specialize: tl.constexpr, #
):
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0) # Combined batch and head index
off_hz = tl.program_id(1) # Query block index
off_z = off_hz // H
off_h = off_hz % H
y_dim = Z * H * N_CTX
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_q_t = _maybe_make_tensor_desc(desc_q_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
desc_v_t = _maybe_make_tensor_desc(desc_v_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_k_t = _maybe_make_tensor_desc(desc_k_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
desc_o_t = _maybe_make_tensor_desc(desc_o_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
qo_offset_y = offset_y + start_m * BLOCK_M
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = desc_q.load([qo_offset_y, 0])
q_t = desc_q_t.load([qo_offset_y, 0])
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(
acc, g_acc,
l_i, m_i,
mu_i, p_tv_acc,
q, desc_k, desc_v, #
q_t, desc_k_t, desc_v_t, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, #
warp_specialize,
)
# stage 2: on-band
if STAGE & 2:
acc, l_i, m_i = _attn_fwd_inner(
acc, g_acc,
l_i, m_i,
mu_i, p_tv_acc,
q, desc_k, desc_v, #
q_t, desc_k_t, desc_v_t, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, #
warp_specialize,
)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
desc_o.store([qo_offset_y, 0], acc.to(dtype))
desc_o_t.store([qo_offset_y, 0], t_y_out.to(dtype))
# --- Typical attn backward kernel from Triton tutorial ---
# https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py
# MIT Licensed
# https://github.com/triton-lang/triton/blob/main/LICENSE
@triton.jit
def _attn_bwd_preprocess(O, DO, #
Delta, #
Z, H, N_CTX, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
HEAD_DIM: tl.constexpr, #
# Filled in by the wrapper.
start_n, start_m, num_steps, #
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, HEAD_DIM)
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(qT_ptrs)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
pT = tl.where(mask, pT, 0.0)
do = tl.load(do_ptrs)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
qT_ptrs += step_m * stride_tok
do_ptrs += step_m * stride_tok
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(dq, q, K, V, #
do, m, D,
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
HEAD_DIM: tl.constexpr,
# Filled in by the wrapper.
start_m, start_n, num_steps, #
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, HEAD_DIM)
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(kT_ptrs)
vT = tl.load(vT_ptrs)
qk = tl.dot(q, kT)
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = (offs_m[:, None] >= offs_n[None, :])
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
kT_ptrs += step_n * stride_tok
vT_ptrs += step_n * stride_tok
return dq
@triton.jit
def _attn_bwd(Q, K, V, sm_scale, #
DO, #
DQ, DK, DV, #
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
BLK_SLICE_FACTOR: tl.constexpr, #
HEAD_DIM: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
# load scales
offs_k = tl.arange(0, HEAD_DIM)
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=True #
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv( #
dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=False #
)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
MASK=True #
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
MASK=False #
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
class JVPAttn(Function):
class FnCtx(FunctionCtx):
sm_scale: float
HEAD_DIM: int
causal: bool
@staticmethod
def forward(
ctx: FnCtx,
q: Tensor,
k: Tensor,
v: Tensor,
q_t: Tensor,
k_t: Tensor,
v_t: Tensor,
causal = False,
sm_scale: Optional[float] = None,
warp_specialize=True,
) -> Tensor:
# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
sm_scale = HEAD_DIM_K ** -.5 if sm_scale is None else sm_scale
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
o = torch.empty_like(q)
o_t = torch.empty_like(q_t)
stage = 3 if causal else 1
extra_kern_args = {}
# Tuning for AMD target
if is_hip():
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
if supports_host_descriptor():
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
dummy_block = [1, 1]
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_q_t = TensorDescriptor(q_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_v_t = TensorDescriptor(v_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_k_t = TensorDescriptor(k_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_o_t = TensorDescriptor(o_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
else:
desc_q = q
desc_v = v
desc_k = k
desc_o = o
desc_q_t = q_t
desc_v_t = v_t
desc_k_t = k_t
desc_o_t = o_t
if hasattr(triton, "set_allocator"):
def alloc_fn(size: int, align: int, _):
return torch.empty(size, dtype=torch.int8, device="cuda")
triton.set_allocator(alloc_fn)
def grid(META):
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
ctx.grid = grid
if is_cuda() and warp_specialize:
if HEAD_DIM_K == 128 and q.dtype == torch.float16:
extra_kern_args["maxnreg"] = 168
else:
extra_kern_args["maxnreg"] = 80
_attn_fwd[grid](
sm_scale, M, #
q.shape[0], q.shape[1], #
desc_q, desc_k, desc_v,
desc_q_t, desc_k_t, desc_v_t,
desc_o, desc_o_t, #
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
STAGE=stage, #
warp_specialize=warp_specialize, #
**extra_kern_args)
ctx.save_for_forward(o_t)
ctx.save_for_backward(q, k, v, o, M)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM = HEAD_DIM_K
ctx.causal = causal
return o
@staticmethod
def fwd_dual(
Q: Tensor,
K: Tensor,
V: Tensor,
causal = False,
sm_scale: Optional[float] = None,
warp_specialize=True,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround for invoking
JVPAttn::forward with the right arguments when you have a dual tensor input.
"""
q_p, q_t = fwAD.unpack_dual(Q)
k_p, k_t = fwAD.unpack_dual(K)
v_p, v_t = fwAD.unpack_dual(V)
# we pass some dualtensor args to ensure jvp() will be called
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason
a = JVPAttn.apply(Q, K, V, q_t, k_t, v_t, causal, sm_scale, warp_specialize)
return a
@staticmethod
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> Tensor:
return ctx.saved_for_forward[0]
@staticmethod
def backward(ctx: JVPAttn.FnCtx, do: Tensor) -> Tensor:
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)
return dq, dk, dv, None, None, None, None
@dataclass
class Args:
bsz: int
model_dim: int
head_dim: int
seq_len: int
@staticmethod
def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--bsz", default=1, type=int)
parser.add_argument("--model-dim", default=320, type=int)
parser.add_argument("--head-dim", default=64, type=int)
parser.add_argument("--seq-len", default=128, type=int)
return parser
@staticmethod
def from_namespace(namespace: Namespace) -> Args:
args = Args(**vars(namespace))
return args
def main(args: Args) -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
heads = args.model_dim // args.head_dim
q_p, q_t, k_p, k_t, v_p, v_t, target = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(7))
# for t in (q_p, k_p, v_p):
# t.requires_grad = True
# t.retain_grad()
# loss_fn = MSELoss()
# if we use MSELoss, we get this error:
# ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone() if you want a mutable tensor.
def loss_fn(out: Tensor, target: Tensor) -> Tensor:
return (out - target).square().mean()
def make_qkv(q_p, k_p, v_p, q_t, k_t, v_t):
q, k, v = (fwAD.make_dual(p, t) for p, t in zip((q_p, k_p, v_p), (q_t, k_t, v_t)))
for t in (q, k, v):
t.requires_grad = True
t.retain_grad()
return q, k, v
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
q0, k0, v0 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
sdpa_out = scaled_dot_product_attention(q0, k0, v0)
sdpa_out.retain_grad()
sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out)
loss0: Tensor = loss_fn(sdpa_out, target)
loss0.backward()
q1, k1, v1 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
ag_out = JVPAttn.fwd_dual(q1, k1, v1)
ag_out.retain_grad()
ag_op, ag_ot = fwAD.unpack_dual(ag_out)
torch.testing.assert_close(ag_op, sdpa_op, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(ag_ot, sdpa_ot, atol=5e-4, rtol=1e-5)
loss2: Tensor = loss_fn(ag_out, target)
torch.testing.assert_close(loss2, loss0, atol=5e-4, rtol=1e-5)
loss2.backward()
torch.testing.assert_close(q1.grad, q0.grad, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(k1.grad, k0.grad, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(v1.grad, v0.grad, atol=5e-4, rtol=1e-5)
pass
pass
if __name__ == "__main__":
parser = Args.get_parser()
args_untyped: Namespace = parser.parse_args()
args: Args = Args.from_namespace(args_untyped)
main(args)
@Birch-san
Copy link
Author

I took almost-latest triton fused attn tutorial code (it doesn't have the fp8 fixes, I didn't realize there was a newer commit available),
and blind-coded "how would I add JVP to this".
Then I tried running it and found that my triton version doesn't support TensorDescriptor, so I won't be able to verify if it works.

I will start again using a blockptr-based implementation (https://gist.github.com/Birch-san/0e852a42a933d3a1c0fcae21ccd15200).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment