samout终于超过了transformers(attention)_pytorch

这次是完胜了

import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.head_num = heads

        self.hidden = hidden_dim

    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size

        out = self.head0(input_data)
        out1 = torch.max(torch.concat([self.head1(input_data).unsqueeze(-1), out.unsqueeze(-1)], -1), -1)[0]

        out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
        out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])

        out = torch.cummax(out * (torch.exp(out1) + h ** 0.5), 2)[0]

        out = out.permute([0, 2, 1, 3])
        out = out.reshape([b, s, -1])
        out = torch.min(torch.concat(
            [h ** 0.5-torch.exp(self.head2(input_data).unsqueeze(-1)), out.unsqueeze(-1)],
            -1), -1)[0]

        return out, state


class KAttention(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(KAttention, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        # self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads

    def forward(self, x, state=None):
        b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
        q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        qk = (q @ k.permute([0, 1, 3, 2])) / d ** 0.5
        mask = torch.triu(torch.ones(s, s).to(device))
        qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
        qkv = torch.nn.functional.softmax(qk, -1) @ v
        #             v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
        qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
        #
        return qkv, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x


class MemoryBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(MemoryBlock, self).__init__()

        # 使用Xavier初始化权重
        self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.fc)

        # self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.mem = torch.eye(hidden_dim).to(device)
        # torch.nn.init.xavier_uniform_(self.mem)

        # self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = x @ (self.fc + self.mem)
        # x = self.sig(x)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        # self.self_attention = KAttention(hidden_size, num_heads)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.ffn(x1) + x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        # self.head_state = torch.nn.Linear(hidden_size, num_layers, False)

        self.down = torch.nn.ModuleList(
            [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])

    def state_forward(self, state, pos, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))

            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x0):
        x0, _ = self.one_forward(x0, state=None)

        return x0, _

    def one_forward(self, x, state=None, seq_len=None):
        x = self.em(x)

        pos = self.pos_forward(x)

        x, state = self.state_forward(state, pos, x)

        return self.head(x), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device))
    #

samout终于超过了transformers(attention)_深度学习_02


从图上可以看出cummax 虽然暂时落后 后期 低于 attention

且从训练显存上也节约3-5gb

且100轮时间上也节约0.5小时以上

且推理的时候完全显存空间不变

attention 和 maxstate

import torch
import numpy as np


class MaxState(torch.nn.Module):
    def __init__(self, hidden_dim, heads, win):
        super(MaxState, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)

        self.head_num = heads

        self.hidden = hidden_dim


    def forward(self, input_data, state=None):
        # self.head.to(device)
        b, s, k, h= input_data.shape[0], input_data.shape[1], self.head_num, self.head_size



        out = self.head0(input_data)
        out1 = torch.max(torch.concat([self.head1(input_data).unsqueeze(-1), out.unsqueeze(-1)], -1), -1)[0]


        out=out.reshape([b, s,k, h]).permute([0,2,1,3])
        out1=out1.reshape([b, s,k, h]).permute([0,2,1,3])

        out=torch.cummax(out*torch.exp(out1), 2)[0]


        out = out.permute([0, 2, 1, 3])
        out = out.reshape([b, s, -1])

        out = torch.max(torch.concat([self.head2(input_data).unsqueeze(-1), out.unsqueeze(-1)], -1), -1)[0]

        return out, state


class KAttention(torch.nn.Module):
    def __init__(self, hidden_dim, heads):
        super(KAttention, self).__init__()

        assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

        self.head_size = hidden_dim // heads
        self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        # self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.head_num = heads

    def forward(self, x, state=None):
        b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
        q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
        qk = (q @ k.permute([0, 1, 3, 2])) /d**0.5
        mask = torch.triu(torch.ones(s, s).to(device))
        qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
        qkv = torch.nn.functional.softmax(qk, -1) @ v
        #             v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
        qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
        #
        return qkv, state


class FeedForward(torch.nn.Module):
    def __init__(self, hidden_size):
        super(FeedForward, self).__init__()

        self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
        self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = self.ffn1(x)

        x2 = self.relu(self.gate(x))

        x = x1 * x2

        x = self.ffn2(x)
        return x


class MemoryBlock(torch.nn.Module):
    def __init__(self, hidden_dim):
        super(MemoryBlock, self).__init__()

        # 使用Xavier初始化权重
        self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        torch.nn.init.xavier_uniform_(self.fc)

        # self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        self.mem = torch.eye(hidden_dim).to(device)
        # torch.nn.init.xavier_uniform_(self.mem)

        # self.sig = torch.nn.Sigmoid()

    def forward(self, x):
        x = x @ (self.fc + self.mem)
        # x = self.sig(x)
        return x


class DecoderLayer(torch.nn.Module):
    def __init__(self, hidden_size, num_heads):
        super(DecoderLayer, self).__init__()
        # self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
        self.self_attention = MaxState(hidden_size, num_heads, 8)
        # self.self_attention = KAttention(hidden_size, num_heads)
        self.ffn = FeedForward(hidden_size)
        self.layer_norm = torch.nn.LayerNorm(hidden_size)

    def forward(self, x, state=None, seq_len=None):
        x1, state = self.self_attention(x, state)
        x = self.layer_norm(self.ffn(x1) + x)

        return x, state


class SamOut(torch.nn.Module):
    def __init__(self, voc_size, hidden_size, num_heads, num_layers):
        super(SamOut, self).__init__()
        self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
        self.pos = torch.nn.Embedding(1024, hidden_size)

        self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
        self.head = torch.nn.Linear(hidden_size, voc_size, False)
        # self.head_state = torch.nn.Linear(hidden_size, num_layers, False)

        self.down = torch.nn.ModuleList(
            [torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])

    def state_forward(self, state, pos, x):
        if state is None:
            state = [None] * len(self.decoder_layers)
        i = 0
        for ii, decoder_layer in enumerate(self.decoder_layers):
            x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))

            x1, state[i] = decoder_layer(x, state[i])
            x = x1 + x
            i += 1
        return x, state

    def pos_forward(self, x):
        if x.shape[1] >= 1024:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos

        else:
            pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
        return pos

    def forward(self, x0):
        x0, _ = self.one_forward(x0, state=None)

        return x0, _

    def one_forward(self, x, state=None, seq_len=None):
        x = self.em(x)

        pos = self.pos_forward(x)

        x, state = self.state_forward(state, pos, x)

        return self.head(x), state


device = "cuda"
if __name__ == '__main__':
    net = SamOut(235, 256, 16, 4)
    net.to(device)
    net(torch.randint(0, 200, [2, 8 * 13]).to(device))
    #

训练代码

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob

from tqdm import tqdm
# from model_e import SamOut
# from model_f import SamOut
from model_g import SamOut
import json


def train():
    train_data, voc = gen_voc()

    net = SamOut(len(voc), 512, 32, 8)
    # net.load_state_dict(torch.load("model.pth"))
    net.to("cuda")
    opt = torch.optim.Adam(params=net.parameters(), lr=0.00003)
    loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
    # loss_func1 = torch.nn.CrossEntropyLoss(ignore_index=3)
    # loss_func2 = torch.nn.CrossEntropyLoss(ignore_index=3)

    bar = tqdm(range(100))
    steps = 0
    epoch_loss = []

    # label_index2 = label_index1[::2]
    # input_index2 = label_index2 - 1
    for epoch in bar:
        np.random.shuffle(train_data)
        loss_list = []
        for i in range(0, len(train_data), 100):
            j = i + 100
            input_one = train_data[i:j]
            # input_one = []
            # input_two = []
            # input_thr = []
            #
            # for i in data:
            #     input_one.append(i[0])
            #     input_two.append(i[1])
            #     input_thr.append(i[2])

            out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
            loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
                              torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))
            # loss += torch.nn.functional.mse_loss(torch.sum(out0, -1),
            #                                      torch.Tensor(input_one)[:, 1:].long().to("cuda") / len(
            #                                          voc))
            # loss += loss_func1(out1[:, :-1].reshape([-1, out0.shape[-1]]),
            #                    torch.cat(label, 0)[:, 1:].reshape([-1]).long().to("cuda"))
            # loss /= 2

            loss_list.append(loss.item())
            bar.set_description("epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
            opt.zero_grad()
            loss.backward()
            opt.step()
            steps += 100
            if steps % 8000 == 0:
                torch.save(net.state_dict(), "model.pth")
        epoch_loss.append(np.mean(loss_list))
        pd.to_pickle(epoch_loss, "loss911")


def gen_voc():
    paths = glob("train/*") + glob("test/*")
    q_list = []
    voc_dict = set()
    len_list = []
    for path in tqdm(paths):
        with open(path, "r", encoding="utf-8") as f:
            data = f.read()
            data = json.loads(data)
        if "answer" not in data:
            q = list(data["question"])
        else:

            q = list(data["question"]) + ["<|bos|>"] + list(data["answer"])
        voc_dict.update(set(q))
        q = ["<|sos|>"] + q + ["<|eos|>"]
        len_list.append(len(q))

        q_list.append(q)
    voc = ["<|sss|>", "<|sos|>", "<|eos|>", "<|pos|>"] + sorted(voc_dict)

    train_list = []
    for i in tqdm(q_list):
        if len(i) <= 255:
            i = [voc.index(i) for i in i] + [voc.index("<|pos|>")] * (256 - len(i))
            train_list.append(i)

    return train_list, voc


def show_loss():
    loss0 = pd.read_pickle("loss0")
    loss1 = pd.read_pickle("loss1")
    loss2 = pd.read_pickle("loss2")
    loss3 = pd.read_pickle("loss3")
    loss4 = pd.read_pickle("loss4")
    loss5 = pd.read_pickle("loss5")
    loss6 = pd.read_pickle("loss6")
    loss7 = pd.read_pickle("loss7")
    loss8 = pd.read_pickle("loss8")
    loss9 = pd.read_pickle("loss9")
    loss91 = pd.read_pickle("loss91")
    loss911 = pd.read_pickle("loss911")
    loss9111 = pd.read_pickle("loss9111")
    # plt.plot(loss0)
    # plt.plot(loss1)
    # plt.plot(loss2)
    # plt.plot(loss3)
    # plt.plot(loss4)
    # plt.plot(loss5)
    # plt.plot(loss6)
    # plt.plot(np.array(loss7) / 3)
    # plt.plot(np.array(loss8))
    # plt.plot(np.array(loss9))
    # plt.plot(np.array(loss91))
    plt.plot(np.array(loss911))
    plt.plot(np.array(loss9111))

    # plt.legend(["sin", "+", "nor", "state", "pos", "s", "mem", "mm", "8", "9", "91", "cummax","attention"])
    plt.legend(["cummax","attention"])
    plt.show()


if __name__ == '__main__':
    show_loss()
    train()
    # val()
    # eval_data()