Last active
April 9, 2025 09:32
-
-
Save mrorigo/3993644143805d452f98372c33333b0a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SoftMultiTokenEmbedding(nn.Module): | |
""" | |
A MultiToken embedding layer that performs soft merging of adjacent tokens | |
without requiring attention mask modifications. | |
""" | |
def __init__(self, vocab_size, embedding_dim, dropout_rate=0.1): | |
super().__init__() | |
self.vocab_size = vocab_size | |
self.embedding_dim = embedding_dim | |
self.dropout = dropout_rate | |
# Base token embedding - standard embedding layer | |
self.token_embedding = nn.Embedding(vocab_size, embedding_dim) | |
# Controllable merge detection with limited capacity | |
self.merge_detector = nn.Sequential( | |
nn.Linear(embedding_dim * 2, embedding_dim // 2), | |
nn.Dropout(dropout_rate), | |
nn.GELU(), | |
nn.Linear(embedding_dim // 2, 1), | |
nn.Sigmoid() | |
) | |
# Token fusion with regularization | |
self.token_fusion = nn.Sequential( | |
nn.Linear(embedding_dim * 2, embedding_dim), | |
nn.Dropout(dropout_rate), | |
nn.LayerNorm(embedding_dim) | |
) | |
# Context enrichment | |
self.context_gate = nn.Sequential( | |
nn.Linear(embedding_dim * 2, embedding_dim), | |
nn.Sigmoid() | |
) | |
self.context_transform = nn.Sequential( | |
nn.Linear(embedding_dim * 2, embedding_dim), | |
nn.Dropout(dropout_rate), | |
nn.LayerNorm(embedding_dim) | |
) | |
self.init_weights() | |
def init_weights(self): | |
"""Conservative initialization that doesn't bias the model too much""" | |
# Reset token embedding to ensure proper scale | |
nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02) | |
# Merge detector - initialize to rarely activate at first | |
nn.init.normal_(self.merge_detector[0].weight, std=0.02) | |
nn.init.constant_(self.merge_detector[0].bias, 0.0) | |
nn.init.normal_(self.merge_detector[3].weight, std=0.02) | |
nn.init.constant_(self.merge_detector[3].bias, -3.0) # Very conservative merge bias | |
# Token fusion - standard initialization | |
nn.init.normal_(self.token_fusion[0].weight, std=0.02) | |
nn.init.constant_(self.token_fusion[0].bias, 0.0) | |
# Context gate and transform | |
nn.init.normal_(self.context_gate[0].weight, std=0.02) | |
nn.init.constant_(self.context_gate[0].bias, 0.0) | |
nn.init.normal_(self.context_transform[0].weight, std=0.02) | |
nn.init.constant_(self.context_transform[0].bias, 0.0) | |
def merge_tokens(self, token_embeds): | |
"""Process pairs of tokens to create fused representations without in-place operations""" | |
batch_size, seq_len, embed_dim = token_embeds.shape | |
if seq_len <= 1: | |
return token_embeds | |
# Create fresh tensor for results | |
merged_embeds = torch.zeros_like(token_embeds) | |
# Get adjacent pairs | |
left_tokens = token_embeds[:, :-1] # All but last | |
right_tokens = token_embeds[:, 1:] # All but first | |
# Create pair representation | |
token_pairs = torch.cat([left_tokens, right_tokens], dim=-1) | |
# Compute merge confidence - how much should these tokens be merged | |
merge_scores = self.merge_detector(token_pairs) | |
# Apply strong threshold - only consider significant merges | |
merge_mask = (merge_scores > 0.5).float() | |
filtered_scores = merge_scores * merge_mask | |
# Create fused representations regardless of scores | |
# (the scores will determine how much they're used) | |
fused_tokens = self.token_fusion(token_pairs) | |
# Apply fusion with controlled effect (max 30% contribution) | |
fusion_strength = 0.3 * filtered_scores | |
# Process first token (can only be left of a pair) | |
merged_embeds[:, 0:1] = token_embeds[:, 0:1] | |
# Process middle tokens (can be both left and right of pairs) | |
for i in range(1, seq_len-1): | |
# Contribution as right token of previous pair | |
right_contrib = ( | |
fusion_strength[:, i-1:i] * fused_tokens[:, i-1:i] | |
if i > 0 else 0 | |
) | |
# Contribution as left token of next pair | |
left_contrib = ( | |
fusion_strength[:, i:i+1] * fused_tokens[:, i:i+1] | |
if i < seq_len-1 else 0 | |
) | |
# Base contribution from original token | |
base_weight = ( | |
1.0 - | |
(fusion_strength[:, i-1:i] if i > 0 else 0) - | |
(fusion_strength[:, i:i+1] if i < seq_len-1 else 0) | |
) | |
# Combine contributions without modifying any tensors in-place | |
merged_embeds[:, i:i+1] = ( | |
base_weight * token_embeds[:, i:i+1] + | |
(right_contrib if i > 0 else 0) + | |
(left_contrib if i < seq_len-1 else 0) | |
) | |
# Process last token (can only be right of a pair) | |
if seq_len > 1: | |
last_fusion = fusion_strength[:, -1:] if seq_len > 1 else 0 | |
merged_embeds[:, -1:] = ( | |
(1 - last_fusion) * token_embeds[:, -1:] + | |
last_fusion * fused_tokens[:, -1:] | |
) | |
# Apply dropout for regularization | |
merged_embeds = F.dropout(merged_embeds, p=self.dropout, training=self.training) | |
return merged_embeds | |
def add_context(self, token_embeds): | |
"""Add contextual information using gated mechanism without in-place operations""" | |
batch_size, seq_len, embed_dim = token_embeds.shape | |
if seq_len <= 1: | |
return token_embeds | |
# Create fresh tensor for results | |
context_embeds = torch.zeros_like(token_embeds) | |
# First token remains unchanged (no previous context) | |
context_embeds[:, 0:1] = token_embeds[:, 0:1] | |
# For the rest, add context from previous token | |
current_tokens = token_embeds[:, 1:] | |
prev_tokens = token_embeds[:, :-1] | |
# Create context pairs | |
context_pairs = torch.cat([current_tokens, prev_tokens], dim=-1) | |
# Compute context gate - how much context to include | |
context_gate = self.context_gate(context_pairs) | |
# Transform the context information | |
context_features = self.context_transform(context_pairs) | |
# Apply gated context mechanism - limited to 20% contribution | |
context_contribution = 0.2 * context_gate | |
# Combine with original embeddings (without in-place modifications) | |
context_embeds[:, 1:] = ( | |
(1 - context_contribution) * token_embeds[:, 1:] + | |
context_contribution * context_features | |
) | |
# Apply dropout for regularization | |
context_embeds = F.dropout(context_embeds, p=self.dropout, training=self.training) | |
return context_embeds | |
def forward(self, input_ids): | |
"""Transform token IDs into contextual embeddings""" | |
# Get base embeddings | |
embeds = self.token_embedding(input_ids) | |
# Apply token merging | |
embeds = self.merge_tokens(embeds) | |
# Apply contextual information | |
embeds = self.add_context(embeds) | |
return embeds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment