背景:
RNN(Recurrent Neural Networks),被国内翻译为循环神经网络,或者递归神经网络,窃以为这两种表述都不合理,应该称为:(深度)同参时序神经网络(下文展开讲述)。
RNN公式(来自:pytorch rnn):
\begin{align} h_t &=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{(t-1)}+b_{hh}) \end{align}
这个公式体现了每层RNN的输入(input gate)计算和隐藏状态(hidden state)的计算过程,这是RNN每一层的计算公式,其中
$W_{ih}$ $W_{hh}$ 分别代表该层 计算input gate的weight参数和计算hidden state的weight参数
$b_{ih}$ $b_{hh}$代表该层对应的bias,
$h_{(t-1)}$代表上一个timestep的hidden state (对于时序序列sequence中的第一个样本,$h_{(t-1)}$ 即$h_0$可随机初始化生成)
需要注意:
1. $x_t$,对于第一层layer,$x_t$自然是训练样本,那第二层的${x_t}$是什么,还是训练样本? 带着这个疑问,看下文
2.每层RNN的计算过程都是公式(1),每层RNN都有且只有参数:$W_{ih}$ $W_{hh}$ $b_{ih}$ $b_{hh}$ 即每层四个参数变量(无论序列长度sequence length是多少、无论一个样本的维度dimension是多少),$x_t$和$h_{(t-1)}$是中间的计算结果,不是参数。即,一层的RNN,有4个参数变量,两层的RNN,有8个参数变量,N层的RNN有4N个参数变量。当然,每个参数变量都是参数矩阵Matrix。 知道模型的参数数量,模型的结构就基本能清楚了。所以这个注意点非常重要。
模型结构:
网上有很多对RNN模型结构的解释图。
有这样的:
不好意思,上传错了。
有这样的,下图1
有这样的,下图2
有这样的,下图3
也有这样的,下图4
窃以为能看得更明白的,是下图5 来自知乎 Scofield
下面这个动画,是 维度为3的样本,多层RNN的计算过程。窃以为,这个视频最受启发。感谢 知乎刘大力 动画来自知乎文章(若动画看不了,请移步至该链接)
以上图片和视频,动画说明了sequence中各t时刻的一个样本输入RNN模型中的计算流图,但窃并不认可其对于模型输出Y的体现。
一个训练样本,有t个时刻timestep,每个时刻的样本有多个维度 dimension。动画中,t = 5,dimension = 3 。
窃以为,一层RNN模型, 共四个参数变量,$W_{ih}$ $W_{hh}$ $b_{ih}$ $b_{hh}$,不同时刻的训练样本同参(同参数变量,变量的值可以变化);序列中的样本,按照先后顺序与这四个参数变量计算,hidden state $h_{(t-1)}$来自上一时刻的计算结果,即时序依赖。 所以,一层RNN应该被称为同参时序神经网络,多层RNN应该叫:深度同参时序神经网络。这样称呼,虽然不利于传播,但易于理解。
两层的RNN模型,则是第一个时刻的训练样本先与第一层的4个参数按照公式(1)计算,计算出的hidden state $h_{(t)_{l0}}$值暂存,并传给 第二层公式(1)中的$x_t$,计算出的hidden state $h_{(t)_{l1}}$ 暂存。暂存的两个 $h_{(t)_{l0}}$和$h_{(t)_{l1}}$ 值,参与第二个时刻的计算(变为$h_{(t-1)_{l0}}$和$h_{(t-1)_{l1}}$),即分别对应第一层公式(1)和第二层的公式(1)中的$h_{(t-1)}$ 。多层RNN,以此类推。由此可见,RNN有多少层,就会有多个hidden state $h_{(t)_{lx}}$,pytorch rnn文档中输出的h_n的shape体现了这一点 。 需要注意的点是:$h_{(t)_{l}}$的值,有两个用到的地方,下一层的$x_t$ 和下一时刻的$h_{(t-1)_{l}}$ 。这一点,在网上很多RNN结构图中体现得不清楚。
看到这里,再来看下面这张结构图7(不同时刻同参、每个hidden state有两个去处):
有没有更理解一点?
代码验证,手写RNN与pytorch 官方RNN对比结果(一层RNN) :
import torch
from torch import nn
#network parameters
input_size = 10
hidden_size = 20
num_layers = 1 #fixed, can't change. Only one layer for this demo.
#data parameters
seq_len = 5
batch_size = 3
data_dim = input_size
#input data
data = torch.randn(seq_len, batch_size, data_dim)
#official rnn in pytorch
ornn = nn.RNN(input_size, hidden_size, num_layers)
#init hidden state
h0 = torch.randn(num_layers,batch_size,hidden_size)
#rnn implemented by myself
class MyRNN():
def __init__(self):
#keep weights and bias parameters the same with official rnn
# to make the compare with official rnn by final result
self.W_ih = torch.nn.Parameter(ornn.weight_ih_l0.T)
self.b_ih = torch.nn.Parameter(ornn.bias_ih_l0)
self.W_hh = torch.nn.Parameter(ornn.weight_hh_l0.T)
self.b_hh = torch.nn.Parameter(ornn.bias_hh_l0)
self.ht = torch.nn.Parameter(h0)
self.myoutput = []
def forward(self,x): #x shape: (seq_len,batch_size,data_dim)
for i in range(seq_len): #this line is the KEY to understand RNN. Important!
igates = torch.matmul(x[i],self.W_ih) + self.b_ih
hgates = torch.matmul(self.ht,self.W_hh) + self.b_hh
self.ht = torch.tanh(igates + hgates)#this line is the formula of RNN. Important!
self.myoutput.append(self.ht)
return self.ht,self.myoutput
myrnn = MyRNN()
myht,myoutput = myrnn.forward(data)
official_output,official_hn = ornn(data,h0)
print ('myht:')
print (myht)
print ('official_hn:')
print (official_hn)
print ("--" * 40)
print ('myoutput:')
print (myoutput)
print ('official_output:')
print (official_output)
输出结果:
myht:
tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690,
-0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505],
[-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828,
-0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967,
-0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153],
[ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230,
0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637,
-0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]],
grad_fn=<TanhBackward>)
official_hn:
tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690,
-0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505],
[-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828,
-0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967,
-0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153],
[ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230,
0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637,
-0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]],
grad_fn=<StackBackward>)
--------------------------------------------------------------------------------
myoutput:
[tensor([[[ 0.1838, -0.5729, 0.7425, -0.1386, 0.4525, -0.0928, 0.4676,
0.1947, -0.2111, -0.2790, -0.3584, 0.1215, -0.5577, 0.3709,
0.9216, 0.0695, 0.0420, -0.5991, -0.8501, 0.4155],
[-0.0024, -0.5132, -0.6784, 0.7312, -0.1101, -0.4194, 0.1185,
0.4437, -0.5395, 0.8785, -0.6332, -0.5439, -0.4265, 0.1511,
-0.0327, 0.4625, -0.4097, -0.9240, -0.6085, 0.3099],
[ 0.1994, 0.6158, 0.9422, 0.8493, -0.6427, 0.0086, 0.0350,
0.1801, -0.8858, 0.4427, -0.2625, 0.7059, -0.4321, 0.5412,
0.5879, 0.5385, -0.2290, -0.8183, -0.4205, -0.7687]]],
grad_fn=<TanhBackward>), tensor([[[ 0.3837, -0.0271, 0.1710, 0.5887, -0.1873, -0.0959, 0.3320,
0.0613, 0.3565, -0.7313, -0.2641, -0.8821, 0.7630, 0.2369,
0.5095, -0.7738, 0.0350, 0.1001, 0.4966, 0.4144],
[-0.1493, -0.3873, 0.6141, 0.1870, 0.0825, -0.0518, 0.0583,
0.3066, 0.6362, 0.1345, -0.2821, 0.0061, -0.3376, -0.2284,
0.1351, 0.3951, 0.0039, -0.6607, -0.1473, 0.6156],
[ 0.8971, -0.1361, 0.0733, 0.5407, -0.5882, -0.4531, 0.2926,
0.5090, 0.4893, -0.2589, 0.1735, -0.1201, -0.0110, -0.4264,
0.3931, 0.0637, 0.5885, 0.4706, 0.1418, 0.3165]]],
grad_fn=<TanhBackward>), tensor([[[ 0.3517, -0.7295, -0.0883, -0.6818, 0.3883, 0.3556, -0.1627,
-0.1085, 0.6256, 0.8205, -0.6915, 0.5160, -0.0390, 0.3519,
-0.0271, 0.0300, 0.0965, -0.3939, -0.0956, 0.2624],
[ 0.5152, 0.0578, 0.4200, -0.6778, -0.6455, -0.1427, -0.2189,
0.1818, -0.1449, 0.1035, -0.6252, 0.7734, -0.5083, 0.6138,
0.4270, 0.5684, 0.6656, 0.5341, -0.0336, 0.6554],
[-0.2308, 0.4569, 0.2901, -0.1212, -0.4826, -0.2699, 0.2559,
-0.3331, -0.0299, 0.0830, 0.2832, -0.5203, -0.0953, -0.3784,
-0.1478, -0.1610, -0.3416, -0.7735, 0.4389, 0.4663]]],
grad_fn=<TanhBackward>), tensor([[[-0.4632, -0.7146, -0.0497, -0.4927, 0.0778, 0.6394, 0.0383,
-0.6022, 0.4774, -0.0682, -0.1731, -0.5328, -0.2757, 0.1885,
0.6235, -0.0990, -0.3720, 0.0275, 0.4964, 0.7343],
[ 0.7086, -0.7316, -0.7619, 0.4543, -0.0888, 0.5574, 0.1033,
0.4042, 0.4909, 0.2489, -0.6275, -0.9121, 0.4050, 0.5086,
0.1161, -0.3312, -0.0297, 0.0204, 0.0442, 0.5536],
[-0.0543, 0.0078, -0.8657, 0.6617, -0.2335, -0.0423, 0.2600,
0.1319, -0.2510, 0.3286, 0.1862, -0.6161, -0.1817, 0.4460,
-0.6628, -0.1969, 0.5526, -0.8781, -0.4859, 0.6430]]],
grad_fn=<TanhBackward>), tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690,
-0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505],
[-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828,
-0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967,
-0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153],
[ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230,
0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637,
-0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]],
grad_fn=<TanhBackward>)]
official_output:
tensor([[[ 0.1838, -0.5729, 0.7425, -0.1386, 0.4525, -0.0928, 0.4676,
0.1947, -0.2111, -0.2790, -0.3584, 0.1215, -0.5577, 0.3709,
0.9216, 0.0695, 0.0420, -0.5991, -0.8501, 0.4155],
[-0.0024, -0.5132, -0.6784, 0.7312, -0.1101, -0.4194, 0.1185,
0.4437, -0.5395, 0.8785, -0.6332, -0.5439, -0.4265, 0.1511,
-0.0327, 0.4625, -0.4097, -0.9240, -0.6085, 0.3099],
[ 0.1994, 0.6158, 0.9422, 0.8493, -0.6427, 0.0086, 0.0350,
0.1801, -0.8858, 0.4427, -0.2625, 0.7059, -0.4321, 0.5412,
0.5879, 0.5385, -0.2290, -0.8183, -0.4205, -0.7687]],
[[ 0.3837, -0.0271, 0.1710, 0.5887, -0.1873, -0.0959, 0.3320,
0.0613, 0.3565, -0.7313, -0.2641, -0.8821, 0.7630, 0.2369,
0.5095, -0.7738, 0.0350, 0.1001, 0.4966, 0.4144],
[-0.1493, -0.3873, 0.6141, 0.1870, 0.0825, -0.0518, 0.0583,
0.3066, 0.6362, 0.1345, -0.2821, 0.0061, -0.3376, -0.2284,
0.1351, 0.3951, 0.0039, -0.6607, -0.1473, 0.6156],
[ 0.8971, -0.1361, 0.0733, 0.5407, -0.5882, -0.4531, 0.2926,
0.5090, 0.4893, -0.2589, 0.1735, -0.1201, -0.0110, -0.4264,
0.3931, 0.0637, 0.5885, 0.4706, 0.1418, 0.3165]],
[[ 0.3517, -0.7295, -0.0883, -0.6818, 0.3883, 0.3556, -0.1627,
-0.1085, 0.6256, 0.8205, -0.6915, 0.5160, -0.0390, 0.3519,
-0.0271, 0.0300, 0.0965, -0.3939, -0.0956, 0.2624],
[ 0.5152, 0.0578, 0.4200, -0.6778, -0.6455, -0.1427, -0.2189,
0.1818, -0.1449, 0.1035, -0.6252, 0.7734, -0.5083, 0.6138,
0.4270, 0.5684, 0.6656, 0.5341, -0.0336, 0.6554],
[-0.2308, 0.4569, 0.2901, -0.1212, -0.4826, -0.2699, 0.2559,
-0.3331, -0.0299, 0.0830, 0.2832, -0.5203, -0.0953, -0.3784,
-0.1478, -0.1610, -0.3416, -0.7735, 0.4389, 0.4663]],
[[-0.4632, -0.7146, -0.0497, -0.4927, 0.0778, 0.6394, 0.0383,
-0.6022, 0.4774, -0.0682, -0.1731, -0.5328, -0.2757, 0.1885,
0.6235, -0.0990, -0.3720, 0.0275, 0.4964, 0.7343],
[ 0.7086, -0.7316, -0.7619, 0.4543, -0.0888, 0.5574, 0.1033,
0.4042, 0.4909, 0.2489, -0.6275, -0.9121, 0.4050, 0.5086,
0.1161, -0.3312, -0.0297, 0.0204, 0.0442, 0.5536],
[-0.0543, 0.0078, -0.8657, 0.6617, -0.2335, -0.0423, 0.2600,
0.1319, -0.2510, 0.3286, 0.1862, -0.6161, -0.1817, 0.4460,
-0.6628, -0.1969, 0.5526, -0.8781, -0.4859, 0.6430]],
[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
0.3459, -0.0215, 0.1415, -0.3584, 0.0867, 0.5099, 0.5690,
-0.2088, 0.2408, 0.5786, 0.2619, 0.5168, 0.3505],
[-0.1239, -0.6814, -0.0618, -0.2011, 0.2991, 0.2501, -0.4828,
-0.2206, 0.4686, 0.8130, -0.2484, 0.4848, -0.4787, -0.1967,
-0.2221, 0.5839, 0.1737, -0.4704, -0.0062, 0.3153],
[ 0.2130, 0.4595, 0.5993, 0.5096, -0.0535, -0.2761, 0.1230,
0.4049, -0.3259, 0.3720, 0.2531, -0.5754, 0.0438, -0.1637,
-0.0579, 0.1214, 0.7484, -0.8096, -0.2217, 0.1534]]],
grad_fn=<StackBackward>)
两者一致,说明模型结构正确。
代码验证,手写RNN与pytorch 官方RNN对比结果(两层RNN):
import torch
from torch import nn
#network parameters
input_size = 10
hidden_size = 20
num_layers = 2
#data parameters
seq_len = 5
batch_size = 3
data_dim = input_size
data = torch.randn(seq_len, batch_size, data_dim)
#original official rnn in pytorch
ornn = nn.RNN(input_size, hidden_size, num_layers)
h0 = torch.randn(num_layers,batch_size,hidden_size)
class MyRNN():
def __init__(self):
#input_size, hidden_size
self.W_ih = torch.nn.Parameter(ornn.weight_ih_l0.T)
self.b_ih = torch.nn.Parameter(ornn.bias_ih_l0)
self.W_hh = torch.nn.Parameter(ornn.weight_hh_l0.T)
self.b_hh = torch.nn.Parameter(ornn.bias_hh_l0)
self.ht = torch.nn.Parameter(h0)
self.myoutput = []
if num_layers == 2:
self.ht = torch.nn.Parameter(h0[0])
self.ht1 = torch.nn.Parameter(h0[1])
self.W_ih_l1 = torch.nn.Parameter(ornn.weight_ih_l1.T)
self.b_ih_l1 = torch.nn.Parameter(ornn.bias_ih_l1)
self.W_hh_l1 = torch.nn.Parameter(ornn.weight_hh_l1.T)
self.b_hh_l1 = torch.nn.Parameter(ornn.bias_hh_l1)
def forward(self,x): #x: (seq_len,batch_size,data_dim)
for i in range(seq_len):
#the first layer. apply the formula
igates = torch.matmul(x[i],self.W_ih) + self.b_ih
hgates = torch.matmul(self.ht,self.W_hh) + self.b_hh #ht read from the early timestep.
self.ht = torch.tanh(igates + hgates) #ht update
if num_layers == 2:
#the second layer. apply the formula
igates = torch.matmul(self.ht,self.W_ih_l1) + self.b_ih_l1 #ht read from the the first layer. important!
hgates = torch.matmul(self.ht1,self.W_hh_l1) + self.b_hh_l1 #ht1 read from the early timestep.
self.ht1 = torch.tanh(igates + hgates) #ht1 update
ht_final_layer = [self.ht,self.ht1]
self.myoutput.append(self.ht1) #important. just ht1 ,the output of last layer.
return ht_final_layer,self.myoutput
myrnn = MyRNN()
myht,myoutput = myrnn.forward(data)
official_output,official_hn = ornn(data,h0)
print ('myht:')
print (myht)
print ('official_hn:')
print (official_hn)
print ("--" * 40)
print ('myoutput:')
print (myoutput)
print ('official_output')
执行结果:
myht:
[tensor([[-0.0386, 0.0588, 0.3025, -0.6304, 0.2505, -0.2632, 0.0101, -0.6417,
0.2560, -0.1788, 0.3951, -0.3890, 0.5895, 0.1630, 0.1462, -0.6854,
-0.1802, -0.3126, -0.8059, -0.1910],
[-0.3681, 0.2041, 0.2560, 0.6034, 0.1888, 0.0478, 0.4822, 0.0652,
-0.7043, -0.2169, 0.2462, 0.1334, -0.1881, 0.4579, -0.0285, 0.1425,
0.3664, 0.4980, 0.2442, -0.5373],
[-0.5242, -0.0747, -0.4040, -0.0835, 0.6314, 0.1566, 0.2049, -0.1784,
-0.2990, -0.3908, -0.2911, -0.2110, 0.6358, 0.4597, 0.0701, -0.3386,
0.5218, -0.5246, -0.3237, 0.0551]], grad_fn=<TanhBackward>), tensor([[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01,
3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01,
-1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01,
5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2619e-02, -1.0607e-01],
[-5.0548e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01,
-8.8198e-02, 3.6017e-01, -7.5253e-02, -1.5342e-01, -1.3452e-01,
-1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02,
-1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01],
[-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01,
-1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01,
-1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02,
1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]],
grad_fn=<TanhBackward>)]
official_hn:
tensor([[[-3.8583e-02, 5.8803e-02, 3.0251e-01, -6.3039e-01, 2.5051e-01,
-2.6322e-01, 1.0055e-02, -6.4175e-01, 2.5604e-01, -1.7878e-01,
3.9513e-01, -3.8902e-01, 5.8954e-01, 1.6296e-01, 1.4621e-01,
-6.8542e-01, -1.8024e-01, -3.1264e-01, -8.0587e-01, -1.9096e-01],
[-3.6807e-01, 2.0406e-01, 2.5604e-01, 6.0344e-01, 1.8878e-01,
4.7830e-02, 4.8223e-01, 6.5184e-02, -7.0430e-01, -2.1692e-01,
2.4618e-01, 1.3339e-01, -1.8806e-01, 4.5792e-01, -2.8516e-02,
1.4252e-01, 3.6637e-01, 4.9800e-01, 2.4424e-01, -5.3730e-01],
[-5.2422e-01, -7.4715e-02, -4.0400e-01, -8.3507e-02, 6.3144e-01,
1.5658e-01, 2.0493e-01, -1.7839e-01, -2.9904e-01, -3.9076e-01,
-2.9111e-01, -2.1097e-01, 6.3583e-01, 4.5969e-01, 7.0081e-02,
-3.3865e-01, 5.2179e-01, -5.2456e-01, -3.2368e-01, 5.5066e-02]],
[[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01,
3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01,
-1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01,
5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2618e-02, -1.0607e-01],
[-5.0455e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01,
-8.8198e-02, 3.6017e-01, -7.5252e-02, -1.5342e-01, -1.3452e-01,
-1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02,
-1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01],
[-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01,
-1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01,
-1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02,
1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]]],
grad_fn=<StackBackward>)
--------------------------------------------------------------------------------
myoutput:
[tensor([[ 1.2620e-01, 6.0072e-01, -7.1112e-02, 2.0916e-01, -1.7033e-01,
-5.8128e-02, 3.4290e-01, 5.7120e-01, 7.6652e-04, 4.7431e-01,
-9.7752e-02, -6.9819e-01, 4.4204e-02, 1.8705e-01, 3.7682e-01,
-3.2877e-01, 3.3991e-01, -9.3203e-01, -4.5387e-01, -7.6271e-01],
[ 6.4547e-01, -3.3936e-01, -6.3192e-01, -3.2661e-02, -5.8965e-01,
-7.6409e-01, 1.3470e-01, -3.7835e-01, -2.0378e-01, -1.6322e-01,
-6.0952e-01, 2.9986e-02, 7.8969e-02, -6.4902e-01, 7.5271e-01,
1.7919e-01, 6.5517e-01, -6.5625e-01, 3.2050e-01, -3.4623e-01],
[-5.5407e-01, 2.0340e-01, 4.1821e-01, -9.7931e-02, 4.2492e-01,
8.5182e-01, 7.9682e-02, 7.5144e-01, -2.0973e-01, -1.3963e-01,
4.5111e-01, -5.1502e-01, -3.1101e-01, 8.7050e-02, 7.7077e-01,
4.9754e-01, -1.6914e-01, 5.5128e-01, -7.0215e-01, 2.6817e-01]],
grad_fn=<TanhBackward>), tensor([[-0.2918, 0.5013, 0.0336, -0.3569, -0.5727, -0.1577, -0.1704, 0.3353,
-0.4692, -0.2399, 0.3714, 0.3964, -0.2294, 0.0909, 0.1388, -0.1164,
0.2566, -0.4419, 0.6232, -0.5399],
[-0.7720, 0.3316, 0.4893, 0.4173, 0.1900, 0.5988, 0.2782, -0.3852,
0.1218, -0.1172, -0.4391, 0.1240, 0.3925, 0.3963, -0.5687, 0.2115,
0.4115, 0.5132, -0.1591, -0.1080],
[ 0.1837, 0.2649, 0.6524, 0.2677, 0.0456, 0.2033, -0.0522, 0.4843,
-0.4531, 0.4153, 0.0187, -0.6308, 0.1819, -0.5004, 0.6018, 0.4021,
0.4913, -0.5287, 0.1526, -0.1455]], grad_fn=<TanhBackward>), tensor([[ 0.1872, -0.1069, 0.4237, 0.4201, -0.6734, 0.0836, -0.0252, 0.2273,
-0.2810, -0.0137, -0.2922, -0.3051, -0.2602, -0.4907, 0.0777, 0.1137,
0.2030, -0.1614, -0.0779, -0.2083],
[-0.0990, 0.3498, 0.5492, -0.3256, 0.2025, 0.3302, -0.5011, -0.1571,
0.0209, 0.2982, 0.1901, -0.6905, 0.2419, -0.5201, 0.3651, 0.3990,
0.5685, -0.4665, 0.0143, -0.1595],
[-0.5264, -0.0514, 0.1115, 0.3346, -0.2498, -0.0302, 0.4115, 0.3076,
-0.5988, -0.0438, -0.3437, 0.1128, 0.2481, -0.0956, -0.2785, -0.1713,
0.2296, -0.1200, 0.0860, -0.2926]], grad_fn=<TanhBackward>), tensor([[-0.2957, 0.1804, 0.3002, 0.0617, -0.1344, 0.1993, -0.3224, 0.4173,
-0.0781, 0.3736, -0.2150, 0.2653, -0.0528, 0.0651, -0.0500, 0.2519,
-0.0915, -0.2620, -0.2110, -0.5948],
[-0.1506, 0.4123, 0.0162, 0.1171, 0.0414, -0.0956, -0.2576, 0.4046,
-0.6677, -0.0049, -0.2525, -0.2696, 0.2976, -0.4672, -0.0190, 0.1525,
0.2290, -0.4887, 0.0049, -0.7503],
[-0.2533, -0.2999, 0.0536, -0.4347, -0.4320, 0.2809, -0.2127, -0.5016,
-0.2124, 0.3309, -0.4574, -0.1008, -0.1006, -0.2328, 0.3993, 0.0364,
0.6901, 0.1125, 0.4137, 0.6626]], grad_fn=<TanhBackward>), tensor([[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01,
3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01,
-1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01,
5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2619e-02, -1.0607e-01],
[-5.0548e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01,
-8.8198e-02, 3.6017e-01, -7.5253e-02, -1.5342e-01, -1.3452e-01,
-1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02,
-1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01],
[-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01,
-1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01,
-1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02,
1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]],
grad_fn=<TanhBackward>)]
official_output
tensor([[[ 1.2620e-01, 6.0072e-01, -7.1112e-02, 2.0916e-01, -1.7033e-01,
-5.8128e-02, 3.4290e-01, 5.7120e-01, 7.6649e-04, 4.7431e-01,
-9.7752e-02, -6.9819e-01, 4.4204e-02, 1.8705e-01, 3.7682e-01,
-3.2877e-01, 3.3991e-01, -9.3203e-01, -4.5387e-01, -7.6271e-01],
[ 6.4547e-01, -3.3936e-01, -6.3193e-01, -3.2661e-02, -5.8965e-01,
-7.6409e-01, 1.3470e-01, -3.7835e-01, -2.0378e-01, -1.6322e-01,
-6.0952e-01, 2.9986e-02, 7.8969e-02, -6.4902e-01, 7.5271e-01,
1.7919e-01, 6.5517e-01, -6.5625e-01, 3.2050e-01, -3.4623e-01],
[-5.5407e-01, 2.0340e-01, 4.1821e-01, -9.7931e-02, 4.2492e-01,
8.5182e-01, 7.9682e-02, 7.5144e-01, -2.0973e-01, -1.3963e-01,
4.5111e-01, -5.1502e-01, -3.1101e-01, 8.7050e-02, 7.7077e-01,
4.9754e-01, -1.6914e-01, 5.5128e-01, -7.0215e-01, 2.6817e-01]],
[[-2.9177e-01, 5.0127e-01, 3.3566e-02, -3.5687e-01, -5.7271e-01,
-1.5774e-01, -1.7043e-01, 3.3525e-01, -4.6915e-01, -2.3995e-01,
3.7142e-01, 3.9644e-01, -2.2941e-01, 9.0899e-02, 1.3878e-01,
-1.1636e-01, 2.5660e-01, -4.4189e-01, 6.2322e-01, -5.3986e-01],
[-7.7200e-01, 3.3155e-01, 4.8930e-01, 4.1734e-01, 1.8999e-01,
5.9885e-01, 2.7816e-01, -3.8521e-01, 1.2183e-01, -1.1717e-01,
-4.3911e-01, 1.2396e-01, 3.9253e-01, 3.9633e-01, -5.6871e-01,
2.1150e-01, 4.1146e-01, 5.1318e-01, -1.5914e-01, -1.0799e-01],
[ 1.8367e-01, 2.6493e-01, 6.5243e-01, 2.6774e-01, 4.5578e-02,
2.0329e-01, -5.2159e-02, 4.8428e-01, -4.5313e-01, 4.1533e-01,
1.8746e-02, -6.3081e-01, 1.8190e-01, -5.0044e-01, 6.0178e-01,
4.0211e-01, 4.9127e-01, -5.2867e-01, 1.5256e-01, -1.4553e-01]],
[[ 1.8723e-01, -1.0690e-01, 4.2369e-01, 4.2007e-01, -6.7342e-01,
8.3559e-02, -2.5240e-02, 2.2735e-01, -2.8096e-01, -1.3662e-02,
-2.9221e-01, -3.0512e-01, -2.6019e-01, -4.9072e-01, 7.7736e-02,
1.1373e-01, 2.0299e-01, -1.6141e-01, -7.7901e-02, -2.0833e-01],
[-9.8969e-02, 3.4982e-01, 5.4921e-01, -3.2558e-01, 2.0254e-01,
3.3020e-01, -5.0109e-01, -1.5706e-01, 2.0853e-02, 2.9821e-01,
1.9009e-01, -6.9054e-01, 2.4189e-01, -5.2012e-01, 3.6514e-01,
3.9902e-01, 5.6852e-01, -4.6647e-01, 1.4296e-02, -1.5953e-01],
[-5.2637e-01, -5.1397e-02, 1.1150e-01, 3.3456e-01, -2.4977e-01,
-3.0166e-02, 4.1154e-01, 3.0765e-01, -5.9878e-01, -4.3782e-02,
-3.4375e-01, 1.1282e-01, 2.4812e-01, -9.5623e-02, -2.7851e-01,
-1.7131e-01, 2.2957e-01, -1.1999e-01, 8.5984e-02, -2.9264e-01]],
[[-2.9568e-01, 1.8038e-01, 3.0018e-01, 6.1720e-02, -1.3442e-01,
1.9932e-01, -3.2239e-01, 4.1725e-01, -7.8142e-02, 3.7360e-01,
-2.1505e-01, 2.6528e-01, -5.2758e-02, 6.5120e-02, -4.9986e-02,
2.5186e-01, -9.1457e-02, -2.6198e-01, -2.1105e-01, -5.9480e-01],
[-1.5058e-01, 4.1227e-01, 1.6235e-02, 1.1707e-01, 4.1378e-02,
-9.5621e-02, -2.5761e-01, 4.0463e-01, -6.6765e-01, -4.8583e-03,
-2.5254e-01, -2.6960e-01, 2.9760e-01, -4.6718e-01, -1.9016e-02,
1.5246e-01, 2.2903e-01, -4.8867e-01, 4.9081e-03, -7.5035e-01],
[-2.5332e-01, -2.9992e-01, 5.3646e-02, -4.3469e-01, -4.3205e-01,
2.8094e-01, -2.1272e-01, -5.0162e-01, -2.1240e-01, 3.3086e-01,
-4.5738e-01, -1.0083e-01, -1.0064e-01, -2.3278e-01, 3.9928e-01,
3.6350e-02, 6.9007e-01, 1.1249e-01, 4.1367e-01, 6.6261e-01]],
[[-8.1422e-02, 1.8842e-02, 1.0144e-01, 3.8074e-02, -5.5781e-01,
3.0895e-01, -4.5643e-01, 1.6013e-01, -2.9781e-01, 1.9961e-01,
-1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01, 3.6203e-01,
5.9737e-02, 2.3827e-01, -1.1631e-01, -7.2618e-02, -1.0607e-01],
[-5.0455e-05, -1.7537e-01, -5.8831e-02, 1.4870e-01, -6.3641e-01,
-8.8198e-02, 3.6017e-01, -7.5252e-02, -1.5342e-01, -1.3452e-01,
-1.0299e-01, 5.4613e-02, 2.3648e-01, -4.3949e-01, 3.1918e-02,
-1.2566e-01, 5.4776e-01, 1.7124e-01, 4.7584e-01, -1.5331e-01],
[-2.6558e-01, 2.2919e-01, 5.0820e-01, 1.0909e-01, -4.1980e-01,
-1.5907e-01, 2.2342e-01, 5.8021e-02, -2.7018e-01, -3.4850e-01,
-1.8512e-01, 6.1239e-03, 2.4516e-01, -5.7344e-01, -7.5789e-02,
1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]]],
grad_fn=<StackBackward>)
可见,两个模型结果一致。手写模型结构无误。
结论: 通过理论和实践代码表明,RNN是同参时序神经网络(同参,亦称:权重共享),一层RNN有四个参数变量,每个timestep的样本均与这同样的四个参数变量计算。 高层依赖低一层的计算结果,当前时刻依赖前一时刻的计算结果。
本文若有不当之处,敬请批评指正。