Skip to content

Instantly share code, notes, and snippets.

@mrorigo
Last active April 9, 2025 09:32
Show Gist options
  • Save mrorigo/3993644143805d452f98372c33333b0a to your computer and use it in GitHub Desktop.
Save mrorigo/3993644143805d452f98372c33333b0a to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftMultiTokenEmbedding(nn.Module):
"""
A MultiToken embedding layer that performs soft merging of adjacent tokens
without requiring attention mask modifications.
"""
def __init__(self, vocab_size, embedding_dim, dropout_rate=0.1):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = dropout_rate
# Base token embedding - standard embedding layer
self.token_embedding = nn.Embedding(vocab_size, embedding_dim)
# Controllable merge detection with limited capacity
self.merge_detector = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim // 2),
nn.Dropout(dropout_rate),
nn.GELU(),
nn.Linear(embedding_dim // 2, 1),
nn.Sigmoid()
)
# Token fusion with regularization
self.token_fusion = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim),
nn.Dropout(dropout_rate),
nn.LayerNorm(embedding_dim)
)
# Context enrichment
self.context_gate = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim),
nn.Sigmoid()
)
self.context_transform = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim),
nn.Dropout(dropout_rate),
nn.LayerNorm(embedding_dim)
)
self.init_weights()
def init_weights(self):
"""Conservative initialization that doesn't bias the model too much"""
# Reset token embedding to ensure proper scale
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
# Merge detector - initialize to rarely activate at first
nn.init.normal_(self.merge_detector[0].weight, std=0.02)
nn.init.constant_(self.merge_detector[0].bias, 0.0)
nn.init.normal_(self.merge_detector[3].weight, std=0.02)
nn.init.constant_(self.merge_detector[3].bias, -3.0) # Very conservative merge bias
# Token fusion - standard initialization
nn.init.normal_(self.token_fusion[0].weight, std=0.02)
nn.init.constant_(self.token_fusion[0].bias, 0.0)
# Context gate and transform
nn.init.normal_(self.context_gate[0].weight, std=0.02)
nn.init.constant_(self.context_gate[0].bias, 0.0)
nn.init.normal_(self.context_transform[0].weight, std=0.02)
nn.init.constant_(self.context_transform[0].bias, 0.0)
def merge_tokens(self, token_embeds):
"""Process pairs of tokens to create fused representations without in-place operations"""
batch_size, seq_len, embed_dim = token_embeds.shape
if seq_len <= 1:
return token_embeds
# Create fresh tensor for results
merged_embeds = torch.zeros_like(token_embeds)
# Get adjacent pairs
left_tokens = token_embeds[:, :-1] # All but last
right_tokens = token_embeds[:, 1:] # All but first
# Create pair representation
token_pairs = torch.cat([left_tokens, right_tokens], dim=-1)
# Compute merge confidence - how much should these tokens be merged
merge_scores = self.merge_detector(token_pairs)
# Apply strong threshold - only consider significant merges
merge_mask = (merge_scores > 0.5).float()
filtered_scores = merge_scores * merge_mask
# Create fused representations regardless of scores
# (the scores will determine how much they're used)
fused_tokens = self.token_fusion(token_pairs)
# Apply fusion with controlled effect (max 30% contribution)
fusion_strength = 0.3 * filtered_scores
# Process first token (can only be left of a pair)
merged_embeds[:, 0:1] = token_embeds[:, 0:1]
# Process middle tokens (can be both left and right of pairs)
for i in range(1, seq_len-1):
# Contribution as right token of previous pair
right_contrib = (
fusion_strength[:, i-1:i] * fused_tokens[:, i-1:i]
if i > 0 else 0
)
# Contribution as left token of next pair
left_contrib = (
fusion_strength[:, i:i+1] * fused_tokens[:, i:i+1]
if i < seq_len-1 else 0
)
# Base contribution from original token
base_weight = (
1.0 -
(fusion_strength[:, i-1:i] if i > 0 else 0) -
(fusion_strength[:, i:i+1] if i < seq_len-1 else 0)
)
# Combine contributions without modifying any tensors in-place
merged_embeds[:, i:i+1] = (
base_weight * token_embeds[:, i:i+1] +
(right_contrib if i > 0 else 0) +
(left_contrib if i < seq_len-1 else 0)
)
# Process last token (can only be right of a pair)
if seq_len > 1:
last_fusion = fusion_strength[:, -1:] if seq_len > 1 else 0
merged_embeds[:, -1:] = (
(1 - last_fusion) * token_embeds[:, -1:] +
last_fusion * fused_tokens[:, -1:]
)
# Apply dropout for regularization
merged_embeds = F.dropout(merged_embeds, p=self.dropout, training=self.training)
return merged_embeds
def add_context(self, token_embeds):
"""Add contextual information using gated mechanism without in-place operations"""
batch_size, seq_len, embed_dim = token_embeds.shape
if seq_len <= 1:
return token_embeds
# Create fresh tensor for results
context_embeds = torch.zeros_like(token_embeds)
# First token remains unchanged (no previous context)
context_embeds[:, 0:1] = token_embeds[:, 0:1]
# For the rest, add context from previous token
current_tokens = token_embeds[:, 1:]
prev_tokens = token_embeds[:, :-1]
# Create context pairs
context_pairs = torch.cat([current_tokens, prev_tokens], dim=-1)
# Compute context gate - how much context to include
context_gate = self.context_gate(context_pairs)
# Transform the context information
context_features = self.context_transform(context_pairs)
# Apply gated context mechanism - limited to 20% contribution
context_contribution = 0.2 * context_gate
# Combine with original embeddings (without in-place modifications)
context_embeds[:, 1:] = (
(1 - context_contribution) * token_embeds[:, 1:] +
context_contribution * context_features
)
# Apply dropout for regularization
context_embeds = F.dropout(context_embeds, p=self.dropout, training=self.training)
return context_embeds
def forward(self, input_ids):
"""Transform token IDs into contextual embeddings"""
# Get base embeddings
embeds = self.token_embedding(input_ids)
# Apply token merging
embeds = self.merge_tokens(embeds)
# Apply contextual information
embeds = self.add_context(embeds)
return embeds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment