目录
1.整个AutoCorrelation的结构
2.对q和k进行处理
3.对Value进行处理
4.aggregation(最难的部分)
1.整个AutoCorrelation的结构
整个AutoCorrelation层初始化了三个函数,但是只用到了两个,这里存在疑问,不知道为什么是这样的。
2.对q和k进行处理
q_fft和k_fft是得到的傅里叶变换的一个向量,他的维度是(32,8,64,49)。
res是将这两个得到的向量进行共轭相乘,结果是res,是(32,8,64,49)维度的数据。
corr是将这个结果进行逆向傅里叶,得到的是(32,8,64,96)的维度的数据。
傅里叶的意义在哪?在本文的最后会记个笔记。
3.对Value进行处理
最后返回的值一个是corr一个是v,也就是在这整个步骤里面,主要进行的就是对V和利用q和k得到这个corr。
在train的时候对values的处理,主要是这个函数:time_delay_agg_training,这个函数需要传入两个参数,一个是corr一个是v,所以应该是利用这两个参数进行处理v,因为返回的值只有一个v。
从结果来看,因为其中的步骤不是很懂,index最终的数是[0,94,1,95]也就是说对于这个96这个维度的数进行取top?得到其中corr比较大的索引?
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1):这一句得到了一个(32,4)的向量,他是对原本的mean_value进行堆叠,这个mean_value是怎么来的呢,是对这个corr(32,8,64,96)的8和64这个维度进行取mean,得到的是一个(32,96)的mean,这个32是batch_size,也就是说对这个batch中的每个corr都取平均。所以weights的意义在于,把这个batch中的每个数据的这个比较大的数的维度的值都取出来进行堆叠,也就是说这个weights是这32个数据内部corr比较大的位置的数据的堆叠。这个weights也就相当于将比较活跃的attention进行提取出来堆叠。
tmp_corr = torch.softmax(weights, dim=-1):这一步是对这个比较活跃的部分进行softmax。
下面一部分,是aggregation应该是最难的一部分,到了午饭时间还没看懂,等吃完饭回来再看。
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
4.aggregation(最难的部分)
上面的部分大概是像informer中对活跃的进行提取,下面绿色的部分是aggregation的部分。
(1)tmp_values = values:
第一步先把原始的values给这个tmp_values,现在这个tmp_values和原来的values都是一样的都是维度为(32,8,64,96)维度的数据。
(2)delays_agg = torch.zeros_like(values).float():第二步,这一步生成一个和values一样维度的向量,他的维度也是(32,8,64,96)
(3)pattern = torch.roll(tmp_values, -int(index[i]), -1):这一步的意思是,将-1这个维度像上移动index的这个步数。
(4)for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)):
这一步的意思是将这个roll的tmp_values沿着上面得到的index这个数往上卷,然后用这个得到的pattern和经过softmax的比较大的那个类似qk的weights进行相乘。
(5)return delays_agg:最后return的是这个delays_agg,不知道这里有什么用??看的不是很懂。
AutoCorrelation这个函数最后返回的是V。这个V在最后是所得到的output,也就是相当于attention的编码的感觉。
这个V就是这里的out。一直重复。