Skip to content

Instantly share code, notes, and snippets.

@debashishc
Created December 16, 2024 13:31
Show Gist options
  • Save debashishc/c4f5bfa5a4c4e417b48b6bb47a2f3f3c to your computer and use it in GitHub Desktop.
Save debashishc/c4f5bfa5a4c4e417b48b6bb47a2f3f3c to your computer and use it in GitHub Desktop.
Python script demonstrating how to transfer weights from a SelfAttention_v2 instance to a SelfAttention_v1 instance in PyTorch to ensure identical outputs.
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v2 = SelfAttention_v2(d_in, d_out)
# Transfer weights from sa_v2 to sa_v1
with torch.no_grad():
sa_v1.W_query.copy_(sa_v2.W_query.weight.T)
sa_v1.W_key.copy_(sa_v2.W_key.weight.T)
sa_v1.W_value.copy_(sa_v2.W_value.weight.T)
x = torch.randn(10, d_in) # Batch size of 10
output_v1 = sa_v1(x)
output_v2 = sa_v2(x)
print(torch.allclose(output_v1, output_v2, atol=1e-6))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment