Dual_RNN_Block应该是整个网络中最重要的部分了。

这里,每一个Block相当于网络内部的一层 ,源码中默认设置4层Dual_RNN_Block。

每一个Dual_RNN_Block又分为intra_rnn(块内rnn)和inter_rnn(块间rnn)


intra_rnninter_rnn是dual的灵魂,但是刚开始接触很难理解这个概念。
结合代码和原论文的配图,可以理解为对Dual_RNN_Block的3D上对K和S维度训练


输入张量

输入的张量shape为[B, N, K, S], 具体的来源可以参考​​这里​​。

其中B为batch-size,每一个batch里的N,K,S,如下图。(K=2P)

【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_rnn

intra_rnn

RNN是最最后一维做训练,但是与其他维度也有关联。尤其是-2维度。

intra_rnn是针对K的训练,K是形容block的变量,即在这个维度上理解为intra

下图为intra_rnn block的流程图

【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_深度学习_02

inter_rnn

intra_rnn是针对S的训练,S是形容block个数的变量,是block与block之间的关系,即在这个维度上理解为inter

下图为inter_rnn block的流程图

【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_深度学习_03

双剑合璧 Dual_RNN_Block

上述两个intra_rnn + inter_rnn就是dual_rnn了。

但是有点细节:


  1. intra_rnn的结果是加上了输入张量x 再送到 inter_rnn计算
  2. inter_rnn的结果是加上了intra_rnn的结果再输出
    【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_pytorch_04
    最后,把paper中的图贴在这里方便大家理解。
    【Dual-Path-RNN-Pytorch源码分析】Dual_RNN_Block_rnn_05