引言

本文通过Pytorch实现了Seq2Seq中常用的注意力方式。

注意力方式

Seq2Seq中常见注意力机制的实现_注意力机制

结合论文​​Effective Approaches to Attention-based Neural Machine Translation​​​和​​NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE​​,我们得到上面四种计算注意力的方式。

编码器的每个输出Seq2Seq中常见注意力机制的实现_权重_02对应的权重Seq2Seq中常见注意力机制的实现_代码实现_03通过如下公式计算:
Seq2Seq中常见注意力机制的实现_注意力机制_04
其中
Seq2Seq中常见注意力机制的实现_注意力实现_05

见(论文翻译) NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE

代码实现

import torch.nn as nn
import torch


class Attention(nn.Module):

def __init__(self, hidden_size, method='dot'):
super(Attention, self).__init__()

self.method = method
self.hidden_size = hidden_size

if self.method not in ['dot', 'general', 'concat', 'bahdanau']:
raise ValueError(self.method, "is not an appropriate attention method.")

if self.method == 'general':
self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
elif self.method == 'concat':
self.Wa = nn.Linear(hidden_size * 2, hidden_size, bias=False)
self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))
elif self.method == 'bahdanau':
self.Wa = nn.Linear(hidden_size, hidden_size, bias=False)
self.Ua = nn.Linear(hidden_size, hidden_size, bias=False)
self.va = nn.Parameter(torch.FloatTensor(1, hidden_size))

def _score(self, last_hidden, encoder_outputs):
'''

:param last_hidden: 解码器最后一层(若有多层的话)的输出 [1,batch_size,hidden_size] 解码器一次只处理一个时间步,并且只有一个方向: D=1
:param encoder_outputs: 编码器所有时间步的隐藏状态 [seq_len, batch_size, hidden_size]
'''

if self.method == 'dot':
# last_hidden * encoder_outputs [seq_len, batch_size, hidden_size]
# sum(x, dim=2) 将第2个维度的值累计,累计第2个维度的值,使其维度大小变成1,并移除,得到 [seq_len, batch_size]
# 计算每个批次内, 解码器当前时间步 与编码器每个时间步的 权重得分
# 计算e_i
return torch.sum(last_hidden * encoder_outputs, dim=2) # [seq_len, batch_size]
elif self.method == 'general':
energy = self.Wa(last_hidden) # [1, batch_size, hidden_size]
# [seq_len, batch_size, hidden_size] x [1, batch_size, hidden_size] = [seq_len, batch_size, hidden_size]
return torch.sum(encoder_outputs * energy, dim=2) # [seq_len, batch_size]

elif self.method == 'concat':
# last_hidden.expand(encoder_outputs.size(0), -1, -1)) # [seq_len, batch_size, hidden_size] 对维度0进行复制操作
# 复制seq_len份,以支持cat操作
# cat(*, dim=2) [seq_len, batch_size, hidden_size*2]
# energy = tanh(self.Wa(*)) [seq_len,batch_size, hidden_size]
energy = torch.tanh(
self.Wa(torch.cat((encoder_outputs, last_hidden.expand(encoder_outputs.size(0), -1, -1)), dim=2)))
return torch.sum(self.va * energy, dim=2) # [seq_len, batch_size]

else: # method == 'bahdanau'
# self.Wa(last_hidden) [1,batch_size,hidden_size]
# self.Ua(encoder_outputs) [seq_len, batch_size, hidden_size]
# torch.tanh(*) [seq_len, batch_size, hidden_size]
energy = torch.tanh(self.Wa(last_hidden) + self.Ua(encoder_outputs))
return torch.sum(self.va * energy, dim=2) # [seq_len, batch_size]

def forward(self, last_hidden, encoder_outputs):
# 注意力得分,见_score方法,返回的大小都是 [seq_len, batch_size]
attn_energies = self._score(last_hidden, encoder_outputs)
# 转置 [batch_size, seq_len]
attn_energies = attn_energies.t()
# 经过softmax,得到权重系数,我们要计算对每个时间步的权重,所以沿着时间步的维度计算
# 并且计算之后,形状保持不变。
# 计算上面公式(6) α_i
return torch.softmax(attn_energies, dim=1) \
.unsqueeze(1) # unsqueeze(1) 在dim=1处,扩展一个维度,形状变成 [batch_size, 1, seq_len]