记录 Transformer 模型中几种典型位置编码的实现。
略
class RotaryPositionalEncoding(nn.Module):
"""
RoPE 实现。参考:https://github.com/aju22/RoPE-PyTorch。
"""
def __init__(self, d: int, dropout: float=0.0, base: int = 10000):
super().__init__()
self.base = base
self.d = d
self.dropout = nn.Dropout(p=dropout)
self.cos_cached = None
self.sin_cached = None
def _build_cache(self, x: torch.Tensor):
if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
return
seq_len = x.shape[0]
theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)
seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) #Position Index -> [0,1,2...seq-1]
idx_theta = torch.einsum('n,d->nd', seq_idx, theta) #Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]
self.cos_cached = idx_theta2.cos()[:, None, :] #Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
self.sin_cached = idx_theta2.sin()[:, None, :] #cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]
def _neg_half(self, x: torch.Tensor):
"""
x: [seq_len, batch_size, d_model]
"""
d_2 = self.d // 2
return torch.cat([-x[:, :, d_2:], x[:, :, :d_2]], dim=-1) # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2]
def forward(self, x: torch.Tensor):
"""
x: [seq_len, batch_size, d_model]
"""
self._build_cache(x)
neg_half_x = self._neg_half(x)
x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # [x_1*cosTHETA_1 - x_d/2*sinTHETA_d/2, ....]
return self.dropout(x_rope)
以时间点作为位置为例:
class TimeAwarePositionalEncoding(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_model = d_model
# 可学习的时间间隔权重
self.time_weights = nn.Parameter(torch.randn(d_model))
def forward(self, x: torch.Tensor, time_delta: torch.Tensor) -> torch.Tensor:
"""
x: [batch_size, seq_len, d_model]
time_delta: [batch_size, seq_len] (累积时间间隔)
"""
batch_size, seq_len, _ = x.size()
# 生成时间位置编码
# position = time_delta.cumsum(dim=1) # 累计时间
position = time_delta
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model)).to(x.device)
pe = torch.zeros(batch_size, seq_len, self.d_model).to(x.device)
pe[..., 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
pe[..., 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
# 时间间隔权重调整
time_weights = torch.sigmoid(time_delta.unsqueeze(-1) * self.time_weights)
return x + pe * time_weights