Skip to content

Instantly share code, notes, and snippets.

@hzhu212
Created March 18, 2025 02:37
Show Gist options
  • Save hzhu212/4336f6df6a4fb8f3dd46f18edff31059 to your computer and use it in GitHub Desktop.
Save hzhu212/4336f6df6a4fb8f3dd46f18edff31059 to your computer and use it in GitHub Desktop.
位置编码-Positional Encoding

记录 Transformer 模型中几种典型位置编码的实现。

1. 默认位置编码

2. 旋转位置编码

class RotaryPositionalEncoding(nn.Module):
    """
    RoPE 实现。参考:https://github.com/aju22/RoPE-PyTorch。
    """
    def __init__(self, d: int, dropout: float=0.0, base: int = 10000):
        super().__init__()
        self.base = base
        self.d = d
        self.dropout = nn.Dropout(p=dropout)
        self.cos_cached = None
        self.sin_cached = None

    def _build_cache(self, x: torch.Tensor):
        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
            return
        seq_len = x.shape[0]
        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)
        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) #Position Index -> [0,1,2...seq-1]
        idx_theta = torch.einsum('n,d->nd', seq_idx, theta)  #Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]
        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]
        self.cos_cached = idx_theta2.cos()[:, None, :] #Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
        self.sin_cached = idx_theta2.sin()[:, None, :] #cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]

    def _neg_half(self, x: torch.Tensor):
        """
        x: [seq_len, batch_size, d_model]
        """
        d_2 = self.d // 2
        return torch.cat([-x[:, :, d_2:], x[:, :, :d_2]], dim=-1) # [x_1, x_2,...x_d] -> [-x_d/2, ... -x_d, x_1, ... x_d/2]

    def forward(self, x: torch.Tensor):
        """
        x: [seq_len, batch_size, d_model]
        """
        self._build_cache(x)
        neg_half_x = self._neg_half(x)
        x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]]) # [x_1*cosTHETA_1 - x_d/2*sinTHETA_d/2, ....]
        return self.dropout(x_rope)

3. 带权重的位置编码

以时间点作为位置为例:

class TimeAwarePositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        # 可学习的时间间隔权重
        self.time_weights = nn.Parameter(torch.randn(d_model))
        
    def forward(self, x: torch.Tensor, time_delta: torch.Tensor) -> torch.Tensor:
        """
        x: [batch_size, seq_len, d_model]
        time_delta: [batch_size, seq_len] (累积时间间隔)
        """
        batch_size, seq_len, _ = x.size()
        
        # 生成时间位置编码
        # position = time_delta.cumsum(dim=1)  # 累计时间
        position = time_delta
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model)).to(x.device)
        
        pe = torch.zeros(batch_size, seq_len, self.d_model).to(x.device)
        pe[..., 0::2] = torch.sin(position.unsqueeze(-1) * div_term)
        pe[..., 1::2] = torch.cos(position.unsqueeze(-1) * div_term)
        
        # 时间间隔权重调整
        time_weights = torch.sigmoid(time_delta.unsqueeze(-1) * self.time_weights)
        return x + pe * time_weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment