GraphSage 算法原理介绍与源码浅析
文章目录
- 邻居采样
- 采样方法
- 邻居聚合
- 聚合方法
- 总结
前言
最近在做 Graph 相关的工作, 两年前做过一段时间, 想不到兜兜转转又回到最初的起点~???????????? 工作继续稳步推进, 同时打算复习下基础算法. 论文也忒多了, 一段时间没看, 已经跟不上了 ????????????
这里插句题外话, 之前我写的一些博客, 代码分析的太过细节了, 我自己平时翻看的时候, 都会直接将琐碎的东西给略过. 从这一行为可以看出, 之前博客中记录了太多冗余的内容, 不仅在记录时浪费了时间, 更给后续查阅带来了一些阻碍. 鉴于此, 以后做代码分析打算尽力只分析源码的核心部分, 再加上部分感兴趣的内容.
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中;
文章信息
- 论文标题: Inductive Representation Learning on Large Graphs
- 论文地址:https://arxiv.org/pdf/1706.02216.pdf
- 代码地址:https://github.com/williamleif/GraphSAGE
- 发表时间: NIPS, 2017
- 论文作者: William L. Hamilton, Rex Ying, Jure Leskovec
- 作者单位: Stanford
补充: 在 全面理解 PinSage 文章中详细介绍了 PinSage 算法, GraphSage 算法是 PinSage 的理论基础, 而 PinSage 包含了很多工程上的实践经验, 两者可以结合起来看看.
核心观点
GraphSage (Graph SAmple and aggreGatE) 属于 Inductive learning 算法, 它学习一种聚合函数, 通过聚合节点邻居的特征信息来学习目标节点本身的 embedding 表达. 从它的名字中可以看出算法的核心步骤分别是邻居采样以及特征聚合; GraphSage 就是我们通常意义的机器学习任务, 对于未知的节点具有泛化能力, 它和 Transductive Learning 算法 (如 GCN, DeepWalk, 在固定的图结构上学习节点的 embedding) 不同的是, Transductive Learning 算法在图中加入新节点后, 需要将模型重新训练.
核心观点解读
在介绍具体的算法之前, 先简要对比一下 Inductive learning 与 Transductive learning. 关于它们的详细介绍推荐阅读文章 Inductive vs. Transductive Learning.
其中:
Inductive learning is the same as what we commonly know as traditional supervised learning. We build and train a machine learning model based on a labelled training dataset we already have. Then we use this trained model to predict the labels of a testing dataset which we have never encountered before.
In contrast to inductive learning, transductive learning techniques have observed all the data beforehand, both the training and testing datasets. We learn from the already observed training dataset and then predict the labels of the testing dataset. Even though we do not know the labels of the testing datasets, we can make use of the patterns and additional information present in this data during the learning process.
The main difference is that during transductive learning, you have already encountered both the training and testing datasets when training the model. However, inductive learning encounters only the training data when training the model and applies the learned model on a dataset which it has never seen before.
Transduction does not build a predictive model. If a new data point is added to the testing dataset, then we will have to re-run the algorithm from the beginning, train the model and then use it to predict the labels. On the other hand, inductive learning builds a predictive model. When you encounter new data points, there is no need to re-run the algorithm from the beginning.
(Inductive Learning 被翻译为归纳式学习, Transductive Leanring 为直推式学习. 说实话, 这两个翻译把我整迷糊了, 从来没有记住过, 但是上面的英文释义却非常好记, 不容易忘????????????)
GraphSage 属于 Inductive learning 算法, 它学习一种聚合函数, 通过聚合节点邻居的特征信息来学习目标节点本身的 embedding 表达. 它的主要步骤就记录在它的名字中: Sample 与 Aggregate. 其中 Sample 阶段通过随机采样获取多跳邻居; Aggregate 阶段聚合邻居节点特征生成目标节点自身的 embedding. 以聚合 2 跳邻居为例, 它将首先聚合 2 跳邻居的特征生成 1 跳邻居的 embedding, 之后再聚合 1 跳邻居的 embedding 来生成节点本身的 embedding. 由于生成 1 跳邻居 embedding 时, 已经包含了 2 跳邻居的特征信息, 此时目标节点也将获得 2 跳邻居的特征信息. 论文中的图示形象地展示了这一过程:
生成完目标节点的 embedding 后, 可以提供给下游的机器学习系统做诸如节点分类的预估任务.
GraphSage 的前向传播算法如下图:
第一个 for 循环针对层数进行遍历, 第二个 for 循环用于遍历 Graph 中的所有节点, 针对每个节点 , 对邻居进行采样得到 , 并通过 对邻居节点的 embedding 进行聚合, 得到 , 再将它与目标节点当前的 embedding 进行拼接, 经过非线性变换后赋给 , 从而完成目标节点 的一次更新. 当外层的 for 循环 () 遍历结束时, 节点 将完成
在具体代码实现时, 实际上采用的是 minibatch 的形式, 论文 Appendix A 进行了介绍, 待会在源码分析中也将进行描述.
源码分析
本次分析的代码位于 https://github.com/williamleif/GraphSAGE, 是官方开源的 TensorFlow 版本.
GraphSage 的核心在于 Sample 和 Aggregate. 由于训练模型时, 我们一般采用 minibatch 的方式进行训练, 因此在论文的 Appendix A 中, 还给出了一份 minibatch 版本的伪代码, 如下:
其中代码 1 ~ 7 行表示对邻居进行采样, 而 8 ~ 15 行表示邻居聚合.
在 GraphSage 的代码中, 邻居采样以及聚合代码均位于 https://github.com/williamleif/GraphSAGE/blob/master/graphsage/models.py 文件中, 在进行介绍之前, 需要解释一个会令人困惑的点. 作者对于 Graph 中每层节点的采样个数设置如下:
实际上表达的含义如下图:
注意图中的 layer 1
层采样的节点数为 10, 而 layer 2
层采样的节点数为 25
, 刚好和代码中的定义相反. 关于这一点作者在 Appendix A 中介绍伪代码的下方进行了说明, 而且注意到上面伪代码的第一行, 令 , 一开始就将目标节点赋值给 , 采样的时候是按 的顺序进行遍历 (伪代码第 2 行), 而聚合时则是按 的顺序进行遍历 (伪代码第 9 行).
看源码时如果不注意这一点, 容易有些困惑. 为了方便介绍, 后续我就拿具体的数字, 比如 10, 25 之类的来说明代码含义, 这样可以快速判断当前在 Graph 中的第几层.
邻居采样
GraphSage 邻居采样代码定义如下:
阅读时不需要太在意实现细节 (比如 与
-
inputs
: 大小为[B,]
的 Tensor, 表示目标节点的 ID; -
layer_infos
: 假设 Graph 深度为, 那么layer_infos
的大小为, 保存 Graph 中每一层的相关信息, 比如采样的邻居数num_samples
, 采样方法neigh_sampler
等.
由于从目标节点开始采样, 采样结束后:
-
samples
保存 3 个 Tensor, 大小为:[Tensor(B*1,), Tensor(B*10,), Tensor(B*250,)]
, 表示 Graph 中每一层的节点 id -
support_sizes
为[1, 10, 250]
, 表示对每一个目标节点, 它在 Graph 中每一层的邻居个数.
采样方法
采样方法 sampler
定义在: https://github.com/williamleif/GraphSAGE/blob/master/graphsage/neigh_samplers.py, 代码如下:
对其进行调用时传入 inputs
, 包含目标节点的 ids
以及采样个数 num_samples
, 最后返回大小为 [B, num_samples]
的 Tensor. 另外注意邻接矩阵的生成也包括放回采样以及不放回采样, 具体见作者源码, 这里不过多介绍. (详见 https://github.com/williamleif/GraphSAGE/blob/master/graphsage/minibatch.py#L76 中 construct_adj
函数)
邻居聚合
邻居聚合代码位于: https://github.com/williamleif/GraphSAGE/blob/master/graphsage/models.py, 定义如下 (只保留了核心的代码):
由于之前提到的, 作者在采样的时候是按 的顺序进行遍历 (伪代码第 2 行), 而聚合时则是按
聚合方法
关于聚合方法, 主要定义在: https://github.com/williamleif/GraphSAGE/blob/master/graphsage/aggregators.py
MeanAggregator
定义如下:
对 neigh_vecs
邻居节点的 embedding 进行 mean pooling 后, 再和目标节点本身的 embedding 进行相加或者拼接.
GCNAggregator
先使用 tf.expand_dims(self_vecs, axis=1)
展开成 [B, 1, E]
的形式, 再和 neigh_vecs
进行 concat, 最后整体求 mean;
MaxPoolingAggregator
代码中使用 neigh_h = tf.reduce_max(neigh_h, axis=1)
对邻居 embedding 进行聚合.
其他还有 TwoMaxLayerPoolingAggregator, MeanPoolingAggregator 以及 SeqAggregator (实现 LSTM Aggregator) 就不多分析, 后续有需要的时候再看.
总结
国庆快乐~