Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created November 6, 2024 20:28
Show Gist options
  • Save youkaichao/2633b5f0f405fe5674a6d80e26a7da75 to your computer and use it in GitHub Desktop.
Save youkaichao/2633b5f0f405fe5674a6d80e26a7da75 to your computer and use it in GitHub Desktop.
inplace embedding
import torch
import torch.nn as nn
# an Embedding module containing 10 tensors of size 3
embedding = nn.Embedding(10, 3)
embedding.weight.requires_grad_(False)
# a batch of 4 indices
input = torch.LongTensor([1, 2, 4, 5])
output = embedding(input)
output_buffer = torch.empty(4, 3)
torch.index_select(embedding.weight, dim=0, index=input, out=output_buffer)
assert torch.allclose(output, output_buffer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment