Skip to content

Instantly share code, notes, and snippets.

@Isotr0py
Created July 17, 2025 16:44
Show Gist options
  • Save Isotr0py/eef7470ff176a28ac40340b883cf1abe to your computer and use it in GitHub Desktop.
Save Isotr0py/eef7470ff176a28ac40340b883cf1abe to your computer and use it in GitHub Desktop.
Benchmark jina embeddings v4 vision pooling kernels
from contextlib import contextmanager
import triton
import triton.language as tl
import torch
import torch.nn.functional as F
from array import array
VISION_START_TOKEN_ID = 151652
VISION_END_TOKEN_ID = 151653
VISION_TOKEN_ID = 151655
@contextmanager
def time_counter(enable: bool, label: str):
if enable:
import time
start_time = time.time()
yield
elapsed_time = time.time() - start_time
print(f"-- {label} = {elapsed_time}")
else:
yield
@triton.jit
def extract_vision_tokens_kernel(
hidden_states_ptr,
token_ids_ptr,
output_ptr,
seq_start,
seq_len,
hidden_size,
vision_start_id: tl.constexpr,
vision_end_id: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Triton kernel to extract and pool vision tokens efficiently."""
pid = tl.program_id(0)
if pid >= hidden_size:
return
# Find vision token range
vision_count = 0
accumulator = 0.0
for i in range(seq_len):
token_id = tl.load(token_ids_ptr + seq_start + i)
if token_id >= vision_start_id and token_id <= vision_end_id:
hidden_val = tl.load(hidden_states_ptr +
(seq_start + i) * hidden_size + pid)
accumulator += hidden_val
vision_count += 1
# Store mean pooled result
result = accumulator / vision_count if vision_count > 0 else 0.0
tl.store(output_ptr + pid, result)
def _apply_vision_pooling_optimized(
hidden_size: int,
hidden_states: torch.Tensor,
token_ids_list: list[array],
prompt_lens: torch.Tensor,
) -> list[torch.Tensor]:
"""Apply optimized vision token pooling using Triton kernels."""
pooled_outputs = []
offset = 0
device = hidden_states.device
for i, (token_ids,
prompt_len) in enumerate(zip(token_ids_list, prompt_lens)):
prompt_len = int(prompt_len.item())
# Convert token IDs to tensor
token_tensor = torch.tensor(list(token_ids),
dtype=torch.long,
device=device)
# Allocate output tensor
output = torch.zeros(hidden_size,
device=device,
dtype=hidden_states.dtype)
# Check for vision tokens
has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID)
& (token_tensor <= VISION_END_TOKEN_ID))
if has_vision:
# Use Triton kernel for vision token extraction
grid = (hidden_size, )
extract_vision_tokens_kernel[grid](
hidden_states,
token_tensor,
output,
offset,
prompt_len,
hidden_size,
VISION_START_TOKEN_ID,
VISION_END_TOKEN_ID,
BLOCK_SIZE=1024,
)
else:
# Regular mean pooling for text
seq_states = hidden_states[offset:offset + prompt_len]
output = seq_states.mean(dim=0)
# Normalize and handle potential NaNs by replacing with zeros
output = F.normalize(output, p=2, dim=-1, eps=1e-12)
pooled_outputs.append(output)
offset += prompt_len
return pooled_outputs
def _apply_vision_pooling_pytorch(
hidden_size: int,
hidden_states: torch.Tensor,
token_ids_list: list[array],
prompt_lens: torch.Tensor,
) -> list[torch.Tensor]:
"""PyTorch fallback for vision token pooling."""
pooled_outputs = []
offset = 0
for token_ids, prompt_len in zip(token_ids_list, prompt_lens):
prompt_len = int(prompt_len.item())
# Extract sequence states and tokens
seq_states = hidden_states[offset:offset + prompt_len]
# Convert array to tensor for processing
seq_tokens = torch.tensor(list(token_ids[:prompt_len]),
dtype=torch.long,
device=hidden_states.device)
# Check for vision tokens
vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) &
(seq_tokens <= VISION_END_TOKEN_ID))
if vision_mask.any():
# Pool only vision tokens
vision_states = seq_states[vision_mask]
if vision_states.numel() == 0:
pooled = seq_states.mean(dim=0)
else:
pooled = vision_states.mean(dim=0)
else:
# Pool all tokens for text
pooled = seq_states.mean(dim=0)
# Normalize embeddings
pooled = F.normalize(pooled, p=2, dim=-1, eps=1e-12)
pooled_outputs.append(pooled)
offset += prompt_len
return pooled_outputs
hidden_size = 2048
seq_len_per_image = [512, 1024, 8192, 16384]
text_seq_len = 2048
num_images = [1, 2, 4]
dtype = torch.bfloat16
# Test optimized vision pooling
for num_image in num_images:
for img_seq_len in seq_len_per_image:
print()
print(f"Image sequence length: {img_seq_len}, Text sequence length: {text_seq_len}, Number of images: {num_image}")
# Simulate hidden states and token IDs
prompt_len = img_seq_len + 2 * num_image + text_seq_len
hidden_states = torch.randn(prompt_len, hidden_size, device='cuda', dtype=dtype)
token_ids_list = [([151652] + [151655] * img_seq_len + [151653]) * num_image + [123] * text_seq_len]
prompt_lens = torch.tensor([prompt_len], device='cuda', dtype=dtype)
_apply_vision_pooling_optimized(hidden_size, hidden_states, token_ids_list, prompt_lens)
with time_counter(enable=True, label="triton vision pooling"):
for _ in range(100):
optimized_outputs = _apply_vision_pooling_optimized(
hidden_size, hidden_states, token_ids_list, prompt_lens)
with time_counter(enable=True, label="pytorch vision pooling"):
for _ in range(100):
pytorch_outputs = _apply_vision_pooling_pytorch(
hidden_size, hidden_states, token_ids_list, prompt_lens)
for opt_out, pt_out in zip(optimized_outputs, pytorch_outputs):
torch.testing.assert_close(opt_out, pt_out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment