GAT 算法原理介绍与源码分析
文章目录
- 广而告之
- 五. 总结
零. 前言 (与正文无关, 请忽略)
对自己之前分析过的文章做一个简单的总结:
- 机器学习基础:LR /LibFM
- 特征交叉:DCN /PNN /DeepMCP /xDeepFM /FiBiNet /AFM
- 用户行为建模:DSIN /DMR /DMIN
- 多任务建模:MMOE
- Graph 建模:GraphSage
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.上的阅读体验会更好一些, 地址是
一. 文章信息
- 论文标题: Graph Attention Networks
- 论文地址:https://arxiv.org/pdf/1710.10903.pdf
- 代码地址:https://github.com/PetarV-/GAT
- 发表时间: ICLR 2018
- 论文作者: 详见文章
- 作者单位: University of Cambridge
二. 核心观点
GAT (Graph Attention Networks) 采用 Attention 机制来学习邻居节点的权重, 通过对邻居节点的加权求和来获得节点本身的表达.
三. 核心观点解读
GAT 的实现机制如下图所示:
注意右图中, GAT 采用 Multi-Head Attention, 图中有 3 种颜色的曲线, 表示 3 个不同的 Head. 在不同的 Head 下, 节点 可以学习到不同的 embedding, 然后将这些 embedding 进行 concat/avg 便生成 .
下面直接看分析代码吧.
四. 源码分析
GAT 的源码位于: https://github.com/PetarV-/GAT
GAT 网络本身是通过堆叠多个 Graph Attention Layer 层构成的, 首先介绍 Graph Attention Layer 的实现.
4.1 Graph Attention Layer
Graph Attention Layer 的定义:
设 个输入节点的特征为: , 采用 Attention 机制生成新的节点特征
Attention 系数按如下方式生成:
其中 , 而
其代码实现位于: https://github.com/PetarV-/GAT/blob/master/utils/layers.py. 注意在代码实现中, 作者的写法很简洁精妙, 不是照着上面的公式直接写的, 而是做了一点程度的变换.
由于 , 因此令 , 其中 , 那么 其实等效于 , 下面代码实现中, 采用的就是等效的写法.
这里再补充两个小要点: conv1d
的实现以及 bias_mat
的生成. 首先看 conv1d
的实现:
再来看 bias_mat
的生成, 代码位于: https://github.com/PetarV-/GAT/blob/master/utils/process.py, 实现如下:
4.2 GAT 网络
代码定义于: https://github.com/PetarV-/GAT/blob/master/models/gat.py, 实现如下:
主要内容为堆叠 Graph Attention Layer, 就不详细介绍了.
五. 总结
没有总结, 内心只有纠结.