多头注意力与Masking

在自然语言处理和计算机视觉等任务中,多头注意力机制已经被广泛应用。它是一种能够捕捉输入序列中不同部分的相关性的机制,通过对输入序列进行多次注意力计算,获得不同的注意力权重向量。然而,在使用多头注意力时,为了正确地计算注意力权重,我们需要使用mask(掩码)来过滤掉无效的输入。

多头注意力

多头注意力机制是Transformer模型的核心组件之一,它由Scaled Dot-Product Attention构成。Scaled Dot-Product Attention通过计算查询向量(Q)、键向量(K)和值向量(V)之间的注意力权重,生成加权的值向量。

在多头注意力机制中,我们将输入序列分别映射到不同的注意力头(即不同的线性变换),然后将每个头得到的注意力权重进行拼接,并通过线性变换得到最终的输出。

import torch
import torch.nn as nn

class MultiheadAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads):
        super(MultiheadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        self.query_proj = nn.Linear(input_dim, hidden_dim)
        self.key_proj = nn.Linear(input_dim, hidden_dim)
        self.value_proj = nn.Linear(input_dim, hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, query, key, value):
        batch_size = query.size(0)
        
        # 线性变换
        query = self.query_proj(query)
        key = self.key_proj(key)
        value = self.value_proj(value)
        
        # 拆分为多个注意力头
        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled Dot-Product Attention
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        attention_weights = nn.Softmax(dim=-1)(scores)
        attention_output = torch.matmul(attention_weights, value)
        
        # 拼接并线性变换
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.output_proj(attention_output)
        
        return output, attention_weights

Masking

在使用多头注意力时,我们需要使用mask来过滤掉无效的输入。在自然语言处理中,一种常见的mask是padding mask,用于过滤掉填充部分的输入。另一种常见的mask是sequence mask,用于过滤掉未来的输入,即在解码过程中不允许看到未来的信息。

对于padding mask,我们可以使用如下代码将填充部分的注意力权重设置为一个较小的值,从而在计算注意力加权和时将其抑制:

def create_padding_mask(seq):
    seq = torch.eq(seq, 0)
    return seq.unsqueeze(1).unsqueeze(2)

对于sequence mask,我们可以使用如下代码将未来的位置的注意力权重设置为一个较小的值,从而在计算注意力加权和时将其抑制:

def create_sequence_mask(seq):
    seq_len = seq.size(1)
    mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=1)
    return mask.unsqueeze(0)

使用Mask的多头注意力

接下来,我们将使用mask来过滤输入并应用多头注意力。

input_dim = 512
hidden_dim = 256
num_heads = 8

attention = MultiheadAttention(input_dim, hidden_dim, num_heads)

# 假设我们有一个输入序列
seq = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0], [8, 9, 0, 0, 0]])

# 创建padding mask
padding_mask = create_padding_mask(seq)

# 创建sequence mask