目录

  • 需要掌握的基础知识
  • 1:Encoder- Decoder架构
  • 2:LSTM模型原理
  • 3:Attention机制
  • 基于Loung Attention+LSTM的机器翻译模型
  • 模型
  • 数据
  • 训练
  • 基于Bahdanau Attention+LSTM的机器翻译模型
  • 模型
  • 数据
  • 训练


需要掌握的基础知识

1:Encoder- Decoder架构

Encoder又称编码器,但我更喜欢叫他通用的特征提取器。直观的理解就是Encoder在某一个高维空间对输入X进行信息抽取和理解加工。

Decoder又称解码器,他的任务是将encoder在某个高维空间下理解加工好的信息重新变成我们人类可以理解的信息。

这样一个信息经过Encoder再经过Decoder的结构我们就称为Encoder- Decoder架构

2:LSTM模型原理

LSMT是由RNN模型发展而来的,因为RNN模型容易出现梯度爆炸和梯度消失的问题(链式求导法则使得导数是乘积的形式),导致模型无法学习长度较长的句子。LSMT通过门控机制来解决这个问题。LSTM的模型内部有两条线,一个是hidden state,一个是cell state。cell state是主线,保存长距离的总体信息(long term)。hidden state和输入x信息共同决定门控系统的门打开多少。从下图可以看出,cell state会经过两个操作之后得到下一阶段的cell state。1:经过遗忘门。2:加上新输入的信息。这类似于我们人类读书,获取新信息的同时,会讲之前不重要的信息忘记。下一时刻的hidden state是由下一时刻的cell state经过激活函数tanh和输出门共同作用的结果

机器翻译模型 本地部署_机器翻译模型 本地部署

3:Attention机制

Attention机制说白了就是加权平均,不同的权重计算方法衍生出了各种不同的attention。
在机器翻译中,就是对经过encoder之后的所有token进行加权平均。再输入给decoder。

  • Luong attention(乘性attention)
  • Bahdanau attention(加性attention)

    在具体实现的细节中:
  • luong attention是用decoder 当前时刻 的hidden state和encoder的output去计算注意力。用注意力对encoder的output做加权平均得到ct,再将ct和当前的时间点的decoder hidden state拼接,经过一个ffn去做分类
  • Bahdanau attention是用decoder 前一时刻 的hidden state和encoder的output去计算注意力。用注意力对encoder的output做加权平均得到ct,再将ct和当前的时间点的decoder hidden state拼接。做为hidden state输入到下一时刻的lstm中。

基于Loung Attention+LSTM的机器翻译模型

模型

采用Encoder-Decoder架构和LSTM+Luong attention
要注意的细节:

  • 要注意attention类中需要做mask,因为输入的句子经过了pad,计算注意力的时候只需要注意真正句子的部分,而不用注意pad的位置。所以在softmax算权重之前,需要给pad的位置加上一个很小很小的数字比如1e-6。这样softmax之后,pad位置上的权重就会特别特别小。note:不要以为把pad上的数变成0也可以。完全不对!
  • 在计算attention的时候,用到两个部分,一是encoder的output,二是decoder的output,这里与论文中的结构有所不同,论文中是使用decoder的hidden
    size。
  • RNN在计算的时候可以使用nn.utils.rnn.pack_padded_sequence,有两点好处。
    1:可以让hidden size更加真实,因为我们的句子是经过padding的,而rnn在运行的时候会一直编码下去,一直到句子的结尾,这就导致了padding的位置也经过了rnn,导致rnn输出不够精确。
    2: 可以加快rnn的编码速度,进行的pad的位置自动停止,可以加速rnn
  • 使用过nn.utils.rnn.pack_padded_sequence经过rnn之后,要使用他的反函数把序列恢复成本来的结构nn.utils.rnn.pad_packed_sequence.
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, dropout=0.2):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, x, x_len):  # x.shape = (batch_size,seq_len,vocab_size)
        embed = self.embedding(x)  # embed.shape = (batch_size,seq_len,embedding_size)
        embed = self.dropout(embed)
        packed_embed = nn.utils.rnn.pack_padded_sequence(embed, x_len, batch_first=True, enforce_sorted=False)
        #pad_packed_embed, _ = nn.utils.rnn.pad_packed_sequence(packed_embed, batch_first=True) #test
        original_idx = packed_embed.unsorted_indices
        packed_out, hid = self.rnn(packed_embed)

        pad_packed_embed, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        out = pad_packed_embed
        hidden = hid[0][:, original_idx]
        cell = hid[1][:, original_idx]

        hidden = self.dense(torch.cat([hidden[-1], hidden[-2]], dim=-1)).unsqueeze(0)
        cell = self.dense(torch.cat([cell[-1], cell[-2]], dim=-1)).unsqueeze(0)
        # out.shape = (batch_size,seq_len,2*hidden_size)
        # hidden.shape = (1,batch_size,hidden_size)
        return out, (hidden, cell)
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.W_a = nn.Linear(2 * hidden_size, hidden_size, bias=False)

    def forward(self, decoder_output, context, mask):
        # context.shape = (batch_size,en_seq_len,2*hidden_size)
        # decoder_output = (batch_size,de_seq_len,hidden_size)
        context_in = self.W_a(context).transpose(1, 2)  # context_in.shape = (batch_size,hidden_size,en_seq_len)

        attention = torch.bmm(decoder_output, context_in)  # attention.shape = (batch_size,de_seq_len,en_seq_len)

        attention = attention.masked_fill(mask, -1e6)

        attention_softmaxed = torch.softmax(attention, dim=2)  # attention.shape = (batch_size,de_seq_len,en_seq_len)

        weighted_sum = torch.bmm(attention_softmaxed,
                                 context)  # weighted_sum.shape = (batch_size,de_seq_len,2*hidden_size)

        return weighted_sum




class Decoder(nn.Module):
    def __init__(self,vocab_size,embedding_size,hidden_size,dropout = 0.2):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size,embedding_size)
        self.rnn = nn.LSTM(embedding_size,hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hidden_size,vocab_size)
        self.attention = Attention(hidden_size)
        self.dense = nn.Linear(hidden_size*3,vocab_size)




    def forward(self,y,y_len,hidden,context,x_len):
        embed = self.dropout(self.embed(y))
        packed_embed = nn.utils.rnn.pack_padded_sequence(embed,y_len,batch_first=True,enforce_sorted=False)
        original_idx = packed_embed.unsorted_indices

        packed_rnn,hidden = self.rnn(packed_embed,hidden)

        out,_ = nn.utils.rnn.pad_packed_sequence(packed_rnn,batch_first=True)

        #out = out[original_idx]
        mask = self.generate_mask_for_attention(y_len,x_len)
        c_t = self.attention(out,context,mask)
        out = self.dense(torch.tanh(torch.cat([c_t,out],dim=2)))


        hidden = (hidden[0][:,original_idx],hidden[1][:,original_idx])
        out = F.log_softmax(out,2)
        return out,hidden

    def generate_mask_for_attention(self,x_len,y_len):
        # (batch_size, de_seq_len, en_seq_len)
        device = x_len.device
        x_len_max = max(x_len)
        y_len_max = max(y_len)
        x_mask = torch.arange(x_len_max,device=x_len.device)[None,:]<x_len[:,None] # batch_size,max_x
        y_mask = torch.arange(y_len_max,device=x_len.device)[None,:] < y_len[:,None] # batch_size,max_y
        mask = (x_mask[:,:,None]*y_mask[:,None,:]).logical_not().bool()
        return mask


class seq_2_seq_withattention(nn.Module):
    def __init__(self,encoder,decoder):
        super(seq_2_seq_withattention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self,x,xlen,y,ylen):
        en_out,hid = self.encoder(x,xlen)
        out,hid = self.decoder(y,ylen,hid,en_out,xlen)

        return out
    def translate(self,x,y_input,x_len,max_len = 20):
        encoder_out,hid = self.encoder(x,x_len)
        prediction = []
        batch_size = x.shape[0]
        y_len = torch.ones(batch_size)
        y_len = y_len.long()
        for i in range(max_len):
            out,hid = self.decoder(y_input,y_len,hid,encoder_out,x_len)
            y_input = torch.argmax(out,dim=-1).view(batch_size,1)
            prediction.append(chs_idx_to_word[int(y_input[0][0])])
        return prediction

数据

机器翻译的数据预处理,
以英文到中文为例:
src(原语言)英文 enu
tgt(目标语言)中文 chs

  • 建立词表,src和tgt都需要建立,因为中文和英文的词表是不同的。建立词表的时候可以选择不同的粒度。可以是每一个字作为一个token,也可以是一个词作为一个token。在拉丁语系中,还可以用词根来作为最小单元。这里中文和英文都直接使用空格来进行分词。建立词表的时候还需要4个特殊字符:“unk”表示未知字符 “pad"表示是pad的位置
    “BOS” 表示begin of sentence
    “EOS” 表示 end of sentence
def load_data(file_path):
    enu = []
    chs = []

    with open(file_path) as f:
        for line in f:
            sentence_pair = line.split("\t")
            enu.append(["BOS"]+ nltk.word_tokenize(sentence_pair[0].lower())+["EOS"])
            chs.append(["BOS"] + [word for word in sentence_pair[1]][:-1] + ["EOS"])
    return enu,chs
def build_dic(sources,most_common = 20000):
    word_dic = Counter()
    for source in sources:
        for sentence in source:
            for word in sentence:
                word_dic[word]+=1
    word_dic = word_dic.most_common(most_common)
    dic = {}
    for i,count,in enumerate(word_dic):
        dic[count[0]] = i+2
    dic["UNK"] = UNK_IDX
    dic["PAD"] = PAD_IDX
    dic_len = len(dic)
    return dic,dic_len
  • 根据词表将文字编码成数字
def sentence_word_to_index(sentences,dic):
    sentence_onehot = []
    sentences_lengths = [len(sentence) for sentence in sentences]
    eighty_percent_length = int(np.percentile(sentences_lengths,80))
    #eighty_percent_length = max(sentences_lengths)
    for sentence in sentences:
        sentence_with_pad = [0 for i in range(eighty_percent_length)]
        for i,word in enumerate(sentence):
            if i>=eighty_percent_length:
                break
            sentence_with_pad[i] = dic[word]
        sentence_onehot.append(sentence_with_pad)
    sentences_lengths = [min(lenth,eighty_percent_length) for lenth in sentences_lengths]
    return np.array(sentence_onehot),np.array(sentences_lengths)
  • 创建minibatch用于分批训练
    这里手动shuffle数据提取minibatch,也可以使用pytorch提供的dataset和dataloader进行minibatch的生成(from torch.utils.data import Dataset,DataLoader)
def get_mini_batch(batch_size,n,shuffle = True):
    batch_idx = np.arange(0,n,batch_size)
    if shuffle:
        np.random.shuffle(batch_idx)
    mini_batches = []
    for idx in batch_idx:
        mini_batches.append([idx,min(idx+batch_size,n)])
    return mini_batches


def prepare_data(enu,chs,enu_dic,chs_dic):
    batch_data = []
    enu_onehot,enu_len = sentence_word_to_index(enu,enu_dic)
    chs_onehot,chs_len = sentence_word_to_index(chs,chs_dic)

    mini_batch = get_mini_batch(batch_size=64,n = len(enu))
    for batch in mini_batch:
        mb_enu = enu_onehot[batch[0]:batch[1],:]
        mb_chs = chs_onehot[batch[0]:batch[1],:]
        enu_l = enu_len[batch[0]:batch[1]]
        chs_l = chs_len[batch[0]:batch[1]]
        batch_data.append((mb_enu,enu_l,mb_chs,chs_l))
    return batch_data

训练

损失函数这里使用了Cross Entropy Loss:
实现的时候没有直接调用torch.nn.CrossEntropyLoss.而是手动计算了Cross Entropy Loss。
交叉熵损失的公式:sum(-plogq) 其中p是真实的分布 也就是ground truth,q是模型的输出。
在机器翻译的训练中,p一般都是一个one-hot表示(知识蒸馏中p则不是one-hot)。所以公式可以进一步简化成 loss = -logq 这里的q是真实标签对应的那个q
手动拆解torch.nn.CrossEntropyLoss:
第一步:计算log(softmax(decoder_out)) #(decoder的最后一步已经计算出来了)
第二步:找到真实标签对应的decoder output。# 使用gather函数
第三步:一二步的结果相乘再乘(-1)
特别注意这里计算loss还要考虑pad。计算loss的时候不需要计算pad位置的loss,所以创建一个01的mask 0代表是pad位置 1代表真实句子的位置。计算的loss乘这个mask,才是不包括pad位置的loss。然后再计算loss的平均值。
如果要使用torch自带的CrossEntropyLoss,需要将reduction设置成‘none’,然后计算出来的结果手动乘mask 和上述步骤一样。

class LanguageModelLoss(nn.Module):
    def __init__(self):
        super(LanguageModelLoss, self).__init__()


    def forward(self,pre,label,mask):
        # pre.shape = (batch_size,seq_len,vocab_size)
        # label.shape = (batch_size,seq_len,1)
        # mask.shape = (batch_size,seqlen)
        pre = pre.view(-1,pre.size(2))
        label = label.contiguous().view(-1,1)
        mask = mask.view(-1,1)

        out = -pre.gather(1,label)*mask
        out = torch.sum(out)/torch.sum(mask)
        return out



embedding_size = 500
hidden_size = 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(enu_vocab_size,embedding_size,hidden_size)
decoder = Decoder(chs_vocab_size,embedding_size,hidden_size)
model = seq_2_seq_withattention(encoder,decoder).to(device)

loss_fn = LanguageModelLoss().to(device)
optimizer = torch.optim.Adam(model.parameters())
model_path = "./mt_model.pth"
def train(model,data,epochs):
    model.train()
    eval_loss = 100
    for epoch in range(epochs):
        for i,(mb_x,mb_x_len,mb_y,mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device)
            mb_x_len = torch.from_numpy(mb_x_len).to(device)
            mb_y= torch.from_numpy(mb_y).to(device)
            mb_y_input = mb_y[:,:-1].to(device)
            mb_y_output = mb_y[:,1:].to(device)
            mb_y_len = torch.from_numpy(mb_y_len-1).to(device)
            pre = model(mb_x,mb_x_len,mb_y_input,mb_y_len)
            mask = torch.arange(mb_y_len.max(),device = device)[None,:]<mb_y_len[:,None]
            loss = loss_fn(pre,mb_y_output,mask)

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm(model.parameters(),5.)
            optimizer.step()
            print("epoch:{} iter:{} loss:{}".format(epoch,i,loss))

            if i%10 == 0:
                result = eval(model,dev_data)
                if result<eval_loss:
                    eval_loss = result
                    torch.save(model.state_dict(),model_path)

            model.train()


def eval(model,data):
    model.eval()
    with torch.no_grad():
        for i,(mb_x,mb_x_len,mb_y,mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device)
            mb_x_len = torch.from_numpy(mb_x_len).to(device)
            mb_y= torch.from_numpy(mb_y).to(device)
            mb_y_input = mb_y[:,:-1].to(device)
            mb_y_output = mb_y[:,1:].to(device)
            mb_y_len = torch.from_numpy(mb_y_len-1).to(device)
            pre = model(mb_x,mb_x_len,mb_y_input,mb_y_len)
            mask = torch.arange(mb_y_len.max(),device = device)[None,:]<mb_y_len[:,None]
            loss = loss_fn(pre,mb_y_output,mask)
            print("eval_loss",loss)

            return loss

基于Bahdanau Attention+LSTM的机器翻译模型

模型

模型整体架构和luong attention相似。不同的地方在于attention的计算
需要注意的点:

  • 注意力计算用到的是decoder上一时刻的hidden state和encoder的output。得到加权平均之后的向量Ct。Ct会和embedding去做拼接,然后作为decoder下一时刻的输入。
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, dropout=0.2):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, batch_first=True, bidirectional=True,num_layers=2)
        self.dropout = nn.Dropout(dropout)
        self.dense = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, x, x_len):  # x.shape = (batch_size,seq_len,vocab_size)
        embed = self.embedding(x)  # embed.shape = (batch_size,seq_len,embedding_size)
        embed = self.dropout(embed)
        packed_embed = nn.utils.rnn.pack_padded_sequence(embed, x_len, batch_first=True, enforce_sorted=False)
        original_idx = packed_embed.unsorted_indices
        packed_out, hid = self.rnn(packed_embed)

        pad_packed_embed, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)

        out = pad_packed_embed
        hidden = hid[0][:, original_idx]
        cell = hid[1][:, original_idx]

        hidden = self.dense(torch.cat([hidden[-1], hidden[-2]], dim=-1)).unsqueeze(0)
        cell = self.dense(torch.cat([cell[-1], cell[-2]], dim=-1)).unsqueeze(0)
        # out.shape = (batch_size,seq_len,2*hidden_size)
        # hidden.shape = (1,batch_size,hidden_size)
        return out, (hidden, cell)


class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.W_a = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.W_b = nn.Linear(hidden_size,hidden_size,bias=False)
        self.W_c = nn.Linear(hidden_size,1)
        self.dense = nn.Linear(2*hidden_size,hidden_size)

    def forward(self, decoder_hidden, context, mask):
        # context.shape = (batch_size,encoder_seq_len,2*hidden_size)
        # decoder_hidden.shape = (1,batch_size,hidden_size)
        context_in = self.W_a(context)  # context_in.shape = (batch_size,en_seq_len,hidden_size)
        hidden_in = self.W_b(decoder_hidden).transpose(0,1) # hidden_in.shape = (batch_size,1,hidden_size)
        attention = self.W_c(torch.tanh(context_in+hidden_in))
        # attention.shape = (batch_size,en_seq_len,1)
        attention = attention.masked_fill(mask,-1e6)
        attention_softmaxed = torch.softmax(attention,1)
        weighted_sum = torch.bmm(attention_softmaxed.transpose(1,2),context)
        # weighted_sum.shape = (batch_size,1,2*hidden_size)
        '''
        weighted_sum = context*attention_softmaxed # weighted_sum.shape = (batch_size,encoder_seq_len,2*hidden_size)

        context_out = self.dense(torch.sum(weighted_sum,dim=1)) # context_out = (batch_size,1,2*hidden_size)
        return context_out
        '''

        return self.dense(weighted_sum)





class Decoder(nn.Module):
    def __init__(self,vocab_size,embedding_size,hidden_size,dropout = 0.2):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size,embedding_size)
        self.rnn = nn.LSTM(embedding_size+hidden_size,hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(hidden_size,vocab_size)
        self.attention = Attention(hidden_size)
        self.dense = nn.Linear(hidden_size,vocab_size)




    def forward(self,y,hidden,context,x_len):
        # Decdoer会一个单词一个单词的解码,所以这里的y就是一个单词,而不是像以前把一句话都传进来。
        # y.shape = (batch_size,1)
        embed = self.dropout(self.embed(y)) # embed.shape = (batch_size,1,embeding_size)
        mask = torch.arange(max(x_len), device=x_len.device)[None, :] < x_len[:, None]
        mask = mask.logical_not()
        c_t = self.attention(hidden[0], context, mask.unsqueeze(2)).squeeze(1) #c_t.shape = (batch_size,1,hidden_size)
        #print(c_t.shape,embed.shape)
        lstm_in = torch.cat([c_t,embed],dim=-1).unsqueeze(1) # lstm_in.shape = (batch_size,1,hidden_size+embeding_size)
        out,hid = self.rnn(lstm_in,hidden)
        # out.shape = (batch_size,1,hidden_size)
        # hid.shape = (1,batch_size,hidden_size)
        out = self.dense(out) # out.shape = (batch_size,1,vocab_size)
        out = torch.nn.functional.log_softmax(out,dim = -1)
        return out,hid

数据

数据生成的方法同第一部分luong attention

训练

训练方法对比luong attention会稍显复杂一些。因为luong attention的decoder是一次性把整句话一起decoder 得到的output也一起进行loss的计算。Bahdanau attention却不行,因为Bahdanau attention decoder结构中需要把注意力得到的ct重新输入到lstm中,这就导致了每一个step,lstm的输入是不一样的,是依赖于上一时刻的hidden state的。所以在训练的时候我们也需要一个step 一个step的去跑。最后把整句话的loss加起来,然后一起反向传播。

embedding_size = 1000
hidden_size = 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(enu_vocab_size,embedding_size,hidden_size).to(device)
decoder = Decoder(chs_vocab_size,embedding_size,hidden_size).to(device)


loss_fn = LanguageModelLoss().to(device)
optimizer = torch.optim.Adam([{'params':encoder.parameters()},{'params':decoder.parameters()}])
#optimizer = torch.optim.Adam(encoder.parameters())
encoder_path = "./mt_encoder_model.pth"
decoder_path = "./mt_decoder_model.pth"
def train_step(mb_x,x_len,mb_y):
    y_len = mb_y.shape[1]
    encoder_out,encoder_hid = encoder(mb_x,x_len)
    decoder_hid = encoder_hid
    decoder_input = mb_y[:,0]
    batch_loss = 0
    '''
        for i in range(1,y_len):
        if i !=1:
            encoder_out, encoder_hid = encoder(mb_x, x_len)

        decoder_out, _ = decoder(decoder_input,decoder_hid,encoder_out,x_len)
        #decoder_hid = _
        decoder_hid = (_[0].detach(),_[1].detach())
        # mask.shape = (batch_size,1)
        mask = torch.eq(decoder_input,torch.tensor(0)).logical_not().bool().unsqueeze(1)
        decoder_input = mb_y[:,i]

        loss = loss_fn(decoder_out,decoder_input,mask)
        batch_loss+=loss
        optimizer.zero_grad()
        loss.backward()
        print('yes')
        #nn.utils.clip_grad_norm()
        optimizer.step()
    '''
    for i in range(1,y_len):
        '''
        if i !=1:
            encoder_out, encoder_hid = encoder(mb_x, x_len)
            '''

        decoder_out, _ = decoder(decoder_input,decoder_hid,encoder_out,x_len)
        #decoder_hid = _
        decoder_hid = (_[0].detach(),_[1].detach())
        # mask.shape = (batch_size,1)
        mask = torch.eq(decoder_input,torch.tensor(0)).logical_not().bool().unsqueeze(1)
        decoder_input = mb_y[:,i]

        loss = loss_fn(decoder_out,decoder_input,mask)
        batch_loss+=loss
    optimizer.zero_grad()
    batch_loss.backward()

    optimizer.step()
    return batch_loss

def eval_step(mb_x,x_len,mb_y):
    y_len = mb_y.shape[1]
    encoder_out,encoder_hid = encoder(mb_x,x_len)
    decoder_hid = encoder_hid
    decoder_input = mb_y[:,0]
    batch_loss = 0
    for i in range(1,y_len):
        if i !=1:
            encoder_out, encoder_hid = encoder(mb_x, x_len)
        decoder_out, decoder_hid = decoder(decoder_input,decoder_hid,encoder_out,x_len)
        #decoder_hid = (_[0].detach(),_[1].detach())
        # mask.shape = (batch_size,1)
        mask = torch.eq(decoder_input,torch.tensor(0)).logical_not().bool().unsqueeze(1)
        decoder_input = mb_y[:,i]

        loss = loss_fn(decoder_out,decoder_input,mask)
        batch_loss+=loss
    return batch_loss



def train(encoder,decoder,data,epochs):
    encoder.train()
    decoder.train()
    eval_loss = 2500
    for epoch in range(epochs):
        for i,(mb_x,mb_x_len,mb_y,mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device)
            mb_x_len = torch.from_numpy(mb_x_len).to(device)
            mb_y= torch.from_numpy(mb_y).to(device)

            batchloss = train_step(mb_x,mb_x_len,mb_y)
            print("epoch:{} iter:{} loss:{}".format(epoch,i,batchloss))
            if i%10 == 0:
                result = eval(encoder,decoder,dev_data)
                if result<eval_loss:
                    eval_loss = result
                    torch.save(encoder.state_dict(),encoder_path)
                    torch.save(decoder.state_dict(), decoder_path)
            encoder.train()
            decoder.train()


def eval(encoder,decoder,data):
    encoder.eval()
    decoder.eval()
    eval_loss = 0
    with torch.no_grad():
        for i,(mb_x,mb_x_len,mb_y,mb_y_len) in enumerate(data):
            mb_x = torch.from_numpy(mb_x).to(device)
            mb_x_len = torch.from_numpy(mb_x_len).to(device)
            mb_y = torch.from_numpy(mb_y).to(device)

            eval_loss += eval_step(mb_x, mb_x_len, mb_y)
    print("Eval_loss:", eval_loss)
    return eval_loss