Created
December 16, 2024 13:31
-
-
Save debashishc/c4f5bfa5a4c4e417b48b6bb47a2f3f3c to your computer and use it in GitHub Desktop.
Python script demonstrating how to transfer weights from a SelfAttention_v2 instance to a SelfAttention_v1 instance in PyTorch to ensure identical outputs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sa_v1 = SelfAttention_v1(d_in, d_out) | |
sa_v2 = SelfAttention_v2(d_in, d_out) | |
# Transfer weights from sa_v2 to sa_v1 | |
with torch.no_grad(): | |
sa_v1.W_query.copy_(sa_v2.W_query.weight.T) | |
sa_v1.W_key.copy_(sa_v2.W_key.weight.T) | |
sa_v1.W_value.copy_(sa_v2.W_value.weight.T) | |
x = torch.randn(10, d_in) # Batch size of 10 | |
output_v1 = sa_v1(x) | |
output_v2 = sa_v2(x) | |
print(torch.allclose(output_v1, output_v2, atol=1e-6)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment