Created
July 17, 2025 16:44
-
-
Save Isotr0py/eef7470ff176a28ac40340b883cf1abe to your computer and use it in GitHub Desktop.
Benchmark jina embeddings v4 vision pooling kernels
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager | |
import triton | |
import triton.language as tl | |
import torch | |
import torch.nn.functional as F | |
from array import array | |
VISION_START_TOKEN_ID = 151652 | |
VISION_END_TOKEN_ID = 151653 | |
VISION_TOKEN_ID = 151655 | |
@contextmanager | |
def time_counter(enable: bool, label: str): | |
if enable: | |
import time | |
start_time = time.time() | |
yield | |
elapsed_time = time.time() - start_time | |
print(f"-- {label} = {elapsed_time}") | |
else: | |
yield | |
@triton.jit | |
def extract_vision_tokens_kernel( | |
hidden_states_ptr, | |
token_ids_ptr, | |
output_ptr, | |
seq_start, | |
seq_len, | |
hidden_size, | |
vision_start_id: tl.constexpr, | |
vision_end_id: tl.constexpr, | |
BLOCK_SIZE: tl.constexpr, | |
): | |
"""Triton kernel to extract and pool vision tokens efficiently.""" | |
pid = tl.program_id(0) | |
if pid >= hidden_size: | |
return | |
# Find vision token range | |
vision_count = 0 | |
accumulator = 0.0 | |
for i in range(seq_len): | |
token_id = tl.load(token_ids_ptr + seq_start + i) | |
if token_id >= vision_start_id and token_id <= vision_end_id: | |
hidden_val = tl.load(hidden_states_ptr + | |
(seq_start + i) * hidden_size + pid) | |
accumulator += hidden_val | |
vision_count += 1 | |
# Store mean pooled result | |
result = accumulator / vision_count if vision_count > 0 else 0.0 | |
tl.store(output_ptr + pid, result) | |
def _apply_vision_pooling_optimized( | |
hidden_size: int, | |
hidden_states: torch.Tensor, | |
token_ids_list: list[array], | |
prompt_lens: torch.Tensor, | |
) -> list[torch.Tensor]: | |
"""Apply optimized vision token pooling using Triton kernels.""" | |
pooled_outputs = [] | |
offset = 0 | |
device = hidden_states.device | |
for i, (token_ids, | |
prompt_len) in enumerate(zip(token_ids_list, prompt_lens)): | |
prompt_len = int(prompt_len.item()) | |
# Convert token IDs to tensor | |
token_tensor = torch.tensor(list(token_ids), | |
dtype=torch.long, | |
device=device) | |
# Allocate output tensor | |
output = torch.zeros(hidden_size, | |
device=device, | |
dtype=hidden_states.dtype) | |
# Check for vision tokens | |
has_vision = torch.any((token_tensor >= VISION_START_TOKEN_ID) | |
& (token_tensor <= VISION_END_TOKEN_ID)) | |
if has_vision: | |
# Use Triton kernel for vision token extraction | |
grid = (hidden_size, ) | |
extract_vision_tokens_kernel[grid]( | |
hidden_states, | |
token_tensor, | |
output, | |
offset, | |
prompt_len, | |
hidden_size, | |
VISION_START_TOKEN_ID, | |
VISION_END_TOKEN_ID, | |
BLOCK_SIZE=1024, | |
) | |
else: | |
# Regular mean pooling for text | |
seq_states = hidden_states[offset:offset + prompt_len] | |
output = seq_states.mean(dim=0) | |
# Normalize and handle potential NaNs by replacing with zeros | |
output = F.normalize(output, p=2, dim=-1, eps=1e-12) | |
pooled_outputs.append(output) | |
offset += prompt_len | |
return pooled_outputs | |
def _apply_vision_pooling_pytorch( | |
hidden_size: int, | |
hidden_states: torch.Tensor, | |
token_ids_list: list[array], | |
prompt_lens: torch.Tensor, | |
) -> list[torch.Tensor]: | |
"""PyTorch fallback for vision token pooling.""" | |
pooled_outputs = [] | |
offset = 0 | |
for token_ids, prompt_len in zip(token_ids_list, prompt_lens): | |
prompt_len = int(prompt_len.item()) | |
# Extract sequence states and tokens | |
seq_states = hidden_states[offset:offset + prompt_len] | |
# Convert array to tensor for processing | |
seq_tokens = torch.tensor(list(token_ids[:prompt_len]), | |
dtype=torch.long, | |
device=hidden_states.device) | |
# Check for vision tokens | |
vision_mask = ((seq_tokens >= VISION_START_TOKEN_ID) & | |
(seq_tokens <= VISION_END_TOKEN_ID)) | |
if vision_mask.any(): | |
# Pool only vision tokens | |
vision_states = seq_states[vision_mask] | |
if vision_states.numel() == 0: | |
pooled = seq_states.mean(dim=0) | |
else: | |
pooled = vision_states.mean(dim=0) | |
else: | |
# Pool all tokens for text | |
pooled = seq_states.mean(dim=0) | |
# Normalize embeddings | |
pooled = F.normalize(pooled, p=2, dim=-1, eps=1e-12) | |
pooled_outputs.append(pooled) | |
offset += prompt_len | |
return pooled_outputs | |
hidden_size = 2048 | |
seq_len_per_image = [512, 1024, 8192, 16384] | |
text_seq_len = 2048 | |
num_images = [1, 2, 4] | |
dtype = torch.bfloat16 | |
# Test optimized vision pooling | |
for num_image in num_images: | |
for img_seq_len in seq_len_per_image: | |
print() | |
print(f"Image sequence length: {img_seq_len}, Text sequence length: {text_seq_len}, Number of images: {num_image}") | |
# Simulate hidden states and token IDs | |
prompt_len = img_seq_len + 2 * num_image + text_seq_len | |
hidden_states = torch.randn(prompt_len, hidden_size, device='cuda', dtype=dtype) | |
token_ids_list = [([151652] + [151655] * img_seq_len + [151653]) * num_image + [123] * text_seq_len] | |
prompt_lens = torch.tensor([prompt_len], device='cuda', dtype=dtype) | |
_apply_vision_pooling_optimized(hidden_size, hidden_states, token_ids_list, prompt_lens) | |
with time_counter(enable=True, label="triton vision pooling"): | |
for _ in range(100): | |
optimized_outputs = _apply_vision_pooling_optimized( | |
hidden_size, hidden_states, token_ids_list, prompt_lens) | |
with time_counter(enable=True, label="pytorch vision pooling"): | |
for _ in range(100): | |
pytorch_outputs = _apply_vision_pooling_pytorch( | |
hidden_size, hidden_states, token_ids_list, prompt_lens) | |
for opt_out, pt_out in zip(optimized_outputs, pytorch_outputs): | |
torch.testing.assert_close(opt_out, pt_out) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment