class GetAttentionHiddens(nn.Module):
def __init__(self, input_size, attention_hidden_size, similarity_attention = False):
super(GetAttentionHiddens, self).__init__()
self.scoring = AttentionScore(input_size, attention_hidden_size, similarity_score=similarity_attention)
def forward(self, x1, x2, x2_mask, x3=None, scores=None, return_scores=False, drop_diagonal=False):
"""
Using x1, x2 to calculate attention score, but x1 will take back info from x3.
If x3 is not specified, x1 will attend on x2.
x1: [batch, len1, x1_input_size]
x2: [batch, len2, x2_input_size]
x2_mask: [batch, len2]
x3: [batch, len2, x3_input_size]
"""
if x3 is None:
x3 = x2
if scores is None:
scores = self.scoring(x1, x2)
# Mask padding
x2_mask = x2_mask.unsqueeze(1).expand_as(scores)
scores.data.masked_fill_(x2_mask.data, -float('inf'))
if drop_diagonal:
assert(scores.size(1) == scores.size(2))
diag_mask = torch.diag(scores.data.new(scores.size(1)).zero_() + 1).byte().unsqueeze(0).expand_as(scores)
scores.data.masked_fill_(diag_mask, -float('inf'))
# Normalize with softmax
alpha = F.softmax(scores, dim=2)
# Take weighted average
matched_seq = alpha.bmm(x3)
if return_scores:
return matched_seq, scores
else:
return matched_seq # [batch, len1, x1_input_size]
class AttentionScore(nn.Module):
"""
sij = Relu(Wx1)DRelu(Wx2)
"""
def __init__(self, input_size, attention_hidden_size, similarity_score = False):
super(AttentionScore, self).__init__()
self.linear = nn.Linear(input_size, attention_hidden_size, bias=False)
if similarity_score:
self.linear_final = Parameter(torch.ones(1, 1, 1) / (attention_hidden_size ** 0.5), requires_grad = False)
else:
self.linear_final = Parameter(torch.ones(1, 1, attention_hidden_size), requires_grad = True)
def forward(self, x1, x2):
"""
x1: [batch, len1, input_size]
x2: [batch, len2, input_size]
scores: [batch, len1, len2]
<the scores are not masked>
"""
x1 = dropout(x1, p=my_dropout_p, training=self.training)
x2 = dropout(x2, p=my_dropout_p, training=self.training)
x1_rep = self.linear(x1.contiguous().view(-1, x1.size(-1))).view(x1.size(0), x1.size(1), -1)
x2_rep = self.linear(x2.contiguous().view(-1, x2.size(-1))).view(x2.size(0), x2.size(1), -1)
x1_rep = F.relu(x1_rep)
x2_rep = F.relu(x2_rep)
final_v = self.linear_final.expand_as(x2_rep)
x2_rep_v = final_v * x2_rep
scores = x1_rep.bmm(x2_rep_v.transpose(1, 2))
return scores # [batch, len1, len2]