目录

1.整个AutoCorrelation的结构

2.对q和k进行处理

3.对Value进行处理

4.aggregation(最难的部分)


1.整个AutoCorrelation的结构

整个AutoCorrelation层初始化了三个函数,但是只用到了两个,这里存在疑问,不知道为什么是这样的。

pytorch 变分自动编码器 pytorch self-attention代码_人工智能

2.对q和k进行处理

q_fft和k_fft是得到的傅里叶变换的一个向量,他的维度是(32,8,64,49)。

res是将这两个得到的向量进行共轭相乘,结果是res,是(32,8,64,49)维度的数据。

corr是将这个结果进行逆向傅里叶,得到的是(32,8,64,96)的维度的数据。

傅里叶的意义在哪?在本文的最后会记个笔记。

pytorch 变分自动编码器 pytorch self-attention代码_机器学习_02

3.对Value进行处理

最后返回的值一个是corr一个是v,也就是在这整个步骤里面,主要进行的就是对V和利用q和k得到这个corr。

pytorch 变分自动编码器 pytorch self-attention代码_pytorch 变分自动编码器_03

在train的时候对values的处理,主要是这个函数:time_delay_agg_training,这个函数需要传入两个参数,一个是corr一个是v,所以应该是利用这两个参数进行处理v,因为返回的值只有一个v。

pytorch 变分自动编码器 pytorch self-attention代码_目标检测_04

 从结果来看,因为其中的步骤不是很懂,index最终的数是[0,94,1,95]也就是说对于这个96这个维度的数进行取top?得到其中corr比较大的索引?


weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1):这一句得到了一个(32,4)的向量,他是对原本的mean_value进行堆叠,这个mean_value是怎么来的呢,是对这个corr(32,8,64,96)的8和64这个维度进行取mean,得到的是一个(32,96)的mean,这个32是batch_size,也就是说对这个batch中的每个corr都取平均。所以weights的意义在于,把这个batch中的每个数据的这个比较大的数的维度的值都取出来进行堆叠,也就是说这个weights是这32个数据内部corr比较大的位置的数据的堆叠。这个weights也就相当于将比较活跃的attention进行提取出来堆叠。


tmp_corr = torch.softmax(weights, dim=-1):这一步是对这个比较活跃的部分进行softmax。
 
下面一部分,是aggregation应该是最难的一部分,到了午饭时间还没看懂,等吃完饭回来再看。
 
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
    pattern = torch.roll(tmp_values, -int(index[i]), -1)
    delays_agg = delays_agg + pattern * \
                 (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))


pytorch 变分自动编码器 pytorch self-attention代码_pytorch 变分自动编码器_05

4.aggregation(最难的部分)

上面的部分大概是像informer中对活跃的进行提取,下面绿色的部分是aggregation的部分。

pytorch 变分自动编码器 pytorch self-attention代码_开发语言_06


(1)tmp_values = values:
第一步先把原始的values给这个tmp_values,现在这个tmp_values和原来的values都是一样的都是维度为(32,8,64,96)维度的数据。
(2)delays_agg = torch.zeros_like(values).float():第二步,这一步生成一个和values一样维度的向量,他的维度也是(32,8,64,96)
(3)pattern = torch.roll(tmp_values, -int(index[i]), -1):这一步的意思是,将-1这个维度像上移动index的这个步数。
(4)for i in range(top_k):
    pattern = torch.roll(tmp_values, -int(index[i]), -1)
    delays_agg = delays_agg + pattern * \
                 (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)):
这一步的意思是将这个roll的tmp_values沿着上面得到的index这个数往上卷,然后用这个得到的pattern和经过softmax的比较大的那个类似qk的weights进行相乘。
(5)return delays_agg:最后return的是这个delays_agg,不知道这里有什么用??看的不是很懂。



 

pytorch 变分自动编码器 pytorch self-attention代码_pytorch 变分自动编码器_07

pytorch 变分自动编码器 pytorch self-attention代码_人工智能_08

 


AutoCorrelation这个函数最后返回的是V。这个V在最后是所得到的output,也就是相当于attention的编码的感觉。


 

pytorch 变分自动编码器 pytorch self-attention代码_目标检测_09

 这个V就是这里的out。一直重复。

pytorch 变分自动编码器 pytorch self-attention代码_机器学习_10