Dual_RNN_Block应该是整个网络中最重要的部分了。
这里,每一个Block相当于网络内部的一层 ,源码中默认设置4层Dual_RNN_Block。
每一个Dual_RNN_Block又分为intra_rnn(块内rnn)和inter_rnn(块间rnn)
intra_rnn和inter_rnn是dual的灵魂,但是刚开始接触很难理解这个概念。
结合代码和原论文的配图,可以理解为对Dual_RNN_Block的3D上对K和S维度训练
输入张量
输入的张量shape为[B, N, K, S], 具体的来源可以参考这里。
其中B为batch-size,每一个batch里的N,K,S,如下图。(K=2P)
intra_rnn
RNN是最最后一维做训练,但是与其他维度也有关联。尤其是-2维度。
intra_rnn是针对K的训练,K是形容block的变量,即在这个维度上理解为intra
下图为intra_rnn block的流程图
inter_rnn
intra_rnn是针对S的训练,S是形容block个数的变量,是block与block之间的关系,即在这个维度上理解为inter
下图为inter_rnn block的流程图
双剑合璧 Dual_RNN_Block
上述两个intra_rnn + inter_rnn就是dual_rnn了。
但是有点细节:
- intra_rnn的结果是加上了输入张量x 再送到 inter_rnn计算
- inter_rnn的结果是加上了intra_rnn的结果再输出
最后,把paper中的图贴在这里方便大家理解。