今天用PyTorch参考《Python深度学习基于PyTorch》搭建了一个LSTM网络单元,在这里做一下笔记。
1.LSTM的原理
LSTM是RNN(循环神经网络)的变体,全名为长短期记忆网络(Long Short Term Memory networks)。
它的精髓在于引入了细胞状态这样一个概念,不同于RNN只考虑最近的状态,LSTM的细胞状态会决定哪些状态应该被留下来,哪些状态应该被遗忘。
具体与RNN的区别可参考这篇博文:LSTM与RNN的比较 先放一张LSTM网络的模型图:
如上图所示,可以看到这是一个网络,我们单拿出其中一个单元来进行分析,可见每一个单元都包含一系列运算,那么这些运算的意义是什么呢?下面我们来一一解释每个单元的具体内容。(1)遗忘门
ht-1 :前一个时刻的Cell的输出
xt : 当前时刻的输入
注意:中括号的意思是将ht-1与xt拼接起来,后面出现公式同理
遗忘门主要来判断上一状态中的输出对现状态的影响大小,遗忘门的输出要通过一个Sigmoid函数,Sigmoid函数的输出范围是0~1,相当于得到一个权重,后面与Ct-1相乘,以此得到上一状态输出对现状态的影响。
(2)输入门
输入门中会得到一个临界的细胞状态(Ct^),表示此状态下的备选输出,与it作用后就得到此次状态需要输出的内容。
由以上两个门就可以输出更新后的细胞状态Ct,输出公式如上图所示,需要注意这里的“ * ”符号为哈达玛乘积,就是对应矩阵元素相乘。(3)输出门
输出门具体运算过程如上图所示。这样就得到了这个时刻的输出,把这个输出再传入下一状态即可。
2.代码实现
初始化:
import torch
import torch.nn as nn
搭建一个LSTM单元:
class LSTMCell(nn.Module):
def __init__(self,input_size,hidden_size,cell_size,output_size):
super(LSTMCell,self).__init__()
self.hidden_size = hidden_size
self.cell_size = cell_size
#设定门输入输出数据的大小尺寸
self.gate = nn.Linear(input_size+hidden_size,cell_size)
self.output = nn.Linear(hidden_size,output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
#分类器-输出
self.softmax = nn.LogSoftmax(dim=1)
def forward(self,input,hidden,cell):
#拼接数据,后置的0/1 确定横向(1)还是竖向(0)拼接
combined = torch.cat((input,hidden),1)
#根据LSTM一个单元的网络图得出三个门,并进行运算
f_gate = self.sigmoid(self.gate(combined))
i_gate = self.sigmoid(self.gate(combined))
#z_state看作为Cell的中间状态
z_state = self.tanh(self.gate(combined))
o_gate = self.sigmoid(self.gate(combined))
#注意这下面的乘为哈达玛乘积,矩阵对应元素相乘
cell = torch.add(torch.mul(f_gate,cell),torch.mul(i_gate,z_state))
hidden = torch.mul(self.tanh(cell),o_gate)
output = self.output(hidden)
output = self.softmax(output)
return output,hidden,cell
def initHidden(self):
return torch.zeros(1,self.hidden_size)
def initCell(self):
return torch.zeros(1,self.cell_size)
实例化LSTMCell,并传入输入、隐含状态等进行验证:
lstmcell = LSTMCell(input_size=10,hidden_size=20,cell_size=20,output_size=10)
input = torch.randn(32,10)
h_0 = torch.randn(32,20)
c_0 = torch.randn(32,20)
output,hn,cn = lstmcell(input,h_0,c_0)
print(output.size(),hn.size(),cn.size())
输出结果:
torch.Size([32, 10]) torch.Size([32, 20]) torch.Size([32, 20])
end