Skip to content

Instantly share code, notes, and snippets.

View Birch-san's full-sized avatar

Birch-san

View GitHub Profile
@Birch-san
Birch-san / jvp_accuracy_compare.py
Created June 30, 2025 01:58
Let's try and vibe-code JVPAttn's jvp to be as accurate as JVPAttnRef
from __future__ import annotations
from typing import Any, Literal, NamedTuple, Optional
from os import environ
import triton
import triton.language as tl
import torch
from torch import Tensor, enable_grad
from torch.autograd import Function
from torch.autograd.function import FunctionCtx
import torch.autograd.forward_ad as fwAD
@Birch-san
Birch-san / recompile_nonrepro.py
Created June 30, 2025 00:40
Failed attempt to reproduce "torch.compile gets invalidated too easily by einops rearrange"
from __future__ import annotations
from typing import NamedTuple, Optional
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
import torch
from torch import Tensor, inference_mode
from torch.nn import Module, Linear
from torch.nn.functional import relu
from einops import rearrange
@Birch-san
Birch-san / _06_fused_attention_blockptr_jvp.py
Last active June 29, 2025 17:08
Triton fused attention tutorial, updated with JVP support. Albeit with atol=1e-3 accuracy on JVP.
from __future__ import annotations
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
@Birch-san
Birch-san / jvp_parity.py
Created June 26, 2025 13:11
(Can't run this) triton JVP attn blind-code for TensorDescriptor-era triton
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from os import environ
from typing import Any, Callable, NamedTuple, Optional
import torch
from torch import Tensor, no_grad, enable_grad
from torch.autograd import Function
from torch.autograd.function import FunctionCtx
@Birch-san
Birch-san / fused_attn.py
Created June 26, 2025 13:08
Triton fused attention tutorial code, with blockptr-era codepath restored and newer contributions backported into it
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
@Birch-san
Birch-san / naiv1_generate.py
Last active June 15, 2025 22:52
Script for generating images from NAIv1
from dataclasses import dataclass
from einops import rearrange
import re
import torch
from torch import BoolTensor, FloatTensor, IntTensor, LongTensor, inference_mode
from torch.nn.functional import pad
from itertools import islice
from typing import Generator, Iterable, Iterator, Optional, Protocol, TypeVar
from typing_extensions import override
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional
import torch
from torch import Tensor, no_grad, enable_grad
import torch.autograd.forward_ad as fwAD
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
@Birch-san
Birch-san / jvp_flops.py
Created June 14, 2025 00:12
Does linearize work? Am I using it right?
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Callable, Generic, TypeVar
import torch
from torch import enable_grad, no_grad
import torch.autograd.forward_ad as fwAD
from torch.func import linearize
from torch.nn.attention import SDPBackend, sdpa_kernel
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch
from torch import enable_grad, no_grad
import torch.autograd.forward_ad as fwAD
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
@Birch-san
Birch-san / gist:eddad13648725d47c71799c39e8361b2
Created May 29, 2025 13:07
Example API request for generating an image using a stored NAIv4 vibe. Uses vibe files created by https://gist.github.com/Birch-san/5eb62a4a5e4a1c4447a55e3a9faf8988
#!/usr/bin/env bash
set -eo pipefail
# https://stackoverflow.com/a/12194427/5257399
create() { # fd base [qualifier [suffix [max]]]
local fd="$1" base="$2" qualifier="${3-}" suffix="${4-.png}" max="${5-}"
local n=0 file
local - # ash-style local scoping of options in 4.4+
set -o noclobber
REPLY=