Skip to content

Instantly share code, notes, and snippets.

@Tehada
Created June 27, 2025 13:21
Show Gist options
  • Save Tehada/ac2159f68a1686cf8542596988d30a13 to your computer and use it in GitHub Desktop.
Save Tehada/ac2159f68a1686cf8542596988d30a13 to your computer and use it in GitHub Desktop.
Welcome to fish, the friendly interactive shell
Type help for instructions on how to use fish
mint on ξ‚  initial_fixes_to_reproduce_results [!⇑] via 🐍 v3.12.7
❯ diff -u ../../facebookresearch/esm/esm/modules.py mint/modules.py | diff-so-fancy
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
renamed: ../../facebookresearch/esm/esm/modules.py to mint/modules.py
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
@ mint/modules.py:95 @
add_bias_kv=True,
use_esm1b_layer_norm=False,
use_rotary_embeddings: bool = False,
use_multimer=False,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.use_rotary_embeddings = use_rotary_embeddings
self.use_multimer = use_multimer
self._init_submodules(add_bias_kv, use_esm1b_layer_norm)
def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm):
@ mint/modules.py:115 @
add_zero_attn=False,
use_rotary_embeddings=self.use_rotary_embeddings,
)
if self.use_multimer:
self.multimer_attn = MultiheadAttention(
self.embed_dim,
self.attention_heads,
add_bias_kv=add_bias_kv,
add_zero_attn=False,
use_rotary_embeddings=False,
no_proj=True,
)
self.self_attn_layer_norm = BertLayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
@ mint/modules.py:138 @
):
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=True,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
if self.use_multimer:
self_attn, self_v = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
before_softmax=True,
)
multimer_attn, multimer_v = self.multimer_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
before_softmax=True,
)
attn_weights = torch.where(self_attn_mask.unsqueeze(1), multimer_attn, self_attn)
attn_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(attn_weights)
attn_probs_dropout = F.dropout(
attn_probs, p=self.self_attn.dropout, training=self.training
)
self_attn_probs = attn_probs_dropout.masked_fill(self_attn_mask.unsqueeze(1), 0.0)
multimer_attn_probs = attn_probs_dropout.masked_fill(~self_attn_mask.unsqueeze(1), 0.0)
attn_out = torch.matmul(self_attn_probs, self_v) + torch.matmul(
multimer_attn_probs, multimer_v
)
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(*attn_out.shape[:2], -1)
x = self.self_attn.out_proj(attn_out).transpose(0, 1).contiguous()
if need_head_weights:
attn = attn_probs.transpose(0, 1).contiguous()
else:
attn = attn_probs.mean(1)
else:
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=True,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
x = residual + x
residual = x
mint on ξ‚  initial_fixes_to_reproduce_results [!⇑] via 🐍 v3.12.7
❯
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment