背景:

    RNN(Recurrent Neural Networks),被国内翻译为循环神经网络,或者递归神经网络,窃以为这两种表述都不合理,应该称为:(深度)同参时序神经网络(下文展开讲述)。

    RNN公式(来自:pytorch rnn):

\begin{align} h_t &=tanh(W_{ih}x_t+b_{ih}+W_{hh}h_{(t-1)}+b_{hh}) \end{align}

     这个公式体现了每层RNN的输入(input gate)计算和隐藏状态(hidden state)的计算过程,这是RNN每一层的计算公式,其中

          $W_{ih}$ $W_{hh}$ 分别代表该层 计算input gate的weight参数和计算hidden state的weight参数

          $b_{ih}$ $b_{hh}$代表该层对应的bias,

          $h_{(t-1)}$代表上一个timestep的hidden state (对于时序序列sequence中的第一个样本,$h_{(t-1)}$ 即$h_0$可随机初始化生成)

          需要注意:

             1. $x_t$,对于第一层layer,$x_t$自然是训练样本,那第二层的${x_t}$是什么,还是训练样本? 带着这个疑问,看下文

             2.每层RNN的计算过程都是公式(1),每层RNN都有且只有参数:$W_{ih}$ $W_{hh}$ $b_{ih}$ $b_{hh}$ 即每层四个参数变量(无论序列长度sequence length是多少、无论一个样本的维度dimension是多少),$x_t$和$h_{(t-1)}$是中间的计算结果,不是参数。即,一层的RNN,有4个参数变量,两层的RNN,有8个参数变量,N层的RNN有4N个参数变量。当然,每个参数变量都是参数矩阵Matrix。 知道模型的参数数量,模型的结构就基本能清楚了。所以这个注意点非常重要。

 


 

模型结构:

      网上有很多对RNN模型结构的解释图。

      有这样的:

时间编码 脉冲神经网络 时序神经网络_知乎

         不好意思,上传错了。

         

 

有这样的,下图1

时间编码 脉冲神经网络 时序神经网络_知乎_02

         

 

有这样的,下图2         

时间编码 脉冲神经网络 时序神经网络_时间编码 脉冲神经网络_03

             

 

 

有这样的,下图3

时间编码 脉冲神经网络 时序神经网络_神经网络_04

               

 

 

也有这样的,下图4

时间编码 脉冲神经网络 时序神经网络_ci_05

           

                窃以为能看得更明白的,是下图5  来自知乎 Scofield

           

时间编码 脉冲神经网络 时序神经网络_神经网络_06

                                                     下面这个动画,是 维度为3的样本,多层RNN的计算过程。窃以为,这个视频最受启发。感谢 知乎刘大力  动画来自知乎文章(若动画看不了,请移步至该链接)

                                                                

动画截图6,如下:    

时间编码 脉冲神经网络 时序神经网络_神经网络_07

 

                 以上图片和视频,动画说明了sequence中各t时刻的一个样本输入RNN模型中的计算流图,但窃并不认可其对于模型输出Y的体现。

            一个训练样本,有t个时刻timestep,每个时刻的样本有多个维度 dimension。动画中,t = 5,dimension = 3 。

            窃以为,一层RNN模型, 共四个参数变量,$W_{ih}$ $W_{hh}$ $b_{ih}$ $b_{hh}$,不同时刻的训练样本同参(同参数变量,变量的值可以变化);序列中的样本,按照先后顺序与这四个参数变量计算,hidden state $h_{(t-1)}$来自上一时刻的计算结果,即时序依赖。  所以,一层RNN应该被称为同参时序神经网络,多层RNN应该叫:深度同参时序神经网络。这样称呼,虽然不利于传播,但易于理解。

            两层的RNN模型,则是第一个时刻的训练样本先与第一层的4个参数按照公式(1)计算,计算出的hidden state $h_{(t)_{l0}}$值暂存并传给 第二层公式(1)中的$x_t$,计算出的hidden state $h_{(t)_{l1}}$ 暂存。暂存的两个 $h_{(t)_{l0}}$和$h_{(t)_{l1}}$ 值,参与第二个时刻的计算(变为$h_{(t-1)_{l0}}$和$h_{(t-1)_{l1}}$),即分别对应第一层公式(1)和第二层的公式(1)中的$h_{(t-1)}$ 。多层RNN,以此类推。由此可见,RNN有多少层,就会有多个hidden state $h_{(t)_{lx}}$,pytorch rnn文档中输出的h_n的shape体现了这一点  。 需要注意的点是:$h_{(t)_{l}}$的值,有两个用到的地方,下一层的$x_t$ 和下一时刻的$h_{(t-1)_{l}}$ 。这一点,在网上很多RNN结构图中体现得不清楚。

             看到这里,再来看下面这张结构图7(不同时刻同参、每个hidden state有两个去处):

时间编码 脉冲神经网络 时序神经网络_ci_08

 有没有更理解一点?

 


 

代码验证,手写RNN与pytorch 官方RNN对比结果(一层RNN) :

import torch
from torch import nn

#network parameters
input_size = 10
hidden_size = 20
num_layers = 1 #fixed, can't change. Only one layer for this demo.

#data parameters
seq_len = 5
batch_size = 3
data_dim = input_size

#input data
data = torch.randn(seq_len, batch_size, data_dim)

#official rnn in pytorch
ornn = nn.RNN(input_size, hidden_size, num_layers)

#init hidden state
h0 = torch.randn(num_layers,batch_size,hidden_size)

#rnn implemented by myself
class MyRNN():
    def __init__(self):
        #keep weights and bias parameters the same with official rnn
        # to make the compare with official rnn by final result
        self.W_ih = torch.nn.Parameter(ornn.weight_ih_l0.T)
        self.b_ih = torch.nn.Parameter(ornn.bias_ih_l0)
        self.W_hh = torch.nn.Parameter(ornn.weight_hh_l0.T)
        self.b_hh = torch.nn.Parameter(ornn.bias_hh_l0)
        self.ht = torch.nn.Parameter(h0)
        self.myoutput = []

    def forward(self,x): #x shape: (seq_len,batch_size,data_dim)
        for i in range(seq_len): #this line is the KEY to understand RNN. Important!
            igates = torch.matmul(x[i],self.W_ih) + self.b_ih
            hgates = torch.matmul(self.ht,self.W_hh) + self.b_hh
            self.ht = torch.tanh(igates + hgates)#this line is the formula of RNN. Important!
            self.myoutput.append(self.ht)
        return self.ht,self.myoutput

myrnn = MyRNN()
myht,myoutput = myrnn.forward(data)
official_output,official_hn = ornn(data,h0)

print ('myht:')
print (myht)
print ('official_hn:')
print (official_hn)

print ("--" * 40)
print ('myoutput:')
print (myoutput)
print ('official_output:')
print (official_output)

输出结果:

myht:
tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
           0.3459, -0.0215,  0.1415, -0.3584,  0.0867,  0.5099,  0.5690,
          -0.2088,  0.2408,  0.5786,  0.2619,  0.5168,  0.3505],
         [-0.1239, -0.6814, -0.0618, -0.2011,  0.2991,  0.2501, -0.4828,
          -0.2206,  0.4686,  0.8130, -0.2484,  0.4848, -0.4787, -0.1967,
          -0.2221,  0.5839,  0.1737, -0.4704, -0.0062,  0.3153],
         [ 0.2130,  0.4595,  0.5993,  0.5096, -0.0535, -0.2761,  0.1230,
           0.4049, -0.3259,  0.3720,  0.2531, -0.5754,  0.0438, -0.1637,
          -0.0579,  0.1214,  0.7484, -0.8096, -0.2217,  0.1534]]],
       grad_fn=<TanhBackward>)
official_hn:
tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
           0.3459, -0.0215,  0.1415, -0.3584,  0.0867,  0.5099,  0.5690,
          -0.2088,  0.2408,  0.5786,  0.2619,  0.5168,  0.3505],
         [-0.1239, -0.6814, -0.0618, -0.2011,  0.2991,  0.2501, -0.4828,
          -0.2206,  0.4686,  0.8130, -0.2484,  0.4848, -0.4787, -0.1967,
          -0.2221,  0.5839,  0.1737, -0.4704, -0.0062,  0.3153],
         [ 0.2130,  0.4595,  0.5993,  0.5096, -0.0535, -0.2761,  0.1230,
           0.4049, -0.3259,  0.3720,  0.2531, -0.5754,  0.0438, -0.1637,
          -0.0579,  0.1214,  0.7484, -0.8096, -0.2217,  0.1534]]],
       grad_fn=<StackBackward>)
--------------------------------------------------------------------------------
myoutput:
[tensor([[[ 0.1838, -0.5729,  0.7425, -0.1386,  0.4525, -0.0928,  0.4676,
           0.1947, -0.2111, -0.2790, -0.3584,  0.1215, -0.5577,  0.3709,
           0.9216,  0.0695,  0.0420, -0.5991, -0.8501,  0.4155],
         [-0.0024, -0.5132, -0.6784,  0.7312, -0.1101, -0.4194,  0.1185,
           0.4437, -0.5395,  0.8785, -0.6332, -0.5439, -0.4265,  0.1511,
          -0.0327,  0.4625, -0.4097, -0.9240, -0.6085,  0.3099],
         [ 0.1994,  0.6158,  0.9422,  0.8493, -0.6427,  0.0086,  0.0350,
           0.1801, -0.8858,  0.4427, -0.2625,  0.7059, -0.4321,  0.5412,
           0.5879,  0.5385, -0.2290, -0.8183, -0.4205, -0.7687]]],
       grad_fn=<TanhBackward>), tensor([[[ 0.3837, -0.0271,  0.1710,  0.5887, -0.1873, -0.0959,  0.3320,
           0.0613,  0.3565, -0.7313, -0.2641, -0.8821,  0.7630,  0.2369,
           0.5095, -0.7738,  0.0350,  0.1001,  0.4966,  0.4144],
         [-0.1493, -0.3873,  0.6141,  0.1870,  0.0825, -0.0518,  0.0583,
           0.3066,  0.6362,  0.1345, -0.2821,  0.0061, -0.3376, -0.2284,
           0.1351,  0.3951,  0.0039, -0.6607, -0.1473,  0.6156],
         [ 0.8971, -0.1361,  0.0733,  0.5407, -0.5882, -0.4531,  0.2926,
           0.5090,  0.4893, -0.2589,  0.1735, -0.1201, -0.0110, -0.4264,
           0.3931,  0.0637,  0.5885,  0.4706,  0.1418,  0.3165]]],
       grad_fn=<TanhBackward>), tensor([[[ 0.3517, -0.7295, -0.0883, -0.6818,  0.3883,  0.3556, -0.1627,
          -0.1085,  0.6256,  0.8205, -0.6915,  0.5160, -0.0390,  0.3519,
          -0.0271,  0.0300,  0.0965, -0.3939, -0.0956,  0.2624],
         [ 0.5152,  0.0578,  0.4200, -0.6778, -0.6455, -0.1427, -0.2189,
           0.1818, -0.1449,  0.1035, -0.6252,  0.7734, -0.5083,  0.6138,
           0.4270,  0.5684,  0.6656,  0.5341, -0.0336,  0.6554],
         [-0.2308,  0.4569,  0.2901, -0.1212, -0.4826, -0.2699,  0.2559,
          -0.3331, -0.0299,  0.0830,  0.2832, -0.5203, -0.0953, -0.3784,
          -0.1478, -0.1610, -0.3416, -0.7735,  0.4389,  0.4663]]],
       grad_fn=<TanhBackward>), tensor([[[-0.4632, -0.7146, -0.0497, -0.4927,  0.0778,  0.6394,  0.0383,
          -0.6022,  0.4774, -0.0682, -0.1731, -0.5328, -0.2757,  0.1885,
           0.6235, -0.0990, -0.3720,  0.0275,  0.4964,  0.7343],
         [ 0.7086, -0.7316, -0.7619,  0.4543, -0.0888,  0.5574,  0.1033,
           0.4042,  0.4909,  0.2489, -0.6275, -0.9121,  0.4050,  0.5086,
           0.1161, -0.3312, -0.0297,  0.0204,  0.0442,  0.5536],
         [-0.0543,  0.0078, -0.8657,  0.6617, -0.2335, -0.0423,  0.2600,
           0.1319, -0.2510,  0.3286,  0.1862, -0.6161, -0.1817,  0.4460,
          -0.6628, -0.1969,  0.5526, -0.8781, -0.4859,  0.6430]]],
       grad_fn=<TanhBackward>), tensor([[[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
           0.3459, -0.0215,  0.1415, -0.3584,  0.0867,  0.5099,  0.5690,
          -0.2088,  0.2408,  0.5786,  0.2619,  0.5168,  0.3505],
         [-0.1239, -0.6814, -0.0618, -0.2011,  0.2991,  0.2501, -0.4828,
          -0.2206,  0.4686,  0.8130, -0.2484,  0.4848, -0.4787, -0.1967,
          -0.2221,  0.5839,  0.1737, -0.4704, -0.0062,  0.3153],
         [ 0.2130,  0.4595,  0.5993,  0.5096, -0.0535, -0.2761,  0.1230,
           0.4049, -0.3259,  0.3720,  0.2531, -0.5754,  0.0438, -0.1637,
          -0.0579,  0.1214,  0.7484, -0.8096, -0.2217,  0.1534]]],
       grad_fn=<TanhBackward>)]
official_output:
tensor([[[ 0.1838, -0.5729,  0.7425, -0.1386,  0.4525, -0.0928,  0.4676,
           0.1947, -0.2111, -0.2790, -0.3584,  0.1215, -0.5577,  0.3709,
           0.9216,  0.0695,  0.0420, -0.5991, -0.8501,  0.4155],
         [-0.0024, -0.5132, -0.6784,  0.7312, -0.1101, -0.4194,  0.1185,
           0.4437, -0.5395,  0.8785, -0.6332, -0.5439, -0.4265,  0.1511,
          -0.0327,  0.4625, -0.4097, -0.9240, -0.6085,  0.3099],
         [ 0.1994,  0.6158,  0.9422,  0.8493, -0.6427,  0.0086,  0.0350,
           0.1801, -0.8858,  0.4427, -0.2625,  0.7059, -0.4321,  0.5412,
           0.5879,  0.5385, -0.2290, -0.8183, -0.4205, -0.7687]],

        [[ 0.3837, -0.0271,  0.1710,  0.5887, -0.1873, -0.0959,  0.3320,
           0.0613,  0.3565, -0.7313, -0.2641, -0.8821,  0.7630,  0.2369,
           0.5095, -0.7738,  0.0350,  0.1001,  0.4966,  0.4144],
         [-0.1493, -0.3873,  0.6141,  0.1870,  0.0825, -0.0518,  0.0583,
           0.3066,  0.6362,  0.1345, -0.2821,  0.0061, -0.3376, -0.2284,
           0.1351,  0.3951,  0.0039, -0.6607, -0.1473,  0.6156],
         [ 0.8971, -0.1361,  0.0733,  0.5407, -0.5882, -0.4531,  0.2926,
           0.5090,  0.4893, -0.2589,  0.1735, -0.1201, -0.0110, -0.4264,
           0.3931,  0.0637,  0.5885,  0.4706,  0.1418,  0.3165]],

        [[ 0.3517, -0.7295, -0.0883, -0.6818,  0.3883,  0.3556, -0.1627,
          -0.1085,  0.6256,  0.8205, -0.6915,  0.5160, -0.0390,  0.3519,
          -0.0271,  0.0300,  0.0965, -0.3939, -0.0956,  0.2624],
         [ 0.5152,  0.0578,  0.4200, -0.6778, -0.6455, -0.1427, -0.2189,
           0.1818, -0.1449,  0.1035, -0.6252,  0.7734, -0.5083,  0.6138,
           0.4270,  0.5684,  0.6656,  0.5341, -0.0336,  0.6554],
         [-0.2308,  0.4569,  0.2901, -0.1212, -0.4826, -0.2699,  0.2559,
          -0.3331, -0.0299,  0.0830,  0.2832, -0.5203, -0.0953, -0.3784,
          -0.1478, -0.1610, -0.3416, -0.7735,  0.4389,  0.4663]],

        [[-0.4632, -0.7146, -0.0497, -0.4927,  0.0778,  0.6394,  0.0383,
          -0.6022,  0.4774, -0.0682, -0.1731, -0.5328, -0.2757,  0.1885,
           0.6235, -0.0990, -0.3720,  0.0275,  0.4964,  0.7343],
         [ 0.7086, -0.7316, -0.7619,  0.4543, -0.0888,  0.5574,  0.1033,
           0.4042,  0.4909,  0.2489, -0.6275, -0.9121,  0.4050,  0.5086,
           0.1161, -0.3312, -0.0297,  0.0204,  0.0442,  0.5536],
         [-0.0543,  0.0078, -0.8657,  0.6617, -0.2335, -0.0423,  0.2600,
           0.1319, -0.2510,  0.3286,  0.1862, -0.6161, -0.1817,  0.4460,
          -0.6628, -0.1969,  0.5526, -0.8781, -0.4859,  0.6430]],

        [[ 0.7169, -0.0977, -0.2668, -0.3212, -0.2726, -0.0469, -0.4713,
           0.3459, -0.0215,  0.1415, -0.3584,  0.0867,  0.5099,  0.5690,
          -0.2088,  0.2408,  0.5786,  0.2619,  0.5168,  0.3505],
         [-0.1239, -0.6814, -0.0618, -0.2011,  0.2991,  0.2501, -0.4828,
          -0.2206,  0.4686,  0.8130, -0.2484,  0.4848, -0.4787, -0.1967,
          -0.2221,  0.5839,  0.1737, -0.4704, -0.0062,  0.3153],
         [ 0.2130,  0.4595,  0.5993,  0.5096, -0.0535, -0.2761,  0.1230,
           0.4049, -0.3259,  0.3720,  0.2531, -0.5754,  0.0438, -0.1637,
          -0.0579,  0.1214,  0.7484, -0.8096, -0.2217,  0.1534]]],
       grad_fn=<StackBackward>)

两者一致,说明模型结构正确。

 

 

代码验证,手写RNN与pytorch 官方RNN对比结果(两层RNN):

import torch
from torch import nn

#network parameters
input_size = 10
hidden_size = 20
num_layers = 2

#data parameters
seq_len = 5
batch_size = 3
data_dim = input_size

data = torch.randn(seq_len, batch_size, data_dim)

#original official rnn in pytorch
ornn = nn.RNN(input_size, hidden_size, num_layers)
h0 = torch.randn(num_layers,batch_size,hidden_size)

class MyRNN():
    def __init__(self):
        #input_size, hidden_size
        self.W_ih = torch.nn.Parameter(ornn.weight_ih_l0.T)
        self.b_ih = torch.nn.Parameter(ornn.bias_ih_l0)
        self.W_hh = torch.nn.Parameter(ornn.weight_hh_l0.T)
        self.b_hh = torch.nn.Parameter(ornn.bias_hh_l0)
        self.ht = torch.nn.Parameter(h0)
        self.myoutput = []
        if num_layers == 2:
            self.ht = torch.nn.Parameter(h0[0])
            self.ht1 = torch.nn.Parameter(h0[1])
            self.W_ih_l1 = torch.nn.Parameter(ornn.weight_ih_l1.T)
            self.b_ih_l1 = torch.nn.Parameter(ornn.bias_ih_l1)
            self.W_hh_l1 = torch.nn.Parameter(ornn.weight_hh_l1.T)
            self.b_hh_l1 = torch.nn.Parameter(ornn.bias_hh_l1)

    def forward(self,x): #x: (seq_len,batch_size,data_dim)
        for i in range(seq_len):
            #the first layer. apply the formula
            igates = torch.matmul(x[i],self.W_ih) + self.b_ih
            hgates = torch.matmul(self.ht,self.W_hh) + self.b_hh #ht read from the early timestep.
            self.ht = torch.tanh(igates + hgates) #ht update
            if num_layers == 2:
                #the second layer. apply the formula
                igates = torch.matmul(self.ht,self.W_ih_l1) + self.b_ih_l1 #ht read from the the first layer. important!
                hgates = torch.matmul(self.ht1,self.W_hh_l1) + self.b_hh_l1 #ht1 read from the early timestep.
                self.ht1 = torch.tanh(igates + hgates) #ht1 update
            ht_final_layer = [self.ht,self.ht1]
            self.myoutput.append(self.ht1) #important. just ht1 ,the output of last layer.
        return ht_final_layer,self.myoutput

myrnn = MyRNN()
myht,myoutput = myrnn.forward(data)
official_output,official_hn = ornn(data,h0)

print ('myht:')
print (myht)
print ('official_hn:')
print (official_hn)

print ("--" * 40)
print ('myoutput:')
print (myoutput)
print ('official_output')

执行结果:

myht:
[tensor([[-0.0386,  0.0588,  0.3025, -0.6304,  0.2505, -0.2632,  0.0101, -0.6417,
          0.2560, -0.1788,  0.3951, -0.3890,  0.5895,  0.1630,  0.1462, -0.6854,
         -0.1802, -0.3126, -0.8059, -0.1910],
        [-0.3681,  0.2041,  0.2560,  0.6034,  0.1888,  0.0478,  0.4822,  0.0652,
         -0.7043, -0.2169,  0.2462,  0.1334, -0.1881,  0.4579, -0.0285,  0.1425,
          0.3664,  0.4980,  0.2442, -0.5373],
        [-0.5242, -0.0747, -0.4040, -0.0835,  0.6314,  0.1566,  0.2049, -0.1784,
         -0.2990, -0.3908, -0.2911, -0.2110,  0.6358,  0.4597,  0.0701, -0.3386,
          0.5218, -0.5246, -0.3237,  0.0551]], grad_fn=<TanhBackward>), tensor([[-8.1422e-02,  1.8842e-02,  1.0144e-01,  3.8074e-02, -5.5781e-01,
          3.0895e-01, -4.5643e-01,  1.6013e-01, -2.9781e-01,  1.9961e-01,
         -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01,  3.6203e-01,
          5.9737e-02,  2.3827e-01, -1.1631e-01, -7.2619e-02, -1.0607e-01],
        [-5.0548e-05, -1.7537e-01, -5.8831e-02,  1.4870e-01, -6.3641e-01,
         -8.8198e-02,  3.6017e-01, -7.5253e-02, -1.5342e-01, -1.3452e-01,
         -1.0299e-01,  5.4613e-02,  2.3648e-01, -4.3949e-01,  3.1918e-02,
         -1.2566e-01,  5.4776e-01,  1.7124e-01,  4.7584e-01, -1.5331e-01],
        [-2.6558e-01,  2.2919e-01,  5.0820e-01,  1.0909e-01, -4.1980e-01,
         -1.5907e-01,  2.2342e-01,  5.8021e-02, -2.7018e-01, -3.4850e-01,
         -1.8512e-01,  6.1239e-03,  2.4516e-01, -5.7344e-01, -7.5789e-02,
          1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]],
       grad_fn=<TanhBackward>)]
official_hn:
tensor([[[-3.8583e-02,  5.8803e-02,  3.0251e-01, -6.3039e-01,  2.5051e-01,
          -2.6322e-01,  1.0055e-02, -6.4175e-01,  2.5604e-01, -1.7878e-01,
           3.9513e-01, -3.8902e-01,  5.8954e-01,  1.6296e-01,  1.4621e-01,
          -6.8542e-01, -1.8024e-01, -3.1264e-01, -8.0587e-01, -1.9096e-01],
         [-3.6807e-01,  2.0406e-01,  2.5604e-01,  6.0344e-01,  1.8878e-01,
           4.7830e-02,  4.8223e-01,  6.5184e-02, -7.0430e-01, -2.1692e-01,
           2.4618e-01,  1.3339e-01, -1.8806e-01,  4.5792e-01, -2.8516e-02,
           1.4252e-01,  3.6637e-01,  4.9800e-01,  2.4424e-01, -5.3730e-01],
         [-5.2422e-01, -7.4715e-02, -4.0400e-01, -8.3507e-02,  6.3144e-01,
           1.5658e-01,  2.0493e-01, -1.7839e-01, -2.9904e-01, -3.9076e-01,
          -2.9111e-01, -2.1097e-01,  6.3583e-01,  4.5969e-01,  7.0081e-02,
          -3.3865e-01,  5.2179e-01, -5.2456e-01, -3.2368e-01,  5.5066e-02]],

        [[-8.1422e-02,  1.8842e-02,  1.0144e-01,  3.8074e-02, -5.5781e-01,
           3.0895e-01, -4.5643e-01,  1.6013e-01, -2.9781e-01,  1.9961e-01,
          -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01,  3.6203e-01,
           5.9737e-02,  2.3827e-01, -1.1631e-01, -7.2618e-02, -1.0607e-01],
         [-5.0455e-05, -1.7537e-01, -5.8831e-02,  1.4870e-01, -6.3641e-01,
          -8.8198e-02,  3.6017e-01, -7.5252e-02, -1.5342e-01, -1.3452e-01,
          -1.0299e-01,  5.4613e-02,  2.3648e-01, -4.3949e-01,  3.1918e-02,
          -1.2566e-01,  5.4776e-01,  1.7124e-01,  4.7584e-01, -1.5331e-01],
         [-2.6558e-01,  2.2919e-01,  5.0820e-01,  1.0909e-01, -4.1980e-01,
          -1.5907e-01,  2.2342e-01,  5.8021e-02, -2.7018e-01, -3.4850e-01,
          -1.8512e-01,  6.1239e-03,  2.4516e-01, -5.7344e-01, -7.5789e-02,
           1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]]],
       grad_fn=<StackBackward>)
--------------------------------------------------------------------------------
myoutput:
[tensor([[ 1.2620e-01,  6.0072e-01, -7.1112e-02,  2.0916e-01, -1.7033e-01,
         -5.8128e-02,  3.4290e-01,  5.7120e-01,  7.6652e-04,  4.7431e-01,
         -9.7752e-02, -6.9819e-01,  4.4204e-02,  1.8705e-01,  3.7682e-01,
         -3.2877e-01,  3.3991e-01, -9.3203e-01, -4.5387e-01, -7.6271e-01],
        [ 6.4547e-01, -3.3936e-01, -6.3192e-01, -3.2661e-02, -5.8965e-01,
         -7.6409e-01,  1.3470e-01, -3.7835e-01, -2.0378e-01, -1.6322e-01,
         -6.0952e-01,  2.9986e-02,  7.8969e-02, -6.4902e-01,  7.5271e-01,
          1.7919e-01,  6.5517e-01, -6.5625e-01,  3.2050e-01, -3.4623e-01],
        [-5.5407e-01,  2.0340e-01,  4.1821e-01, -9.7931e-02,  4.2492e-01,
          8.5182e-01,  7.9682e-02,  7.5144e-01, -2.0973e-01, -1.3963e-01,
          4.5111e-01, -5.1502e-01, -3.1101e-01,  8.7050e-02,  7.7077e-01,
          4.9754e-01, -1.6914e-01,  5.5128e-01, -7.0215e-01,  2.6817e-01]],
       grad_fn=<TanhBackward>), tensor([[-0.2918,  0.5013,  0.0336, -0.3569, -0.5727, -0.1577, -0.1704,  0.3353,
         -0.4692, -0.2399,  0.3714,  0.3964, -0.2294,  0.0909,  0.1388, -0.1164,
          0.2566, -0.4419,  0.6232, -0.5399],
        [-0.7720,  0.3316,  0.4893,  0.4173,  0.1900,  0.5988,  0.2782, -0.3852,
          0.1218, -0.1172, -0.4391,  0.1240,  0.3925,  0.3963, -0.5687,  0.2115,
          0.4115,  0.5132, -0.1591, -0.1080],
        [ 0.1837,  0.2649,  0.6524,  0.2677,  0.0456,  0.2033, -0.0522,  0.4843,
         -0.4531,  0.4153,  0.0187, -0.6308,  0.1819, -0.5004,  0.6018,  0.4021,
          0.4913, -0.5287,  0.1526, -0.1455]], grad_fn=<TanhBackward>), tensor([[ 0.1872, -0.1069,  0.4237,  0.4201, -0.6734,  0.0836, -0.0252,  0.2273,
         -0.2810, -0.0137, -0.2922, -0.3051, -0.2602, -0.4907,  0.0777,  0.1137,
          0.2030, -0.1614, -0.0779, -0.2083],
        [-0.0990,  0.3498,  0.5492, -0.3256,  0.2025,  0.3302, -0.5011, -0.1571,
          0.0209,  0.2982,  0.1901, -0.6905,  0.2419, -0.5201,  0.3651,  0.3990,
          0.5685, -0.4665,  0.0143, -0.1595],
        [-0.5264, -0.0514,  0.1115,  0.3346, -0.2498, -0.0302,  0.4115,  0.3076,
         -0.5988, -0.0438, -0.3437,  0.1128,  0.2481, -0.0956, -0.2785, -0.1713,
          0.2296, -0.1200,  0.0860, -0.2926]], grad_fn=<TanhBackward>), tensor([[-0.2957,  0.1804,  0.3002,  0.0617, -0.1344,  0.1993, -0.3224,  0.4173,
         -0.0781,  0.3736, -0.2150,  0.2653, -0.0528,  0.0651, -0.0500,  0.2519,
         -0.0915, -0.2620, -0.2110, -0.5948],
        [-0.1506,  0.4123,  0.0162,  0.1171,  0.0414, -0.0956, -0.2576,  0.4046,
         -0.6677, -0.0049, -0.2525, -0.2696,  0.2976, -0.4672, -0.0190,  0.1525,
          0.2290, -0.4887,  0.0049, -0.7503],
        [-0.2533, -0.2999,  0.0536, -0.4347, -0.4320,  0.2809, -0.2127, -0.5016,
         -0.2124,  0.3309, -0.4574, -0.1008, -0.1006, -0.2328,  0.3993,  0.0364,
          0.6901,  0.1125,  0.4137,  0.6626]], grad_fn=<TanhBackward>), tensor([[-8.1422e-02,  1.8842e-02,  1.0144e-01,  3.8074e-02, -5.5781e-01,
          3.0895e-01, -4.5643e-01,  1.6013e-01, -2.9781e-01,  1.9961e-01,
         -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01,  3.6203e-01,
          5.9737e-02,  2.3827e-01, -1.1631e-01, -7.2619e-02, -1.0607e-01],
        [-5.0548e-05, -1.7537e-01, -5.8831e-02,  1.4870e-01, -6.3641e-01,
         -8.8198e-02,  3.6017e-01, -7.5253e-02, -1.5342e-01, -1.3452e-01,
         -1.0299e-01,  5.4613e-02,  2.3648e-01, -4.3949e-01,  3.1918e-02,
         -1.2566e-01,  5.4776e-01,  1.7124e-01,  4.7584e-01, -1.5331e-01],
        [-2.6558e-01,  2.2919e-01,  5.0820e-01,  1.0909e-01, -4.1980e-01,
         -1.5907e-01,  2.2342e-01,  5.8021e-02, -2.7018e-01, -3.4850e-01,
         -1.8512e-01,  6.1239e-03,  2.4516e-01, -5.7344e-01, -7.5789e-02,
          1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]],
       grad_fn=<TanhBackward>)]
official_output
tensor([[[ 1.2620e-01,  6.0072e-01, -7.1112e-02,  2.0916e-01, -1.7033e-01,
          -5.8128e-02,  3.4290e-01,  5.7120e-01,  7.6649e-04,  4.7431e-01,
          -9.7752e-02, -6.9819e-01,  4.4204e-02,  1.8705e-01,  3.7682e-01,
          -3.2877e-01,  3.3991e-01, -9.3203e-01, -4.5387e-01, -7.6271e-01],
         [ 6.4547e-01, -3.3936e-01, -6.3193e-01, -3.2661e-02, -5.8965e-01,
          -7.6409e-01,  1.3470e-01, -3.7835e-01, -2.0378e-01, -1.6322e-01,
          -6.0952e-01,  2.9986e-02,  7.8969e-02, -6.4902e-01,  7.5271e-01,
           1.7919e-01,  6.5517e-01, -6.5625e-01,  3.2050e-01, -3.4623e-01],
         [-5.5407e-01,  2.0340e-01,  4.1821e-01, -9.7931e-02,  4.2492e-01,
           8.5182e-01,  7.9682e-02,  7.5144e-01, -2.0973e-01, -1.3963e-01,
           4.5111e-01, -5.1502e-01, -3.1101e-01,  8.7050e-02,  7.7077e-01,
           4.9754e-01, -1.6914e-01,  5.5128e-01, -7.0215e-01,  2.6817e-01]],

        [[-2.9177e-01,  5.0127e-01,  3.3566e-02, -3.5687e-01, -5.7271e-01,
          -1.5774e-01, -1.7043e-01,  3.3525e-01, -4.6915e-01, -2.3995e-01,
           3.7142e-01,  3.9644e-01, -2.2941e-01,  9.0899e-02,  1.3878e-01,
          -1.1636e-01,  2.5660e-01, -4.4189e-01,  6.2322e-01, -5.3986e-01],
         [-7.7200e-01,  3.3155e-01,  4.8930e-01,  4.1734e-01,  1.8999e-01,
           5.9885e-01,  2.7816e-01, -3.8521e-01,  1.2183e-01, -1.1717e-01,
          -4.3911e-01,  1.2396e-01,  3.9253e-01,  3.9633e-01, -5.6871e-01,
           2.1150e-01,  4.1146e-01,  5.1318e-01, -1.5914e-01, -1.0799e-01],
         [ 1.8367e-01,  2.6493e-01,  6.5243e-01,  2.6774e-01,  4.5578e-02,
           2.0329e-01, -5.2159e-02,  4.8428e-01, -4.5313e-01,  4.1533e-01,
           1.8746e-02, -6.3081e-01,  1.8190e-01, -5.0044e-01,  6.0178e-01,
           4.0211e-01,  4.9127e-01, -5.2867e-01,  1.5256e-01, -1.4553e-01]],

        [[ 1.8723e-01, -1.0690e-01,  4.2369e-01,  4.2007e-01, -6.7342e-01,
           8.3559e-02, -2.5240e-02,  2.2735e-01, -2.8096e-01, -1.3662e-02,
          -2.9221e-01, -3.0512e-01, -2.6019e-01, -4.9072e-01,  7.7736e-02,
           1.1373e-01,  2.0299e-01, -1.6141e-01, -7.7901e-02, -2.0833e-01],
         [-9.8969e-02,  3.4982e-01,  5.4921e-01, -3.2558e-01,  2.0254e-01,
           3.3020e-01, -5.0109e-01, -1.5706e-01,  2.0853e-02,  2.9821e-01,
           1.9009e-01, -6.9054e-01,  2.4189e-01, -5.2012e-01,  3.6514e-01,
           3.9902e-01,  5.6852e-01, -4.6647e-01,  1.4296e-02, -1.5953e-01],
         [-5.2637e-01, -5.1397e-02,  1.1150e-01,  3.3456e-01, -2.4977e-01,
          -3.0166e-02,  4.1154e-01,  3.0765e-01, -5.9878e-01, -4.3782e-02,
          -3.4375e-01,  1.1282e-01,  2.4812e-01, -9.5623e-02, -2.7851e-01,
          -1.7131e-01,  2.2957e-01, -1.1999e-01,  8.5984e-02, -2.9264e-01]],

        [[-2.9568e-01,  1.8038e-01,  3.0018e-01,  6.1720e-02, -1.3442e-01,
           1.9932e-01, -3.2239e-01,  4.1725e-01, -7.8142e-02,  3.7360e-01,
          -2.1505e-01,  2.6528e-01, -5.2758e-02,  6.5120e-02, -4.9986e-02,
           2.5186e-01, -9.1457e-02, -2.6198e-01, -2.1105e-01, -5.9480e-01],
         [-1.5058e-01,  4.1227e-01,  1.6235e-02,  1.1707e-01,  4.1378e-02,
          -9.5621e-02, -2.5761e-01,  4.0463e-01, -6.6765e-01, -4.8583e-03,
          -2.5254e-01, -2.6960e-01,  2.9760e-01, -4.6718e-01, -1.9016e-02,
           1.5246e-01,  2.2903e-01, -4.8867e-01,  4.9081e-03, -7.5035e-01],
         [-2.5332e-01, -2.9992e-01,  5.3646e-02, -4.3469e-01, -4.3205e-01,
           2.8094e-01, -2.1272e-01, -5.0162e-01, -2.1240e-01,  3.3086e-01,
          -4.5738e-01, -1.0083e-01, -1.0064e-01, -2.3278e-01,  3.9928e-01,
           3.6350e-02,  6.9007e-01,  1.1249e-01,  4.1367e-01,  6.6261e-01]],

        [[-8.1422e-02,  1.8842e-02,  1.0144e-01,  3.8074e-02, -5.5781e-01,
           3.0895e-01, -4.5643e-01,  1.6013e-01, -2.9781e-01,  1.9961e-01,
          -1.5315e-01, -1.9327e-01, -1.6216e-01, -3.1397e-01,  3.6203e-01,
           5.9737e-02,  2.3827e-01, -1.1631e-01, -7.2618e-02, -1.0607e-01],
         [-5.0455e-05, -1.7537e-01, -5.8831e-02,  1.4870e-01, -6.3641e-01,
          -8.8198e-02,  3.6017e-01, -7.5252e-02, -1.5342e-01, -1.3452e-01,
          -1.0299e-01,  5.4613e-02,  2.3648e-01, -4.3949e-01,  3.1918e-02,
          -1.2566e-01,  5.4776e-01,  1.7124e-01,  4.7584e-01, -1.5331e-01],
         [-2.6558e-01,  2.2919e-01,  5.0820e-01,  1.0909e-01, -4.1980e-01,
          -1.5907e-01,  2.2342e-01,  5.8021e-02, -2.7018e-01, -3.4850e-01,
          -1.8512e-01,  6.1239e-03,  2.4516e-01, -5.7344e-01, -7.5789e-02,
           1.2656e-01, -6.1614e-02, -3.4017e-01, -3.2117e-01, -4.1551e-02]]],
       grad_fn=<StackBackward>)

可见,两个模型结果一致。手写模型结构无误。

 


 

结论:  通过理论和实践代码表明,RNN是同参时序神经网络(同参,亦称:权重共享),一层RNN有四个参数变量,每个timestep的样本均与这同样的四个参数变量计算。 高层依赖低一层的计算结果,当前时刻依赖前一时刻的计算结果。

 

 

 

本文若有不当之处,敬请批评指正。