下面我们来实现Transformer,在正式编写Transformer之前,我们先来看一下实现Transformer的一个小技巧,这个是我们看懂别人写的Transformer代码的一个关键。

1. 数据技巧

还记得我们在讲Transformer原理时,网络层输出chatGPT学习---Transformer代码实现2_chatGPT的计算公式:
chatGPT学习---Transformer代码实现2_Karpathy_02
是一个向量相加的运算,虽然我们可以直接求出结果,但是效率会比较低,通常的做法是将其变为一个矩阵的乘法运算。
为了简单起见,我们假设是求向量的平均值,最简单的方法如下所示:

def startup(self, args={}):
        print('AppExp v0.0.1')
        torch.manual_seed(1337)
        B, T, C = 4, 8, 2 # B: batch_size;T:序列长度;C:通道数,即词汇维度;
        X = torch.randn(B, T, C)
        xbow1 = self.sum1(X, B, T, C)
        print(xbow1)
        xbow2 = self.sum2(X, B, T, C)
        rst = torch.allclose(xbow1, xbow2)
        print(f'比较结果:xbow1==xbow2 => {rst};')
        xbow3 = self.sum3(X, B, T, C)
        rst = torch.allclose(xbow1, xbow3)
        print(f'xbow1和xbow3是否相等?{rst};')

    def sum1(self, X, B, T, C):
        xbow = torch.zeros((B, T, C)) # bag of words
        for b in range(B):
            for t in range(T):
                xprev = X[b, :t+1] # (t, C)
                xbow[b, t] = torch.mean(xprev, 0) # (b, t)
        return xbow
    
    def sum2(self, X, B, T, C):
        wei = torch.tril(torch.ones(T, T)) # Note1
        wei = wei / wei.sum(1, keepdim=True) # Note2
        return wei @ X
    
    def sum3(self, X, B, T, C):
        tril = torch.tril(torch.ones(T, T))
        wei = torch.zeros((T, T))
        wei = wei.masked_fill(tril==0, float('-inf')) # Note3
        wei = F.softmax(wei, dim=-1) # Note4
        return wei @ X

以上三个方法的结果是相同的,但是计算效率一个比一个高,我们会将这个技巧用到Transformer的实现中。

  • Note1:torch.tril为一个下三角矩阵:
tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
  • Note2:每个元素除以它所在行的和,如下所示:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
  • Note3:是两个矩阵相乘,wei的形状为(8, 8),X的形状为(4, 8, 2),根据张量乘法,wei的(8, 8)与X(8, 2)作传统意义上的矩阵乘法运算,形成一个新的(8, 2),最后再叠加成(4, 8, 2),我们以其中一个为例:
    chatGPT学习---Transformer代码实现2_矩阵相乘_03
  • Note3:
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
  • Note4:我们知道chatGPT学习---Transformer代码实现2_矩阵相乘_04,其余项均相同,所以求softmax后,得到与Note2处相同的矩阵:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

我们做这个操作的意义在于,我们当前位置的单词,假设他只跟它前面的单词有关,我们认为其关系就是所有前面的单词(包括其自身)相加然后再取平均。这样我们每个单词就不仅表示自己的信息,同时也包括了它之前所有单词的信息,从而有效的解决了前文中BigramLanguageModel中,只能看前面一个单词的限制条件。
为了后面做Transformer,我们对BigramLanguageModel类做了一些修改,完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.6

2. 自注意力机制

def startup(self, args={}):
        print('AppExp v0.0.1')
        torch.manual_seed(1337)
        B, T, C = 4, 8, AppRegistry.n_embed # B: batch_size;T:序列长度;C:通道数,即词汇维度;
        X = torch.randn(B, T, C)
        self.self_attention(X, B, T, C)

    def self_attention(self, X, B, T, C):
        W_K = nn.Linear(C, AppRegistry.head_size, bias=False)
        W_Q = nn.Linear(C, AppRegistry.head_size, bias=False)
        W_V = nn.Linear(C, AppRegistry.head_size, bias=False)
        k = W_K(X) # (B, T, h) # Note1
        q = W_Q(X) # (B, T, h) # Note2
        wei = q @ k.transpose(-2, -1) / (AppRegistry.head_size**0.5) # (B, T, h) @ (B, h, T) => (B, T, T) # Note3
        tril = torch.tril(torch.ones(T, T))
        wei = wei.masked_fill(tril==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = W_V(X)
        out = wei @ v
        print(f'out: {out.shape};')

根据上一节的内容,对于输入信号X,我们定义了三个权值矩阵:chatGPT学习---Transformer代码实现2_Karpathy_05chatGPT学习---Transformer代码实现2_chatGPT_06chatGPT学习---Transformer代码实现2_深度学习_07,其中K代表经过索引后便于查询的索引,Q代表要查找的内容,为了便于理解,我们以一个输入信号为例:
chatGPT学习---Transformer代码实现2_transformer_08
第i个单词查询对第j个单词的关联度为:
chatGPT学习---Transformer代码实现2_矩阵相乘_09
其中chatGPT学习---Transformer代码实现2_chatGPT_10,为自注意力头的维度,除以其平方根的目的是为了求softmax时的值变得更平均一些。

3. 自注意力头

class Head(nn.Module):
    def __init__(self, block_size=8, n_embed=32, head_size=16):
        super(Head, self).__init__()
        self.W_K = nn.Linear(n_embed, head_size, bias = False)
        self.W_Q = nn.Linear(n_embed, head_size, bias = False)
        self.W_V = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(head_size, head_size)))

    def forward(self, X):
        B, T, C = X.shape
        k = self.W_K(X) # (B, T, h)
        q = self.W_Q(X) # (B, T, h)
        wei = (q @ k.transpose(-2, -1) / (AppRegistry.head_size**0.5)).to(AppRegistry.device) # (B, T, h) @ (B, h, T) => (B, T, T)
        tril = torch.tril(torch.ones(T, T)).to(AppRegistry.device)
        wei = wei.masked_fill(tril==0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.W_V(X)
        return wei @ v

向模型中添加自注意力头:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.sa_head = Head(head_size=AppRegistry.n_embed, n_embed=AppRegistry.n_embed, block_size=AppRegistry.block_size)
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_head(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -AppRegistry.block_size:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] # (B, T, C) => (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

为了验证程序的正确性,可以运行一下训练过程。完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.7

4. 多头机制

添加多头模型支持:

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([Head(AppRegistry.block_size, AppRegistry.n_embed, head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(AppRegistry.n_embed, AppRegistry.n_embed)

    def forward(self, X):
        X = torch.cat([h(X) for h in self.heads], dim=-1)
        return self.proj(X)

向BigramLanguageModel中添加支持多头模型:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.sa_heads = MultiHeadAttention(num_heads=4, head_size=AppRegistry.n_embed//4)
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_heads(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.8

5. 添加前向传播网络

class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super(FeedForward, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
        )

    def forward(self, X):
        return self.net(X)

使用前向传播网络:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.sa_heads = MultiHeadAttention(num_heads=4, head_size=AppRegistry.n_embed//4)
        self.ffwd = FeedForward(AppRegistry.n_embed)
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.sa_heads(x)
        x = self.ffwd(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.9

6. 添加Block

我们下面来添加Encoder的Block:

class TransformerEncoderBlock(nn.Module):
    def __init__(self, n_embed, n_head):
        super(TransformerEncoderBlock, self).__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)

    def forward(self, X):
        X = self.sa(X)
        X = self.ffwd(X)
        return X

使用Block:

class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 词汇数,单词维度
        self.token_embedding_table = nn.Embedding(AppRegistry.vocab_size, AppRegistry.n_embed)
        self.position_embedding_table = nn.Embedding(AppRegistry.block_size, AppRegistry.n_embed)
        self.blocks = nn.Sequential(
            TransformerEncoderBlock(n_embed=AppRegistry.n_embed, n_head=4),
            TransformerEncoderBlock(n_embed=AppRegistry.n_embed, n_head=4),
            TransformerEncoderBlock(n_embed=AppRegistry.n_embed, n_head=4),
        )
        self.lm_head = nn.Linear(AppRegistry.n_embed, AppRegistry.vocab_size)

    def forward(self, idx):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, C) C=n_embed
        pos_emb = self.position_embedding_table(torch.arange(T, device=AppRegistry.device))
        x = tok_emb + pos_emb # (B, T, C)
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        return logits

完整代码请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.0.10

7. 添加Residue连接和dropout

我们添加Residue连接和dropout,并且调整超参数的值,就可以得到最终版本,由于调整的地方比较多,而且比较杂,具体修改请参考:

git clone https://gitee.com/yt7589/hwcgpt.git
cd hwcgpt
git checkout v0.1.0