参考:

attention-is-all-you-need-pytorch

NLP 中的Mask全解

Transformer代码详解-pytorch版

Transformer模型结构

Transformer模型结构如下图:

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘

 

  • Transformer的整体结构就是分成Encoder和Decoder两部分,并且两部分之间是有联系的,可以注意到Encoder的输出是Decoder第二个Multi-head Attention中和的输入。
  • Encoder和Decoder分别由N个EncoderLayer和DecoderLayer组成。N默认为6个。
  • EncoderLayer由两个SubLayers组成,分别是Multi-head Attention和Feed Forward。DecoderLayer则是由三个SubLayers组成,分别是Masked Multi-head Attention,Multi-head Attention和Feed Forward。
  • Multi-head Attention是用ScaledDotProductAttention和Linear组成。Feed Forward是由Linear组成。
  • Add & Norm指的是残差连接之后再进行LayerNorm。

各模块结构结构

Multi-head Attention结构

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘_02

Feed Forward结构

linear attention的pytorch实现 pytorch multi head attention_Mask_03

EncoderLayer结构

linear attention的pytorch实现 pytorch multi head attention_Mask_04

DecoderLayer结构

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘_05

Encoder结构

linear attention的pytorch实现 pytorch multi head attention_Mask_06

Decoder结构

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘_07

 

ScaledDotProductAttention模块

ScaledDotProductAttention做的是一个attention计算。公式如下:

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘_08

输入q k v,可以q先除以根号d_k(d_k默认为64,根号d_k就为8),再与k的转置相乘,再经过softmax,最后与v相乘。下图的操作和公式所做的东西是一样的。

linear attention的pytorch实现 pytorch multi head attention_权重_09

 

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        # 其实就是论文中的根号d_k
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        # sz_b: batch_size 批量大小
        # len_q,len_k,len_v: 序列长度 在这里他们都相等
        # n_head: 多头注意力 默认为8
        # d_k,d_v: k v 的dim(维度) 默认都是64
        # 此时q的shape为(sz_b, n_head, len_q, d_k) (sz_b, 8, len_q, 64)
        # 此时k的shape为(sz_b, n_head, len_k, d_k) (sz_b, 8, len_k, 64)
        # 此时v的shape为(sz_b, n_head, len_k, d_v) (sz_b, 8, len_k, 64)
        # q先除以self.temperature(论文中的根号d_k) k交换最后两个维度(这样才可以进行矩阵相乘) 最后两个张量进行矩阵相乘
        # attn的shape为(sz_b, n_head, len_q, len_k)
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            # 用-1e9代替0 -1e9是一个很大的负数 经过softmax之后接近与0
            # 其一:去除掉各种padding在训练过程中的影响
            # 其二,将输入进行遮盖,避免decoder看到后面要预测的东西。(只用在decoder中)
            attn = attn.masked_fill(mask == 0, -1e9)

        # 先在attn的最后一个维度做softmax 再dropout 得到注意力分数
        attn = self.dropout(F.softmax(attn, dim=-1))
        # 最后attn与v进行矩阵相乘
        # output的shape为(sz_b, 8, len_q, 64)
        output = torch.matmul(attn, v)
        # 返回 output和注意力分数
        return output, attn

MultiHeadAttention和PositionwiseFeedForward模块

MultiHeadAttention做的是将q k v先经过线性层投影,再做ScaledDotProductAttention ,最后经过一个线性层。也就是下图的操作:

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘_10

对应着Transformer的模块是:

linear attention的pytorch实现 pytorch multi head attention_矩阵相乘_02

PositionwiseFeedForward其实就是MLP。对应着Transformer的模块是:

linear attention的pytorch实现 pytorch multi head attention_Mask_03

 

# q k v 先经过不同的线性层 再用ScaledDotProductAttention 最后再经过一个线性层
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        # 这里的n_head, d_model, d_k, d_v分别默认为8, 512, 64, 64
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        # len_q, len_k, len_v 为输入的序列长度
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        # 用作残差连接
        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        # q k v 分别经过一个线性层再改变维度
        # 由(sz_b, len_q, n_head*d_k) => (sz_b, len_q, n_head, d_k) (sz_b, len_q, 8*64) => (sz_b, len_q, 8, 64)
        q = self.w_qs(q).view(sz_b, len_q,