Created
June 30, 2025 01:58
-
-
Save Birch-san/96f0aab2662cfebca2f2b4ba1558eb0d to your computer and use it in GitHub Desktop.
Let's try and vibe-code JVPAttn's jvp to be as accurate as JVPAttnRef
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import Any, Literal, NamedTuple, Optional | |
from os import environ | |
import triton | |
import triton.language as tl | |
import torch | |
from torch import Tensor, enable_grad | |
from torch.autograd import Function | |
from torch.autograd.function import FunctionCtx | |
import torch.autograd.forward_ad as fwAD | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
from torch.nn.functional import scaled_dot_product_attention | |
# ---- Begin JVPAttn, which is based on the triton tutorial, so it may be fast and we have a backwards pass (not shown here). fwd has good accuracy, but jvp has poor accuracy and we're trying to figure out why. ---- | |
@triton.jit | |
def _attn_fwd_inner(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype: tl.constexpr, start_m, qk_scale, sm_scale, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # | |
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # | |
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr): | |
# range of values handled by this stage | |
if STAGE == 1: | |
lo, hi = 0, start_m * BLOCK_M | |
elif STAGE == 2: | |
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
lo = tl.multiple_of(lo, BLOCK_M) | |
# causal = False | |
else: | |
lo, hi = 0, N_CTX | |
K_block_ptr = tl.advance(K_block_ptr, (0, lo)) | |
# NOTE: in fp8 mode, we may want to advance the V_block_ptr differently. | |
# I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access. | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) | |
if ENABLE_JVP: | |
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo)) | |
T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0)) | |
# loop over k, v and update accumulator | |
for start_n in range(lo, hi, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
k = tl.load(K_block_ptr) | |
qk = tl.dot(q, k) | |
if ENABLE_JVP: | |
t_k = tl.load(T_K_block_ptr) | |
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) | |
if STAGE == 2: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) | |
if ENABLE_JVP: | |
# Claude says "tangents should be masked with 0.0 since they represent derivatives". | |
t_qk = tl.where(mask, t_qk, 0.0) | |
# TODO: do we need a separate row maximum for qk_t? | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk -= m_ij[:, None] | |
else: | |
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | |
qk = qk * qk_scale - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
l_ij = tl.sum(p, 1) | |
# -- update m_i and l_i | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_i = l_i * alpha + l_ij | |
# -- update output accumulator -- | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = acc.shape[0] | |
BN: tl.constexpr = acc.shape[1] | |
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
acc0 = acc0 * alpha[:, None] | |
acc1 = acc1 * alpha[:, None] | |
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
acc = acc * alpha[:, None] | |
# update acc | |
v = tl.load(V_block_ptr) | |
# NOTE: we may need to transpose v if dtype == tl.float8e5 | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
p = p.to(dtype) | |
if ENABLE_JVP: | |
p_tqk = p * (t_qk * sm_scale) | |
# if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
# BM: tl.constexpr = g_acc.shape[0] | |
# BN: tl.constexpr = g_acc.shape[1] | |
# g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
# g_acc0 = g_acc0 * alpha[:, None] | |
# g_acc1 = g_acc1 * alpha[:, None] | |
# g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) | |
# else: | |
# g_acc = g_acc * alpha[:, None] | |
# g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) | |
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk.to(v.dtype), v) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
t_v = tl.load(T_V_block_ptr) | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v) | |
T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0)) | |
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N)) | |
acc = tl.dot(p, v, acc) | |
# update m_i and l_i | |
m_i = m_ij | |
# the fp8 PR made a change to how K and V are advanced here but I believe we already have that. | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) | |
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
configs = [ | |
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ | |
for BM in [64, 128]\ | |
for BN in [32, 64, 128]\ | |
for s in NUM_STAGES_OPTIONS \ | |
for w in [4, 8]\ | |
] | |
if "PYTEST_VERSION" in environ: | |
# Use a single config in testing for reproducibility | |
configs = [ | |
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4), | |
] | |
def keep(conf): | |
BLOCK_M = conf.kwargs["BLOCK_M"] | |
BLOCK_N = conf.kwargs["BLOCK_N"] | |
return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8) | |
def prune_invalid_configs(configs, named_args, **kwargs): | |
N_CTX = kwargs["N_CTX"] | |
# Filter out configs where BLOCK_M > N_CTX | |
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX] | |
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], | |
prune_configs_by={'early_config_prune': prune_invalid_configs}) | |
@triton.jit | |
def _attn_fwd(Q, K, V, T_Q, T_K, T_V, # | |
sm_scale, M, Out, T_Out, # | |
stride_qz, stride_qh, stride_qm, stride_qk, # | |
stride_kz, stride_kh, stride_kn, stride_kk, # | |
stride_vz, stride_vh, stride_vk, stride_vn, # | |
stride_tqz, stride_tqh, stride_tqm, stride_tqk, # | |
stride_tkz, stride_tkh, stride_tkn, stride_tkk, # | |
stride_tvz, stride_tvh, stride_tvk, stride_tvn, # | |
stride_oz, stride_oh, stride_om, stride_on, # | |
stride_toz, stride_toh, stride_tom, stride_ton, # | |
Z, H, N_CTX, # | |
HEAD_DIM: tl.constexpr, # | |
BLOCK_M: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
FP8_OUTPUT: tl.constexpr, # | |
STAGE: tl.constexpr, # | |
warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr, # | |
): | |
tl.static_assert(BLOCK_N <= HEAD_DIM) | |
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh | |
# block pointers | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_vk, stride_vn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, HEAD_DIM), | |
order=v_order, | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + qvk_offset, | |
shape=(HEAD_DIM, N_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(HEAD_DIM, BLOCK_N), | |
order=(0, 1), | |
) | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
if ENABLE_JVP: | |
# it's extremely likely we could just re-use qvk_offset, but this seems cheap so whatever | |
t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh | |
T_Q_block_ptr = tl.make_block_ptr( | |
base=T_Q + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tqm, stride_tqk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# could probably just re-use v_order here | |
t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0) | |
T_V_block_ptr = tl.make_block_ptr( | |
base=T_V + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tvk, stride_tvn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, HEAD_DIM), | |
order=t_v_order, | |
) | |
T_K_block_ptr = tl.make_block_ptr( | |
base=T_K + t_qvk_offset, | |
shape=(HEAD_DIM, N_CTX), | |
strides=(stride_tkk, stride_tkn), | |
offsets=(0, 0), | |
block_shape=(HEAD_DIM, BLOCK_N), | |
order=(0, 1), | |
) | |
T_O_block_ptr = tl.make_block_ptr( | |
base=T_Out + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tom, stride_ton), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# load q_t: it will stay in SRAM throughout | |
t_q = tl.load(T_Q_block_ptr) | |
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
else: | |
t_q = None | |
T_V_block_ptr = None | |
T_K_block_ptr = None | |
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner | |
g_acc = tl.zeros([1, 1], dtype=tl.float32) | |
mu_i = tl.zeros([1], dtype=tl.float32) | |
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) | |
# load scales | |
qk_scale = sm_scale | |
qk_scale = qk_scale * 1.44269504 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
q = tl.load(Q_block_ptr) | |
# stage 1: off-band | |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE | |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE | |
if STAGE & 1: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
4 - STAGE, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# stage 2: on-band | |
if STAGE & 2: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
2, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# epilogue | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
tl.store(m_ptrs, m_i) | |
tl.store(O_block_ptr, acc.to(Out.type.element_ty)) | |
if ENABLE_JVP: | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty)) | |
class JVPAttn(Function): | |
class Grid(NamedTuple): | |
M_BLOCKS: int | |
Z_H: int | |
ONE: Literal[1] | |
class FnCtx(FunctionCtx): | |
sm_scale: float | |
HEAD_DIM_K: int | |
causal: bool | |
class FwdOutCtxContrib(NamedTuple): | |
o_t: Optional[Tensor] | |
M: Tensor | |
HEAD_DIM_K: int | |
sm_scale: float | |
class FwdOut(NamedTuple): | |
o: Tensor | |
ctx: JVPAttn.FwdOutCtxContrib | |
class JVPOut(NamedTuple): | |
o: Tensor | |
ctx: None | |
class BwdOut(NamedTuple): | |
q: Tensor | |
k: Tensor | |
v: Tensor | |
q_t: None | |
k_t: None | |
v_t: None | |
causal: None | |
sm_scale: None | |
warp_specialize: None | |
USE_TMA: None | |
class Strides(NamedTuple): | |
z: int | |
h: int | |
n_ctx: int | |
head_dim: int | |
@staticmethod | |
def forward( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
q_t: Optional[Tensor], | |
k_t: Optional[Tensor], | |
v_t: Optional[Tensor], | |
causal: bool, | |
sm_scale: Optional[float], | |
warp_specialize=True, | |
) -> JVPAttn.FwdOut: | |
# shape constraints | |
Z, H, N_CTX, HEAD_DIM_Q = q.shape | |
HEAD_DIM_K = k.shape[-1] | |
# when v is in float8_e5m2 it is transposed. | |
HEAD_DIM_V = v.shape[-1] | |
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V | |
assert HEAD_DIM_K in {16, 32, 64, 128, 256} | |
if sm_scale is None: | |
sm_scale = HEAD_DIM_K**-.5 | |
o = torch.empty_like(q) | |
ENABLE_JVP = q_t is not None | |
o_t: Optional[Tensor] = torch.empty_like(q_t) if ENABLE_JVP else None | |
stage = 3 if causal else 1 | |
extra_kern_args = {} | |
if warp_specialize: | |
# we need more registers if we're doing JVP | |
if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP: | |
extra_kern_args["maxnreg"] = 168 | |
else: | |
# TODO: I think for backwards pass of dim=128 this is too low for H100; register allocation fails | |
extra_kern_args["maxnreg"] = 80 | |
M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32) | |
if hasattr(triton, 'set_allocator'): | |
def alloc_fn(size: int, align: int, _): | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
Z_H = Z * H | |
def grid(META: dict[str, Any]) -> JVPAttn.Grid: | |
return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1) | |
def strides_zhnd(t: Tensor) -> JVPAttn.Strides: | |
return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3)) | |
_attn_fwd[grid]( | |
q, k, v, q_t, k_t, v_t, # | |
sm_scale, M, o, o_t, # | |
*strides_zhnd(q), # | |
*strides_zhnd(k), # | |
*strides_zhnd(v), # | |
*strides_zhnd(q if q_t is None else q_t), # | |
*strides_zhnd(k if k_t is None else k_t), # | |
*strides_zhnd(v if v_t is None else v_t), # | |
*strides_zhnd(o), # | |
*strides_zhnd(o if o_t is None else o_t), # | |
Z, H, # | |
N_CTX=N_CTX, # | |
HEAD_DIM=HEAD_DIM_K, # | |
FP8_OUTPUT=q.dtype == torch.float8_e5m2, # | |
STAGE=stage, # | |
warp_specialize=warp_specialize, # | |
ENABLE_JVP=ENABLE_JVP, # | |
**extra_kern_args) | |
return JVPAttn.FwdOut(o, JVPAttn.FwdOutCtxContrib(o_t, M, HEAD_DIM_K, sm_scale)) | |
@staticmethod | |
def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor: | |
( | |
q, | |
k, | |
v, | |
q_t, | |
k_t, | |
v_t, | |
causal, | |
sm_scale, | |
warp_specialize, | |
) = inputs | |
o, (o_t, M, HEAD_DIM_K, sm_scale) = outputs | |
ctx.save_for_forward(o_t) | |
ctx.save_for_backward(q, k, v, o, M) | |
ctx.sm_scale = sm_scale | |
ctx.HEAD_DIM_K = HEAD_DIM_K | |
ctx.causal = causal | |
return o | |
@staticmethod | |
def fwd( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support | |
""" | |
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, None, None, None, causal, sm_scale, warp_specialize) | |
a, _ = out | |
return a | |
@staticmethod | |
def fwd_dual( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround for invoking | |
JVPAttn::forward with the right arguments when you have a dual tensor input. | |
""" | |
q_p, q_t = fwAD.unpack_dual(q) | |
k_p, k_t = fwAD.unpack_dual(k) | |
v_p, v_t = fwAD.unpack_dual(v) | |
# we pass some dualtensor args to ensure jvp() will be called | |
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason | |
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, q_t, k_t, v_t, causal, sm_scale, warp_specialize) | |
a, _ = out | |
return a | |
@staticmethod | |
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut: | |
return JVPAttn.JVPOut(ctx.saved_for_forward[0], None) | |
# ---- Begin JVPAttnRef, which was based on vibe-coding; probably less optimized, and we don't have a backwards pass for it. fwd and jvp both have good accuracy. ---- | |
@triton.autotune( | |
configs=[ | |
# Ultra-conservative configs for maximum compatibility | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}, num_warps=4, num_stages=1), | |
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16}, num_warps=4, num_stages=1), | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64}, num_warps=4, num_stages=1), | |
], | |
key=['B', 'H', 'L', 'D_head'], | |
) | |
@triton.jit | |
def _flash_attention_jvp_multihead_kernel( | |
# Input tensors | |
Q, K, V, T_Q, T_K, T_V, | |
M, | |
# Output tensors | |
Y, T_Y, | |
# Tensor strides | |
stride_qb, stride_qh, stride_ql, stride_qd, | |
stride_kb, stride_kh, stride_kl, stride_kd, | |
stride_vb, stride_vh, stride_vl, stride_vd, | |
stride_tqb, stride_tqh, stride_tql, stride_tqd, | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd, | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd, | |
stride_yb, stride_yh, stride_yl, stride_yd, | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd, | |
# Problem dimensions | |
B: tl.constexpr, H: tl.constexpr, L: tl.constexpr, D_head: tl.constexpr, | |
# Scale factor | |
scale: tl.constexpr, | |
# Block sizes | |
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, | |
): | |
""" | |
Flash Attention JVP kernel following the reference implementation pattern. | |
Grid: (B*H, triton.cdiv(L, BLOCK_M)) | |
""" | |
# Get program IDs | |
pid_bh = tl.program_id(0) # Combined batch and head index | |
pid_m = tl.program_id(1) # Query block index | |
# Decompose batch and head indices | |
pid_b = pid_bh // H | |
pid_h = pid_bh % H | |
# Compute offsets | |
offs_m0 = pid_bh * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
offs_d = tl.arange(0, D_head) | |
# Base pointers for this (batch, head) | |
q_base = Q + pid_b * stride_qb + pid_h * stride_qh | |
k_base = K + pid_b * stride_kb + pid_h * stride_kh | |
v_base = V + pid_b * stride_vb + pid_h * stride_vh | |
tq_base = T_Q + pid_b * stride_tqb + pid_h * stride_tqh | |
tk_base = T_K + pid_b * stride_tkb + pid_h * stride_tkh | |
tv_base = T_V + pid_b * stride_tvb + pid_h * stride_tvh | |
y_base = Y + pid_b * stride_yb + pid_h * stride_yh | |
ty_base = T_Y + pid_b * stride_tyb + pid_h * stride_tyh | |
# Load query block | |
q_ptrs = q_base + offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd | |
tq_ptrs = tq_base + offs_m[:, None] * stride_tql + offs_d[None, :] * stride_tqd | |
mask_m = offs_m < L | |
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0) | |
tq = tl.load(tq_ptrs, mask=mask_m[:, None], other=0.0) | |
# Initialize accumulators following Flash Attention pattern | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
g_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
# Scale factor for exp2 optimization (like reference) | |
qk_scale = scale * 1.44269504 # 1/log(2) | |
# Loop over key/value blocks | |
for start_n in range(0, L, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
offs_n_curr = start_n + offs_n | |
mask_n = offs_n_curr < L | |
# Load key and value blocks | |
k_ptrs = k_base + offs_n_curr[:, None] * stride_kl + offs_d[None, :] * stride_kd | |
v_ptrs = v_base + offs_n_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd | |
tk_ptrs = tk_base + offs_n_curr[:, None] * stride_tkl + offs_d[None, :] * stride_tkd | |
tv_ptrs = tv_base + offs_n_curr[:, None] * stride_tvl + offs_d[None, :] * stride_tvd | |
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) | |
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) | |
tk = tl.load(tk_ptrs, mask=mask_n[:, None], other=0.0) | |
tv = tl.load(tv_ptrs, mask=mask_n[:, None], other=0.0) | |
# Compute attention scores | |
qk = tl.dot(q, tl.trans(k)) | |
tqk = tl.dot(tq, tl.trans(k)) + tl.dot(q, tl.trans(tk)) | |
# Mask invalid positions first | |
qk = tl.where(mask_n[None, :], qk, float('-inf')) | |
tqk = tl.where(mask_n[None, :], tqk, 0.0) | |
# Online softmax computation following Flash Attention | |
m_ij = tl.maximum(m_i, tl.max(qk * scale, 1)) | |
qk = qk * qk_scale - m_ij[:, None] # Scale and subtract max | |
p = tl.math.exp2(qk) # Use exp2 like reference | |
# Correction factor | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_ij = tl.sum(p, 1) | |
# Update normalization | |
l_i = l_i * alpha + l_ij | |
# NOTE: this downcast of p is a new change compared to the reference implementation | |
# Cast p back to input dtype for matmul | |
p_typed = p.to(q.dtype) | |
# Update output accumulator | |
acc = acc * alpha[:, None] + tl.dot(p_typed, v) | |
# JVP accumulator: (p * tqk) @ v | |
p_tqk = p * (tqk * scale) # Apply scale to tangent scores | |
# NOTE: this downcast of p_tqk is a new change compared to the reference implementation | |
p_tqk_typed = p_tqk.to(q.dtype) # Cast tangent weights too | |
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk_typed, v) | |
# Update mu: sum(p * tqk) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
# Update p @ tv accumulator | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p_typed, tv) | |
# Update max | |
m_i = m_ij | |
# Final computation - add log normalization and divide | |
m_i += tl.math.log2(l_i) | |
y_out = acc / l_i[:, None] | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * y_out | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
# Store outputs | |
y_ptrs = y_base + offs_m[:, None] * stride_yl + offs_d[None, :] * stride_yd | |
ty_ptrs = ty_base + offs_m[:, None] * stride_tyl + offs_d[None, :] * stride_tyd | |
tl.store(y_ptrs, y_out, mask=mask_m[:, None]) | |
tl.store(ty_ptrs, t_y_out, mask=mask_m[:, None]) | |
m_ptrs = M + pid_m * L + offs_m0 | |
tl.store(m_ptrs, m_i) | |
class JVPAttnRef(Function): | |
class FnCtx(FunctionCtx): | |
sm_scale: float | |
head_dim: int | |
class FwdOutCtxContrib(NamedTuple): | |
t_y: Optional[Tensor] | |
M: Tensor | |
head_dim: int | |
sm_scale: float | |
class FwdOut(NamedTuple): | |
o: Tensor | |
ctx: JVPAttnRef.FwdOutCtxContrib | |
class JVPOut(NamedTuple): | |
y: Tensor | |
ctx: None | |
@staticmethod | |
def forward( | |
q_p: Tensor, | |
k_p: Tensor, | |
v_p: Tensor, | |
q_t: Tensor, | |
k_t: Tensor, | |
v_t: Tensor, | |
scale: Optional[float] = None, | |
) -> Tensor: | |
head_dim = k_p.size(-1) | |
scale = head_dim**-.5 if scale is None else scale | |
device = q_p.device | |
dtype = q_p.dtype | |
B, H, L, D_head = q_p.shape | |
# Check minimum dimension requirements for Triton | |
if D_head < 16: | |
raise ValueError(f"D_head must be >= 16 for efficient Triton kernel, got {D_head}") | |
if scale is None: | |
scale = 1.0 / (D_head ** 0.5) | |
# Ensure input shapes are correct | |
assert q_p.shape == (B, H, L, D_head), f"Q shape mismatch: {q_p.shape}" | |
assert k_p.shape == (B, H, L, D_head), f"K shape mismatch: {k_p.shape}" | |
assert v_p.shape == (B, H, L, D_head), f"V shape mismatch: {v_p.shape}" | |
assert q_t.shape == (B, H, L, D_head), f"t_Q shape mismatch: {q_t.shape}" | |
assert k_t.shape == (B, H, L, D_head), f"t_K shape mismatch: {k_t.shape}" | |
assert v_t.shape == (B, H, L, D_head), f"t_V shape mismatch: {v_t.shape}" | |
M = torch.empty((B, H, L), device=q_p.device, dtype=torch.float32) | |
# Create output tensors | |
y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device) | |
t_y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device) | |
# Make tensors contiguous | |
Qc = q_p.contiguous() | |
Kc = k_p.contiguous() | |
Vc = v_p.contiguous() | |
t_Qc = q_t.contiguous() | |
t_Kc = k_t.contiguous() | |
t_Vc = v_t.contiguous() | |
# Compute strides | |
stride_qb, stride_qh, stride_ql, stride_qd = Qc.stride() | |
stride_kb, stride_kh, stride_kl, stride_kd = Kc.stride() | |
stride_vb, stride_vh, stride_vl, stride_vd = Vc.stride() | |
stride_tqb, stride_tqh, stride_tql, stride_tqd = t_Qc.stride() | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd = t_Kc.stride() | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd = t_Vc.stride() | |
stride_yb, stride_yh, stride_yl, stride_yd = y.stride() | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd = t_y.stride() | |
# Use block-based grid like Flash Attention | |
# Choose BLOCK_M based on autotuning, but ensure we cover all queries | |
BLOCK_M = 64 # Will be determined by autotuning | |
grid = (B * H, triton.cdiv(L, BLOCK_M)) | |
_flash_attention_jvp_multihead_kernel[grid]( | |
Qc, Kc, Vc, t_Qc, t_Kc, t_Vc, | |
M, | |
y, t_y, | |
stride_qb, stride_qh, stride_ql, stride_qd, | |
stride_kb, stride_kh, stride_kl, stride_kd, | |
stride_vb, stride_vh, stride_vl, stride_vd, | |
stride_tqb, stride_tqh, stride_tql, stride_tqd, | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd, | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd, | |
stride_yb, stride_yh, stride_yl, stride_yd, | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd, | |
B, H, L, D_head, | |
scale, | |
) | |
return JVPAttnRef.FwdOut(y, JVPAttnRef.FwdOutCtxContrib(t_y, M, head_dim, scale)) | |
@staticmethod | |
def setup_context(ctx: JVPAttnRef.FnCtx, inputs, outputs: JVPAttnRef.FwdOut) -> Tensor: | |
( | |
q, | |
k, | |
v, | |
q_t, | |
k_t, | |
v_t, | |
sm_scale, | |
) = inputs | |
y, (t_y, M, head_dim, sm_scale) = outputs | |
ctx.save_for_forward(t_y) | |
ctx.save_for_backward(q, k, v, y, M) | |
ctx.sm_scale = sm_scale | |
ctx.head_dim = head_dim | |
return y | |
@staticmethod | |
def fwd( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
sm_scale: Optional[float] = None, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support | |
""" | |
out: JVPAttnRef.FwdOut = JVPAttnRef.apply(q, k, v, None, None, None, sm_scale) | |
a, _ = out | |
return a | |
@staticmethod | |
def fwd_dual( | |
Q: Tensor, | |
K: Tensor, | |
V: Tensor, | |
scale: Optional[float] = None, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround for invoking | |
JVPAttn::forward with the right arguments when you have a dual tensor input. | |
""" | |
q_p, q_t = fwAD.unpack_dual(Q) | |
k_p, k_t = fwAD.unpack_dual(K) | |
v_p, v_t = fwAD.unpack_dual(V) | |
# we pass some dualtensor args to ensure jvp() will be called | |
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason | |
out: JVPAttnRef.FwdOut = JVPAttnRef.apply(Q, K, V, q_t, k_t, v_t, scale) | |
a, _ = out | |
return a | |
@staticmethod | |
def jvp(ctx: JVPAttnRef.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttnRef.JVPOut: | |
return JVPAttnRef.JVPOut(ctx.saved_for_forward[0], None) | |
class QKV(NamedTuple): | |
q: Tensor | |
k: Tensor | |
v: Tensor | |
class UnpackedDualQKV(NamedTuple): | |
primal: QKV | |
tangent: QKV | |
def main() -> None: | |
device = torch.device('cuda') | |
dtype = torch.float16 | |
seed = 42 | |
gen = torch.Generator(device=device) | |
bsz = 1 | |
model_dim = 320 | |
head_dim = 64 | |
seq_len = 128 | |
heads = model_dim // head_dim | |
q_p, q_t, k_p, k_t, v_p, v_t = (torch.randn(bsz, heads, seq_len, head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(6)) | |
def make_qkv(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> QKV: | |
return QKV( | |
q=fwAD.make_dual(q_p, q_t).requires_grad_(), | |
k=fwAD.make_dual(k_p, k_t).requires_grad_(), | |
v=fwAD.make_dual(v_p, v_t).requires_grad_(), | |
) | |
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad(): | |
q0, k0, v0 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()) | |
sdpa_out = scaled_dot_product_attention(q0, k0, v0) | |
sdpa_out.retain_grad() | |
sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out) | |
q1, k1, v1 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()) | |
dual_out = JVPAttnRef.fwd_dual(q1, k1, v1) | |
dual_out.retain_grad() | |
dual_op, dual_ot = fwAD.unpack_dual(dual_out) | |
torch.testing.assert_close(dual_op, sdpa_op, atol=5e-4, rtol=1e-5) | |
torch.testing.assert_close(dual_ot, sdpa_ot, atol=5e-4, rtol=1e-5) | |
q2, k2, v2 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone()) | |
dual_out = JVPAttn.fwd_dual(q2, k2, v2) | |
dual_out.retain_grad() | |
dual_op, dual_ot = fwAD.unpack_dual(dual_out) | |
torch.testing.assert_close(dual_op, sdpa_op, atol=5e-4, rtol=1e-5) | |
# TODO: improve this accuracy | |
torch.testing.assert_close(dual_ot, sdpa_ot, atol=1e-3, rtol=1e-5) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment