文章目录

  • 写在前面——定义
  • 维度符号
  • 输入数据
  • KQV矩阵
  • 算法核心
  • attention核心
  • By the way……
  • pytorch 实现


写在前面——定义

维度符号

字母

B

U

E

H

Dkq

Dv

含义

batch 大小

组数据长度(例如:一句话有多少个字,一时间序列包含多少天数据)

数据表示维度(例如:一个字用多少维数据表示,一天数据包含多少个不同方面的数据)

多头attention机制中的头数

每个头中KQ矩阵用多少维数据表示

每个头中V矩阵用多少维数据表示

注:Dkq和Dv一般都是相等的,但是可以不相等,不影响计算过程

举例:

  • 定义堆长度为15,使用50天天气预报数据预测下一天天气情况,使用降雨量、降雪量、日照时间、风速4个数据表示每天天气情况;此处 B=15,U=50,E=4.
  • 使用8头attention机制,其中每个头KQ矩阵进行从4个神经元到64个神经元的全连接扩展,每个头V矩阵进行从4个神经元到32个神经元的全连接扩展;此处 H=8,Dkq=64,Dv=32.

输入数据

一般情况下,输入数据以堆为单位进行输入

# x维度为:[B,U,E]

KQV矩阵

通过线性回归方式生成 Python attention python attention库_数据

# 使用 pytorch 库
LQ = nn.Linear(E, H * Dkq)
LK = nn.Linear(E, H * Dkq)
LV = nn.Linear(E, H * Dv)

# x 维度:[B,U,E]
# view 函数可改变数据维度,作用类似 numpy 的 reshape
# transpose 函数调换对应维度位置
Q = LQ(x).view(B,U,H,-1).transpose(1, 2)
K = LK(x).view(B,U,H,-1).transpose(1, 2)
V = LV(x).view(B,U,H,-1).transpose(1, 2)

经过上述操作,生成矩阵的维度为:

Q

K

V

[B,H,U,Dqk]

[B,H,U,Dqk]

[B,H,U,Dv]

算法核心

attention核心

Python attention python attention库_attention_02
根据上述公式:

  1. Python attention python attention库_Python attention_03即矩阵相乘,矩阵维度相乘过程为:Python attention python attention库_attention_04其中Python attention python attention库_pytorch_05矩阵进行转置而后进行矩阵相乘,因此必须要求Python attention python attention库_attention_06矩阵生成维度相同。
  2. 定义成绩矩阵: Python attention python attention库_attention_07 维度为:Python attention python attention库_attention_08。可看成句子里面每个字与整句话每个字(包括自己)的相关度,或者每天数据与一组天数每天(包括自己)的相关度。
  3. 成绩矩阵进行 Python attention python attention库_Python attention_09 (归一化)处理后,使用其对 Python attention python attention库_数据表示_10 矩阵进行调整,维度变化为Python attention python attention库_数据表示_11因此 Python attention python attention库_数据表示_10 矩阵和 Python attention python attention库_attention_06
  4. 简略代码如下:
# 使用 pytorch 库
# 其中 scale 为公式中的系数 根号(dk)
scores = nn.Softmax(dim=-1)(torch.matmul(Q, K.transpose(-1, -2)) * .scale)
Z = torch.matmul(scores, V)

By the way……

最后使用线性回归处理 Python attention python attention库_attention_14

fully_connection= nn.Linear(Dv * H, E)
# transpose 函数调换 H 和 U 的位置
# reshape 函数调整最终维度
output = fully_connection(Z.transpose(1, 2).reshape(B, U, E))

pytorch 实现

import torch
import torch.nn.functional as f

class Attention(nn.Module, ABC):
    def __init__(self, embed_dim: int, n_heads: int, kq_dim: int = None, v_dim: int = None,
                 mask_flag: bool = False, factor: int = 5, scale: float = None):
        super(FullAttention, self).__init__()
        if embed_dim < n_heads:
            raise Exception("embedding dimension must greater than heads number.")

        self.kq_dim = kq_dim or (embed_dim // n_heads)
        self.v_dim = v_dim or (embed_dim // n_heads)
        self.n_heads = n_heads
        self.factor = factor

        self.matrix_Q = nn.Linear(embed_dim, self.kq_dim * n_heads, bias=False)
        self.matrix_K = nn.Linear(embed_dim, self.kq_dim * n_heads, bias=False)
        self.matrix_V = nn.Linear(embed_dim, self.v_dim * n_heads, bias=False)

        self.scale = scale or 1.0 / math.sqrt(self.kq_dim)
        self.mask_flag = mask_flag

        self.fully_con = nn.Linear(self.v_dim * n_heads, embed_dim, bias=False)

    def forward(self, queries: torch.tensor, keys: torch.tensor, values: torch.tensor):
        batch, unit_q, _ = queries.shape
        _, unit_v, _ = values.shape
        heads = self.n_heads

        q_vector = self.matrix_Q(queries).view(batch, unit_q, heads, -1).transpose(1, 2)
        k_vector = self.matrix_K(keys).view(batch, unit_v, heads, -1).transpose(1, 2)
        v_vector = self.matrix_V(values).view(batch, unit_v, heads, -1).transpose(1, 2)

        scores = torch.matmul(q_vector, k_vector.transpose(-1, -2)) * self.scale

		# 此处是 mask 操作,有机会再介绍~~~
        if self.mask_flag:
            mask = triangular_mask(unit_q, unit_v)
            if torch.cuda.is_available():
                mask = mask.cuda()
            scores.masked_fill_(mask, -np.inf)

        scores = nn.Softmax(dim=-1)(scores)
        z_vector = torch.matmul(scores, v_vector)
        full_output = self.fully_con(z_vector.transpose(1, 2).reshape(batch, unit_q, -1))

        return full_output, scores