Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created June 30, 2025 01:58
Show Gist options
  • Save Birch-san/96f0aab2662cfebca2f2b4ba1558eb0d to your computer and use it in GitHub Desktop.
Save Birch-san/96f0aab2662cfebca2f2b4ba1558eb0d to your computer and use it in GitHub Desktop.
Let's try and vibe-code JVPAttn's jvp to be as accurate as JVPAttnRef
from __future__ import annotations
from typing import Any, Literal, NamedTuple, Optional
from os import environ
import triton
import triton.language as tl
import torch
from torch import Tensor, enable_grad
from torch.autograd import Function
from torch.autograd.function import FunctionCtx
import torch.autograd.forward_ad as fwAD
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
# ---- Begin JVPAttn, which is based on the triton tutorial, so it may be fast and we have a backwards pass (not shown here). fwd has good accuracy, but jvp has poor accuracy and we're trying to figure out why. ----
@triton.jit
def _attn_fwd_inner(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
K_block_ptr, V_block_ptr, #
T_K_block_ptr, T_V_block_ptr, #
dtype: tl.constexpr, start_m, qk_scale, sm_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, #
ENABLE_JVP: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
# NOTE: in fp8 mode, we may want to advance the V_block_ptr differently.
# I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access.
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
if ENABLE_JVP:
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo))
T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr)
qk = tl.dot(q, k)
if ENABLE_JVP:
t_k = tl.load(T_K_block_ptr)
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
if ENABLE_JVP:
# Claude says "tangents should be masked with 0.0 since they represent derivatives".
t_qk = tl.where(mask, t_qk, 0.0)
# TODO: do we need a separate row maximum for qk_t?
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
BM: tl.constexpr = acc.shape[0]
BN: tl.constexpr = acc.shape[1]
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
else:
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr)
# NOTE: we may need to transpose v if dtype == tl.float8e5
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
p = p.to(dtype)
if ENABLE_JVP:
p_tqk = p * (t_qk * sm_scale)
# if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
# BM: tl.constexpr = g_acc.shape[0]
# BN: tl.constexpr = g_acc.shape[1]
# g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
# g_acc0 = g_acc0 * alpha[:, None]
# g_acc1 = g_acc1 * alpha[:, None]
# g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN])
# else:
# g_acc = g_acc * alpha[:, None]
# g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc)
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk.to(v.dtype), v)
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
t_v = tl.load(T_V_block_ptr)
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v)
T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0))
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N))
acc = tl.dot(p, v, acc)
# update m_i and l_i
m_i = m_ij
# the fp8 PR made a change to how K and V are advanced here but I believe we already have that.
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc
NUM_STAGES_OPTIONS = [2, 3, 4]
configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64, 128]\
for BN in [32, 64, 128]\
for s in NUM_STAGES_OPTIONS \
for w in [4, 8]\
]
if "PYTEST_VERSION" in environ:
# Use a single config in testing for reproducibility
configs = [
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4),
]
def keep(conf):
BLOCK_M = conf.kwargs["BLOCK_M"]
BLOCK_N = conf.kwargs["BLOCK_N"]
return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8)
def prune_invalid_configs(configs, named_args, **kwargs):
N_CTX = kwargs["N_CTX"]
# Filter out configs where BLOCK_M > N_CTX
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
prune_configs_by={'early_config_prune': prune_invalid_configs})
@triton.jit
def _attn_fwd(Q, K, V, T_Q, T_K, T_V, #
sm_scale, M, Out, T_Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_tqz, stride_tqh, stride_tqm, stride_tqk, #
stride_tkz, stride_tkh, stride_tkn, stride_tkk, #
stride_tvz, stride_tvh, stride_tvk, stride_tvn, #
stride_oz, stride_oh, stride_om, stride_on, #
stride_toz, stride_toh, stride_tom, stride_ton, #
Z, H, N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
STAGE: tl.constexpr, #
warp_specialize: tl.constexpr, #
ENABLE_JVP: tl.constexpr, #
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=v_order,
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
if ENABLE_JVP:
# it's extremely likely we could just re-use qvk_offset, but this seems cheap so whatever
t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh
T_Q_block_ptr = tl.make_block_ptr(
base=T_Q + t_qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_tqm, stride_tqk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# could probably just re-use v_order here
t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0)
T_V_block_ptr = tl.make_block_ptr(
base=T_V + t_qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_tvk, stride_tvn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=t_v_order,
)
T_K_block_ptr = tl.make_block_ptr(
base=T_K + t_qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_tkk, stride_tkn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
T_O_block_ptr = tl.make_block_ptr(
base=T_Out + t_qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_tom, stride_ton),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# load q_t: it will stay in SRAM throughout
t_q = tl.load(T_Q_block_ptr)
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
else:
t_q = None
T_V_block_ptr = None
T_K_block_ptr = None
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner
g_acc = tl.zeros([1, 1], dtype=tl.float32)
mu_i = tl.zeros([1], dtype=tl.float32)
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale = qk_scale * 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc,
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
K_block_ptr, V_block_ptr, #
T_K_block_ptr, T_V_block_ptr, #
dtype, start_m, qk_scale, sm_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, #
warp_specialize,
ENABLE_JVP)
# stage 2: on-band
if STAGE & 2:
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
K_block_ptr, V_block_ptr, #
T_K_block_ptr, T_V_block_ptr, #
dtype, start_m, qk_scale, sm_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, #
warp_specialize,
ENABLE_JVP)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
if ENABLE_JVP:
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty))
class JVPAttn(Function):
class Grid(NamedTuple):
M_BLOCKS: int
Z_H: int
ONE: Literal[1]
class FnCtx(FunctionCtx):
sm_scale: float
HEAD_DIM_K: int
causal: bool
class FwdOutCtxContrib(NamedTuple):
o_t: Optional[Tensor]
M: Tensor
HEAD_DIM_K: int
sm_scale: float
class FwdOut(NamedTuple):
o: Tensor
ctx: JVPAttn.FwdOutCtxContrib
class JVPOut(NamedTuple):
o: Tensor
ctx: None
class BwdOut(NamedTuple):
q: Tensor
k: Tensor
v: Tensor
q_t: None
k_t: None
v_t: None
causal: None
sm_scale: None
warp_specialize: None
USE_TMA: None
class Strides(NamedTuple):
z: int
h: int
n_ctx: int
head_dim: int
@staticmethod
def forward(
q: Tensor,
k: Tensor,
v: Tensor,
q_t: Optional[Tensor],
k_t: Optional[Tensor],
v_t: Optional[Tensor],
causal: bool,
sm_scale: Optional[float],
warp_specialize=True,
) -> JVPAttn.FwdOut:
# shape constraints
Z, H, N_CTX, HEAD_DIM_Q = q.shape
HEAD_DIM_K = k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
if sm_scale is None:
sm_scale = HEAD_DIM_K**-.5
o = torch.empty_like(q)
ENABLE_JVP = q_t is not None
o_t: Optional[Tensor] = torch.empty_like(q_t) if ENABLE_JVP else None
stage = 3 if causal else 1
extra_kern_args = {}
if warp_specialize:
# we need more registers if we're doing JVP
if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP:
extra_kern_args["maxnreg"] = 168
else:
# TODO: I think for backwards pass of dim=128 this is too low for H100; register allocation fails
extra_kern_args["maxnreg"] = 80
M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32)
if hasattr(triton, 'set_allocator'):
def alloc_fn(size: int, align: int, _):
return torch.empty(size, dtype=torch.int8, device="cuda")
triton.set_allocator(alloc_fn)
Z_H = Z * H
def grid(META: dict[str, Any]) -> JVPAttn.Grid:
return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1)
def strides_zhnd(t: Tensor) -> JVPAttn.Strides:
return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3))
_attn_fwd[grid](
q, k, v, q_t, k_t, v_t, #
sm_scale, M, o, o_t, #
*strides_zhnd(q), #
*strides_zhnd(k), #
*strides_zhnd(v), #
*strides_zhnd(q if q_t is None else q_t), #
*strides_zhnd(k if k_t is None else k_t), #
*strides_zhnd(v if v_t is None else v_t), #
*strides_zhnd(o), #
*strides_zhnd(o if o_t is None else o_t), #
Z, H, #
N_CTX=N_CTX, #
HEAD_DIM=HEAD_DIM_K, #
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
STAGE=stage, #
warp_specialize=warp_specialize, #
ENABLE_JVP=ENABLE_JVP, #
**extra_kern_args)
return JVPAttn.FwdOut(o, JVPAttn.FwdOutCtxContrib(o_t, M, HEAD_DIM_K, sm_scale))
@staticmethod
def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor:
(
q,
k,
v,
q_t,
k_t,
v_t,
causal,
sm_scale,
warp_specialize,
) = inputs
o, (o_t, M, HEAD_DIM_K, sm_scale) = outputs
ctx.save_for_forward(o_t)
ctx.save_for_backward(q, k, v, o, M)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM_K = HEAD_DIM_K
ctx.causal = causal
return o
@staticmethod
def fwd(
q: Tensor,
k: Tensor,
v: Tensor,
causal = False,
sm_scale: Optional[float] = None,
warp_specialize=True,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support
"""
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, None, None, None, causal, sm_scale, warp_specialize)
a, _ = out
return a
@staticmethod
def fwd_dual(
q: Tensor,
k: Tensor,
v: Tensor,
causal = False,
sm_scale: Optional[float] = None,
warp_specialize=True,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround for invoking
JVPAttn::forward with the right arguments when you have a dual tensor input.
"""
q_p, q_t = fwAD.unpack_dual(q)
k_p, k_t = fwAD.unpack_dual(k)
v_p, v_t = fwAD.unpack_dual(v)
# we pass some dualtensor args to ensure jvp() will be called
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, q_t, k_t, v_t, causal, sm_scale, warp_specialize)
a, _ = out
return a
@staticmethod
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut:
return JVPAttn.JVPOut(ctx.saved_for_forward[0], None)
# ---- Begin JVPAttnRef, which was based on vibe-coding; probably less optimized, and we don't have a backwards pass for it. fwd and jvp both have good accuracy. ----
@triton.autotune(
configs=[
# Ultra-conservative configs for maximum compatibility
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}, num_warps=4, num_stages=1),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16}, num_warps=4, num_stages=1),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64}, num_warps=4, num_stages=1),
],
key=['B', 'H', 'L', 'D_head'],
)
@triton.jit
def _flash_attention_jvp_multihead_kernel(
# Input tensors
Q, K, V, T_Q, T_K, T_V,
M,
# Output tensors
Y, T_Y,
# Tensor strides
stride_qb, stride_qh, stride_ql, stride_qd,
stride_kb, stride_kh, stride_kl, stride_kd,
stride_vb, stride_vh, stride_vl, stride_vd,
stride_tqb, stride_tqh, stride_tql, stride_tqd,
stride_tkb, stride_tkh, stride_tkl, stride_tkd,
stride_tvb, stride_tvh, stride_tvl, stride_tvd,
stride_yb, stride_yh, stride_yl, stride_yd,
stride_tyb, stride_tyh, stride_tyl, stride_tyd,
# Problem dimensions
B: tl.constexpr, H: tl.constexpr, L: tl.constexpr, D_head: tl.constexpr,
# Scale factor
scale: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
"""
Flash Attention JVP kernel following the reference implementation pattern.
Grid: (B*H, triton.cdiv(L, BLOCK_M))
"""
# Get program IDs
pid_bh = tl.program_id(0) # Combined batch and head index
pid_m = tl.program_id(1) # Query block index
# Decompose batch and head indices
pid_b = pid_bh // H
pid_h = pid_bh % H
# Compute offsets
offs_m0 = pid_bh * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, D_head)
# Base pointers for this (batch, head)
q_base = Q + pid_b * stride_qb + pid_h * stride_qh
k_base = K + pid_b * stride_kb + pid_h * stride_kh
v_base = V + pid_b * stride_vb + pid_h * stride_vh
tq_base = T_Q + pid_b * stride_tqb + pid_h * stride_tqh
tk_base = T_K + pid_b * stride_tkb + pid_h * stride_tkh
tv_base = T_V + pid_b * stride_tvb + pid_h * stride_tvh
y_base = Y + pid_b * stride_yb + pid_h * stride_yh
ty_base = T_Y + pid_b * stride_tyb + pid_h * stride_tyh
# Load query block
q_ptrs = q_base + offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd
tq_ptrs = tq_base + offs_m[:, None] * stride_tql + offs_d[None, :] * stride_tqd
mask_m = offs_m < L
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0)
tq = tl.load(tq_ptrs, mask=mask_m[:, None], other=0.0)
# Initialize accumulators following Flash Attention pattern
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
g_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
# Scale factor for exp2 optimization (like reference)
qk_scale = scale * 1.44269504 # 1/log(2)
# Loop over key/value blocks
for start_n in range(0, L, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
offs_n_curr = start_n + offs_n
mask_n = offs_n_curr < L
# Load key and value blocks
k_ptrs = k_base + offs_n_curr[:, None] * stride_kl + offs_d[None, :] * stride_kd
v_ptrs = v_base + offs_n_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd
tk_ptrs = tk_base + offs_n_curr[:, None] * stride_tkl + offs_d[None, :] * stride_tkd
tv_ptrs = tv_base + offs_n_curr[:, None] * stride_tvl + offs_d[None, :] * stride_tvd
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0)
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0)
tk = tl.load(tk_ptrs, mask=mask_n[:, None], other=0.0)
tv = tl.load(tv_ptrs, mask=mask_n[:, None], other=0.0)
# Compute attention scores
qk = tl.dot(q, tl.trans(k))
tqk = tl.dot(tq, tl.trans(k)) + tl.dot(q, tl.trans(tk))
# Mask invalid positions first
qk = tl.where(mask_n[None, :], qk, float('-inf'))
tqk = tl.where(mask_n[None, :], tqk, 0.0)
# Online softmax computation following Flash Attention
m_ij = tl.maximum(m_i, tl.max(qk * scale, 1))
qk = qk * qk_scale - m_ij[:, None] # Scale and subtract max
p = tl.math.exp2(qk) # Use exp2 like reference
# Correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# Update normalization
l_i = l_i * alpha + l_ij
# NOTE: this downcast of p is a new change compared to the reference implementation
# Cast p back to input dtype for matmul
p_typed = p.to(q.dtype)
# Update output accumulator
acc = acc * alpha[:, None] + tl.dot(p_typed, v)
# JVP accumulator: (p * tqk) @ v
p_tqk = p * (tqk * scale) # Apply scale to tangent scores
# NOTE: this downcast of p_tqk is a new change compared to the reference implementation
p_tqk_typed = p_tqk.to(q.dtype) # Cast tangent weights too
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk_typed, v)
# Update mu: sum(p * tqk)
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
# Update p @ tv accumulator
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p_typed, tv)
# Update max
m_i = m_ij
# Final computation - add log normalization and divide
m_i += tl.math.log2(l_i)
y_out = acc / l_i[:, None]
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * y_out
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
# Store outputs
y_ptrs = y_base + offs_m[:, None] * stride_yl + offs_d[None, :] * stride_yd
ty_ptrs = ty_base + offs_m[:, None] * stride_tyl + offs_d[None, :] * stride_tyd
tl.store(y_ptrs, y_out, mask=mask_m[:, None])
tl.store(ty_ptrs, t_y_out, mask=mask_m[:, None])
m_ptrs = M + pid_m * L + offs_m0
tl.store(m_ptrs, m_i)
class JVPAttnRef(Function):
class FnCtx(FunctionCtx):
sm_scale: float
head_dim: int
class FwdOutCtxContrib(NamedTuple):
t_y: Optional[Tensor]
M: Tensor
head_dim: int
sm_scale: float
class FwdOut(NamedTuple):
o: Tensor
ctx: JVPAttnRef.FwdOutCtxContrib
class JVPOut(NamedTuple):
y: Tensor
ctx: None
@staticmethod
def forward(
q_p: Tensor,
k_p: Tensor,
v_p: Tensor,
q_t: Tensor,
k_t: Tensor,
v_t: Tensor,
scale: Optional[float] = None,
) -> Tensor:
head_dim = k_p.size(-1)
scale = head_dim**-.5 if scale is None else scale
device = q_p.device
dtype = q_p.dtype
B, H, L, D_head = q_p.shape
# Check minimum dimension requirements for Triton
if D_head < 16:
raise ValueError(f"D_head must be >= 16 for efficient Triton kernel, got {D_head}")
if scale is None:
scale = 1.0 / (D_head ** 0.5)
# Ensure input shapes are correct
assert q_p.shape == (B, H, L, D_head), f"Q shape mismatch: {q_p.shape}"
assert k_p.shape == (B, H, L, D_head), f"K shape mismatch: {k_p.shape}"
assert v_p.shape == (B, H, L, D_head), f"V shape mismatch: {v_p.shape}"
assert q_t.shape == (B, H, L, D_head), f"t_Q shape mismatch: {q_t.shape}"
assert k_t.shape == (B, H, L, D_head), f"t_K shape mismatch: {k_t.shape}"
assert v_t.shape == (B, H, L, D_head), f"t_V shape mismatch: {v_t.shape}"
M = torch.empty((B, H, L), device=q_p.device, dtype=torch.float32)
# Create output tensors
y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device)
t_y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device)
# Make tensors contiguous
Qc = q_p.contiguous()
Kc = k_p.contiguous()
Vc = v_p.contiguous()
t_Qc = q_t.contiguous()
t_Kc = k_t.contiguous()
t_Vc = v_t.contiguous()
# Compute strides
stride_qb, stride_qh, stride_ql, stride_qd = Qc.stride()
stride_kb, stride_kh, stride_kl, stride_kd = Kc.stride()
stride_vb, stride_vh, stride_vl, stride_vd = Vc.stride()
stride_tqb, stride_tqh, stride_tql, stride_tqd = t_Qc.stride()
stride_tkb, stride_tkh, stride_tkl, stride_tkd = t_Kc.stride()
stride_tvb, stride_tvh, stride_tvl, stride_tvd = t_Vc.stride()
stride_yb, stride_yh, stride_yl, stride_yd = y.stride()
stride_tyb, stride_tyh, stride_tyl, stride_tyd = t_y.stride()
# Use block-based grid like Flash Attention
# Choose BLOCK_M based on autotuning, but ensure we cover all queries
BLOCK_M = 64 # Will be determined by autotuning
grid = (B * H, triton.cdiv(L, BLOCK_M))
_flash_attention_jvp_multihead_kernel[grid](
Qc, Kc, Vc, t_Qc, t_Kc, t_Vc,
M,
y, t_y,
stride_qb, stride_qh, stride_ql, stride_qd,
stride_kb, stride_kh, stride_kl, stride_kd,
stride_vb, stride_vh, stride_vl, stride_vd,
stride_tqb, stride_tqh, stride_tql, stride_tqd,
stride_tkb, stride_tkh, stride_tkl, stride_tkd,
stride_tvb, stride_tvh, stride_tvl, stride_tvd,
stride_yb, stride_yh, stride_yl, stride_yd,
stride_tyb, stride_tyh, stride_tyl, stride_tyd,
B, H, L, D_head,
scale,
)
return JVPAttnRef.FwdOut(y, JVPAttnRef.FwdOutCtxContrib(t_y, M, head_dim, scale))
@staticmethod
def setup_context(ctx: JVPAttnRef.FnCtx, inputs, outputs: JVPAttnRef.FwdOut) -> Tensor:
(
q,
k,
v,
q_t,
k_t,
v_t,
sm_scale,
) = inputs
y, (t_y, M, head_dim, sm_scale) = outputs
ctx.save_for_forward(t_y)
ctx.save_for_backward(q, k, v, y, M)
ctx.sm_scale = sm_scale
ctx.head_dim = head_dim
return y
@staticmethod
def fwd(
q: Tensor,
k: Tensor,
v: Tensor,
sm_scale: Optional[float] = None,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support
"""
out: JVPAttnRef.FwdOut = JVPAttnRef.apply(q, k, v, None, None, None, sm_scale)
a, _ = out
return a
@staticmethod
def fwd_dual(
Q: Tensor,
K: Tensor,
V: Tensor,
scale: Optional[float] = None,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround for invoking
JVPAttn::forward with the right arguments when you have a dual tensor input.
"""
q_p, q_t = fwAD.unpack_dual(Q)
k_p, k_t = fwAD.unpack_dual(K)
v_p, v_t = fwAD.unpack_dual(V)
# we pass some dualtensor args to ensure jvp() will be called
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason
out: JVPAttnRef.FwdOut = JVPAttnRef.apply(Q, K, V, q_t, k_t, v_t, scale)
a, _ = out
return a
@staticmethod
def jvp(ctx: JVPAttnRef.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttnRef.JVPOut:
return JVPAttnRef.JVPOut(ctx.saved_for_forward[0], None)
class QKV(NamedTuple):
q: Tensor
k: Tensor
v: Tensor
class UnpackedDualQKV(NamedTuple):
primal: QKV
tangent: QKV
def main() -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
bsz = 1
model_dim = 320
head_dim = 64
seq_len = 128
heads = model_dim // head_dim
q_p, q_t, k_p, k_t, v_p, v_t = (torch.randn(bsz, heads, seq_len, head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(6))
def make_qkv(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> QKV:
return QKV(
q=fwAD.make_dual(q_p, q_t).requires_grad_(),
k=fwAD.make_dual(k_p, k_t).requires_grad_(),
v=fwAD.make_dual(v_p, v_t).requires_grad_(),
)
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
q0, k0, v0 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
sdpa_out = scaled_dot_product_attention(q0, k0, v0)
sdpa_out.retain_grad()
sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out)
q1, k1, v1 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
dual_out = JVPAttnRef.fwd_dual(q1, k1, v1)
dual_out.retain_grad()
dual_op, dual_ot = fwAD.unpack_dual(dual_out)
torch.testing.assert_close(dual_op, sdpa_op, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(dual_ot, sdpa_ot, atol=5e-4, rtol=1e-5)
q2, k2, v2 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
dual_out = JVPAttn.fwd_dual(q2, k2, v2)
dual_out.retain_grad()
dual_op, dual_ot = fwAD.unpack_dual(dual_out)
torch.testing.assert_close(dual_op, sdpa_op, atol=5e-4, rtol=1e-5)
# TODO: improve this accuracy
torch.testing.assert_close(dual_ot, sdpa_ot, atol=1e-3, rtol=1e-5)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment