这次是完胜了
import torch
import numpy as np
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads, win):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.hidden = hidden_dim
def forward(self, input_data, state=None):
# self.head.to(device)
b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size
out = self.head0(input_data)
out1 = torch.max(torch.concat([self.head1(input_data).unsqueeze(-1), out.unsqueeze(-1)], -1), -1)[0]
out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
out = torch.cummax(out * (torch.exp(out1) + h ** 0.5), 2)[0]
out = out.permute([0, 2, 1, 3])
out = out.reshape([b, s, -1])
out = torch.min(torch.concat(
[h ** 0.5-torch.exp(self.head2(input_data).unsqueeze(-1)), out.unsqueeze(-1)],
-1), -1)[0]
return out, state
class KAttention(torch.nn.Module):
def __init__(self, hidden_dim, heads):
super(KAttention, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
# self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
def forward(self, x, state=None):
b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
qk = (q @ k.permute([0, 1, 3, 2])) / d ** 0.5
mask = torch.triu(torch.ones(s, s).to(device))
qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
qkv = torch.nn.functional.softmax(qk, -1) @ v
# v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
#
return qkv, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
x = x1 * x2
x = self.ffn2(x)
return x
class MemoryBlock(torch.nn.Module):
def __init__(self, hidden_dim):
super(MemoryBlock, self).__init__()
# 使用Xavier初始化权重
self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
torch.nn.init.xavier_uniform_(self.fc)
# self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
self.mem = torch.eye(hidden_dim).to(device)
# torch.nn.init.xavier_uniform_(self.mem)
# self.sig = torch.nn.Sigmoid()
def forward(self, x):
x = x @ (self.fc + self.mem)
# x = self.sig(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
# self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
self.self_attention = MaxState(hidden_size, num_heads, 8)
# self.self_attention = KAttention(hidden_size, num_heads)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
def forward(self, x, state=None, seq_len=None):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.ffn(x1) + x)
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.pos = torch.nn.Embedding(1024, hidden_size)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size, False)
# self.head_state = torch.nn.Linear(hidden_size, num_layers, False)
self.down = torch.nn.ModuleList(
[torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])
def state_forward(self, state, pos, x):
if state is None:
state = [None] * len(self.decoder_layers)
i = 0
for ii, decoder_layer in enumerate(self.decoder_layers):
x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))
x1, state[i] = decoder_layer(x, state[i])
x = x1 + x
i += 1
return x, state
def pos_forward(self, x):
if x.shape[1] >= 1024:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos
else:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
return pos
def forward(self, x0):
x0, _ = self.one_forward(x0, state=None)
return x0, _
def one_forward(self, x, state=None, seq_len=None):
x = self.em(x)
pos = self.pos_forward(x)
x, state = self.state_forward(state, pos, x)
return self.head(x), state
device = "cuda"
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net.to(device)
net(torch.randint(0, 200, [2, 8 * 13]).to(device))
#
从图上可以看出cummax 虽然暂时落后 后期 低于 attention
且从训练显存上也节约3-5gb
且100轮时间上也节约0.5小时以上
且推理的时候完全显存空间不变
attention 和 maxstate
import torch
import numpy as np
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads, win):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
self.hidden = hidden_dim
def forward(self, input_data, state=None):
# self.head.to(device)
b, s, k, h= input_data.shape[0], input_data.shape[1], self.head_num, self.head_size
out = self.head0(input_data)
out1 = torch.max(torch.concat([self.head1(input_data).unsqueeze(-1), out.unsqueeze(-1)], -1), -1)[0]
out=out.reshape([b, s,k, h]).permute([0,2,1,3])
out1=out1.reshape([b, s,k, h]).permute([0,2,1,3])
out=torch.cummax(out*torch.exp(out1), 2)[0]
out = out.permute([0, 2, 1, 3])
out = out.reshape([b, s, -1])
out = torch.max(torch.concat([self.head2(input_data).unsqueeze(-1), out.unsqueeze(-1)], -1), -1)[0]
return out, state
class KAttention(torch.nn.Module):
def __init__(self, hidden_dim, heads):
super(KAttention, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.k = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.v = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
# self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head_num = heads
def forward(self, x, state=None):
b, s, h, d = x.shape[0], x.shape[1], self.head_num, self.head_size
q = self.q(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
k = self.k(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
v = self.v(x).reshape([b, s, h, d]).permute([0, 2, 1, 3])
qk = (q @ k.permute([0, 1, 3, 2])) /d**0.5
mask = torch.triu(torch.ones(s, s).to(device))
qk = torch.where(mask.T == 1, qk, torch.Tensor([-float('inf')]).to(device))
qkv = torch.nn.functional.softmax(qk, -1) @ v
# v + torch.arange(1, 3 * s, 3).reshape([1, 1, -1, 1]).to(device) / s / 3)
qkv = qkv.permute([0, 2, 1, 3]).reshape([b, s, -1])
#
return qkv, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
self.relu = torch.nn.ReLU()
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
x = x1 * x2
x = self.ffn2(x)
return x
class MemoryBlock(torch.nn.Module):
def __init__(self, hidden_dim):
super(MemoryBlock, self).__init__()
# 使用Xavier初始化权重
self.fc = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
torch.nn.init.xavier_uniform_(self.fc)
# self.mem = torch.nn.Parameter(torch.empty(hidden_dim, hidden_dim))
self.mem = torch.eye(hidden_dim).to(device)
# torch.nn.init.xavier_uniform_(self.mem)
# self.sig = torch.nn.Sigmoid()
def forward(self, x):
x = x @ (self.fc + self.mem)
# x = self.sig(x)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
# self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)
self.self_attention = MaxState(hidden_size, num_heads, 8)
# self.self_attention = KAttention(hidden_size, num_heads)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
def forward(self, x, state=None, seq_len=None):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.ffn(x1) + x)
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.pos = torch.nn.Embedding(1024, hidden_size)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size, False)
# self.head_state = torch.nn.Linear(hidden_size, num_layers, False)
self.down = torch.nn.ModuleList(
[torch.nn.Linear(2 * hidden_size, hidden_size, False) for _ in range(num_layers)])
def state_forward(self, state, pos, x):
if state is None:
state = [None] * len(self.decoder_layers)
i = 0
for ii, decoder_layer in enumerate(self.decoder_layers):
x = self.down[i](torch.concat([torch.zeros([x.shape[0], 1, 1]).to(device) + pos, x], -1))
x1, state[i] = decoder_layer(x, state[i])
x = x1 + x
i += 1
return x, state
def pos_forward(self, x):
if x.shape[1] >= 1024:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) // 1024).unsqueeze(0)
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device) % 1024).unsqueeze(0) + pos
else:
pos = self.pos(torch.arange(0, x.shape[1]).long().to(device)).unsqueeze(0)
return pos
def forward(self, x0):
x0, _ = self.one_forward(x0, state=None)
return x0, _
def one_forward(self, x, state=None, seq_len=None):
x = self.em(x)
pos = self.pos_forward(x)
x, state = self.state_forward(state, pos, x)
return self.head(x), state
device = "cuda"
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net.to(device)
net(torch.randint(0, 200, [2, 8 * 13]).to(device))
#
训练代码
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob
from tqdm import tqdm
# from model_e import SamOut
# from model_f import SamOut
from model_g import SamOut
import json
def train():
train_data, voc = gen_voc()
net = SamOut(len(voc), 512, 32, 8)
# net.load_state_dict(torch.load("model.pth"))
net.to("cuda")
opt = torch.optim.Adam(params=net.parameters(), lr=0.00003)
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
# loss_func1 = torch.nn.CrossEntropyLoss(ignore_index=3)
# loss_func2 = torch.nn.CrossEntropyLoss(ignore_index=3)
bar = tqdm(range(100))
steps = 0
epoch_loss = []
# label_index2 = label_index1[::2]
# input_index2 = label_index2 - 1
for epoch in bar:
np.random.shuffle(train_data)
loss_list = []
for i in range(0, len(train_data), 100):
j = i + 100
input_one = train_data[i:j]
# input_one = []
# input_two = []
# input_thr = []
#
# for i in data:
# input_one.append(i[0])
# input_two.append(i[1])
# input_thr.append(i[2])
out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))
# loss += torch.nn.functional.mse_loss(torch.sum(out0, -1),
# torch.Tensor(input_one)[:, 1:].long().to("cuda") / len(
# voc))
# loss += loss_func1(out1[:, :-1].reshape([-1, out0.shape[-1]]),
# torch.cat(label, 0)[:, 1:].reshape([-1]).long().to("cuda"))
# loss /= 2
loss_list.append(loss.item())
bar.set_description("epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
opt.zero_grad()
loss.backward()
opt.step()
steps += 100
if steps % 8000 == 0:
torch.save(net.state_dict(), "model.pth")
epoch_loss.append(np.mean(loss_list))
pd.to_pickle(epoch_loss, "loss911")
def gen_voc():
paths = glob("train/*") + glob("test/*")
q_list = []
voc_dict = set()
len_list = []
for path in tqdm(paths):
with open(path, "r", encoding="utf-8") as f:
data = f.read()
data = json.loads(data)
if "answer" not in data:
q = list(data["question"])
else:
q = list(data["question"]) + ["<|bos|>"] + list(data["answer"])
voc_dict.update(set(q))
q = ["<|sos|>"] + q + ["<|eos|>"]
len_list.append(len(q))
q_list.append(q)
voc = ["<|sss|>", "<|sos|>", "<|eos|>", "<|pos|>"] + sorted(voc_dict)
train_list = []
for i in tqdm(q_list):
if len(i) <= 255:
i = [voc.index(i) for i in i] + [voc.index("<|pos|>")] * (256 - len(i))
train_list.append(i)
return train_list, voc
def show_loss():
loss0 = pd.read_pickle("loss0")
loss1 = pd.read_pickle("loss1")
loss2 = pd.read_pickle("loss2")
loss3 = pd.read_pickle("loss3")
loss4 = pd.read_pickle("loss4")
loss5 = pd.read_pickle("loss5")
loss6 = pd.read_pickle("loss6")
loss7 = pd.read_pickle("loss7")
loss8 = pd.read_pickle("loss8")
loss9 = pd.read_pickle("loss9")
loss91 = pd.read_pickle("loss91")
loss911 = pd.read_pickle("loss911")
loss9111 = pd.read_pickle("loss9111")
# plt.plot(loss0)
# plt.plot(loss1)
# plt.plot(loss2)
# plt.plot(loss3)
# plt.plot(loss4)
# plt.plot(loss5)
# plt.plot(loss6)
# plt.plot(np.array(loss7) / 3)
# plt.plot(np.array(loss8))
# plt.plot(np.array(loss9))
# plt.plot(np.array(loss91))
plt.plot(np.array(loss911))
plt.plot(np.array(loss9111))
# plt.legend(["sin", "+", "nor", "state", "pos", "s", "mem", "mm", "8", "9", "91", "cummax","attention"])
plt.legend(["cummax","attention"])
plt.show()
if __name__ == '__main__':
show_loss()
train()
# val()
# eval_data()