Created
February 20, 2018 16:40
-
-
Save jayleicn/4f935df94d2368c7cb459f9294bd3175 to your computer and use it in GitHub Desktop.
bidaf attention layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
class BidafAttnModule(nn.Module): | |
def __init__(self, hidden_size): | |
super(BidafAttnModule, self).__init__() | |
self.fc = nn.Sequential( | |
nn.Linear(hidden_size, hidden_size, bias=False), | |
nn.ReLU() | |
) | |
self.fc_final = nn.Sequential( | |
nn.Linear(4*hidden_size, hidden_size), | |
nn.ReLU() | |
) | |
self.gru = nn.GRU(4*hidden_size, hidden_size) | |
# def forward(self, contexts, questions): | |
# """ | |
# contexts -> #batch, #sen, #token, #hidden -> #batch, #sen * #token, #hidden | |
# questions -> #batch, #seq, #hidden | |
# G -> #batch, #sen * #token, 4*#hidden | |
# """ | |
# batch_num, sen_num, token_num, hidden_size = contexts.size() | |
# contexts = contexts.view(batch_num, sen_num*token_num, hidden_size) | |
# S = self.get_similarity_matrix(contexts, questions) | |
# # attn_q = self.get_context_aware_query(S, questions) | |
# attn_c = self.get_query_aware_context(S, contexts) | |
# # attn_c = attn_c.expand_as(attn_q) | |
# # G = torch.cat([contexts, attn_q, contexts*attn_q, contexts*attn_c], -1) | |
# return attn_c | |
def forward(self, contexts, questions): | |
""" | |
contexts -> #batch, #sen, #token, #hidden -> #batch, #sen * #token, #hidden | |
questions -> #batch, #seq, #hidden | |
attn_c -> #batch, 1, #hidden | |
attn_q -> #batch, #sen*#token, #hidden | |
G -> #batch, #sen * #token, 4*#hidden | |
""" | |
batch_num, sen_num, token_num, hidden_size = contexts.size() | |
contexts = contexts.view(batch_num, sen_num*token_num, hidden_size) | |
S = self.get_similarity_matrix(contexts, questions) | |
attn_q = self.get_context_aware_query(S, questions) | |
attn_c = self.get_query_aware_context(S, contexts) | |
attn_c = attn_c.expand_as(attn_q) | |
G = torch.cat([contexts, attn_q, contexts*attn_q, contexts*attn_c], -1) | |
G, _ = self.gru(G.view(batch_num*sen_num, token_num, -1)) | |
G = torch.max(G, 1)[0] | |
# G = self.fc_final(G) | |
G = G.view(batch_num, sen_num, hidden_size) | |
return G | |
def get_similarity_matrix(self, contexts, questions): | |
""" | |
contexts -> #batch, #sen * #token, #hidden | |
questions -> #batch, #seq, #hidden | |
S -> #batch, #sen*#token, #seq | |
S_tj = h*W*u | |
""" | |
batch_num, _, hidden_size = contexts.size() | |
questions = self.fc(questions.contiguous().view(-1, hidden_size)).view(batch_num, -1, hidden_size) | |
S = torch.bmm(contexts, questions.transpose(1,2)) | |
return S | |
def get_context_aware_query(self, S, questions): | |
""" | |
S -> #batch, #sen*#token, #seq | |
score -> #batch, #sen*#token, #seq | |
questions -> #batch, #seq, #hidden | |
context_aware_questions -> #batch, #sen*#token, #hidden | |
""" | |
score = F.softmax(S, dim=2) | |
context_aware_questions = torch.bmm(score, questions) | |
return context_aware_questions | |
def get_query_aware_context(self, S, contexts): | |
""" | |
S -> #batch, #sen*#token, #seq | |
S_max -> #batch, #sen*#token | |
score -> #batch, #sen*#token -> #batch, 1, #sen*#token | |
contexts -> #batch, #sen * #token, #hidden | |
query_aware_contexts -> #batch, 1, #hidden | |
""" | |
S_max = torch.max(S, 2)[0] | |
score = F.softmax(S_max, dim=1) | |
query_aware_contexts = torch.bmm(score.unsqueeze(1), contexts) | |
return query_aware_contexts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment