Transformer的简单实现

概述

本文将教你如何在PyTorch中实现Transformer模型。首先,我们将介绍Transformer模型的整体流程,并使用表格形式展示每个步骤的详细说明。然后,我们将逐步指导你实现每个步骤所需的代码,并为每行代码提供注释解释其功能。

整体流程

下表展示了Transformer模型的整体流程。我们将按照从上到下的顺序逐步实现每个步骤。

步骤 描述
输入嵌入(Input Embedding) 将输入序列编码为向量表示的嵌入形式
位置编码(Positional Encoding) 将位置信息嵌入到输入向量中,以保留序列的顺序信息
自注意力机制(Self-Attention) 计算输入序列中每个单词相对于其它单词的重要性,并进行加权求和
前馈神经网络(Feedforward Neural Network) 对每个位置的向量进行非线性变换,以增强模型的表达能力
输出层(Output Layer) 将模型输出映射到目标词汇表的概率分布上

输入嵌入(Input Embedding)

输入嵌入的目标是将输入序列编码为向量表示的嵌入形式。在Transformer模型中,我们使用嵌入层将输入序列中的每个词映射为一个固定维度的向量表示。

示例代码如下所示:

import torch
import torch.nn as nn

class InputEmbedding(nn.Module):
    def __init__(self, input_size, embedding_size):
        super(InputEmbedding, self).__init__()
        self.embedding = nn.Embedding(input_size, embedding_size)

    def forward(self, input):
        return self.embedding(input)

代码解释:

  • 首先,我们导入了PyTorch库并定义了一个名为InputEmbedding的类,该类继承自nn.Module
  • 在类的构造函数中,我们初始化了一个nn.Embedding对象,该对象将输入大小和嵌入大小作为参数。
  • 在前向传播函数中,我们将输入序列传递给嵌入层,并返回嵌入向量。

位置编码(Positional Encoding)

位置编码的目标是将位置信息嵌入到输入向量中,以保留序列的顺序信息。在Transformer模型中,我们使用正弦和余弦函数生成位置编码。

示例代码如下所示:

import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len, embedding_size):
        super(PositionalEncoding, self).__init__()
        self.embedding_size = embedding_size

        # 计算位置编码矩阵
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_size, 2) * -(math.log(10000.0) / embedding_size))
        self.position_encoding = torch.zeros(max_seq_len, embedding_size)
        self.position_encoding[:, 0::2] = torch.sin(position * div_term)
        self.position_encoding[:, 1::2] = torch.cos(position * div_term)

    def forward(self, input):
        return input + self.position_encoding[:input.size(1), :]

代码解释:

  • 首先,我们导入了PyTorch库并定义了一个名为PositionalEncoding的类,该类继承自nn.Module
  • 在类的构造函数中,我们首先计算了位置编码矩阵。我们通过计算正弦和余弦函数的值,并将它们与位置相乘得到位置编码矩阵。
  • 在前向传播函数中,我们将位置编码矩阵与输入向量相加,并返回结果。

自注意力机制(Self-Attention)

自注意力机制可以计算输入序列中每个单词相对于其它单词的重要性,并进行加权求和。在Transformer模型中,我们使用多头注意力机制来计算自注意力。

示例代码如下所示: