Task6概览

之前的任务是学习如何设计图神经网络进行节点表征学习,并基于习得的节点表征进行下游的任务,例如节点分类或者链路预测。

本次任务将进行图级别表示的学习,称之为图表征学习。图表征学习要求根据节点属性、边和边的属性(如果有的话)生成一个向量作为图的表征,基于图表征可以做图的预测。

基于图同构网络(Graph Isomorphism Network, GIN)的图表征网络是当前最经典的图表征学习网络。本文将以GIN为例,介绍图同构的相关概念,同时介绍图同构测试的经典算法——Weisfeiler-Lehamn算法,并对GNN与WL-test之间的关系进行简单的分析;最后,给出了GIN的架构设计以及代码的实现及解释。

一、图同构与WL-test

1.1图同构背景

两个图是同构的,意思是两个图拥有一样的拓扑结构。

简单来说,两个图中的节点和边数量相同且边的连接关系相同,则两个图同构,两个图在拓扑上等价。



图神经网络异常检测数据集 图神经网络测试_pytorch


以上两个图虽然在形状上大不一样,但是根据定义来说,它们是属于同构的。

图同构测试的意义

比如在蛋白质结构、基因网络中,具有相似结构(同构测试或相似度计算)的蛋白质或基因结构可能具有相似的功能特性。又比如两位作者相似的期刊引文网络结构可能表示两位作者的研究内容相似等等。

1.2 WL-test原理及计算步骤

Weisfeiler-Lehman 图的同构性测试算法,简称WL-Test,是一种用于测试两个图是否同构的算法。

但是WL-test是图同构的一个必要但不充分的条件。也就是说,两个图的WL-test结果显示有差异,可认为这两个图是非同构的;但如WL-test结果显示没有差异,只能表述为这两个图可能同构。

WL-test包括四个步骤

  • 聚合邻居节点标签;
  • 多重集排序;
  • 标签压缩;
  • 更新标签。

上述四个步骤对应下面四张图。



图神经网络异常检测数据集 图神经网络测试_图神经网络异常检测数据集_02


然而WL算法只能判断两个图在k次iteration下是否同构,但无法度量图之间的相似性。进一步地,可以利用WL Subtree Kernel方法估计两个图的相似性。该方法实际上是在WL-test算法基础上增加了第五步,如下图所示:



图神经网络异常检测数据集 图神经网络测试_pytorch_03


迭代 1 轮后,利用计数函数分别得到两张图的计数特征,如:分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征。两个图的这样的向量的内积,即可作为这两个图的相似性的估计。

1.3 GNN与WL-test的关系

GNN实际上指的是基于消息传递图神经网络(Message Passing Neural Networks,MPNN) )架构设计的GNN模型,主要包括以下几个部分:

  1. 利用消息函数Mt聚合邻居特征(消息);
  2. 利用节点更新函数Ut更新节点自身的特征。

和前面WL-test计算步骤对比可以发现,WL-test和基于MPNN设计的图神经网络模型很相似。从另一个角度来讲,GCN模型可以看作图上非常有名的Weisfeiler-Lehman算法的一种变形

1.4作业

  • 请画出下方图片中的6号、3号和5号节点的从1层到3层到WL子树。


图神经网络异常检测数据集 图神经网络测试_神经网络_04


- 解答:

  1. 6号


图神经网络异常检测数据集 图神经网络测试_神经网络_05


  1. 3号


图神经网络异常检测数据集 图神经网络测试_深度学习_06


  1. 5号


图神经网络异常检测数据集 图神经网络测试_pytorch_07


参考资料:

  1. Datawhale组队学习【图神经网络】

二、图同构网络架构(GIN)及代码

基于图同构网络的图表征学习主要包含以下两个过程:

  1. 首先计算得到节点表征;
  2. 其次对图上各个节点的表征做图池化(Graph Pooling),或称为图读出(Graph Readout), 得到图的表征(Graph Representation)。

接下来将采用自顶向下的方式,来学习基于图同构模型(GIN)的图表征学习方法。

2.1 GIN-学习图中节点的表征(聚合与更新操作)

2.1.1 卷积层设计

在节点嵌入模块中的关键组件为GINConv,需要复写MPNN框架中的message、aggregate和update函数以实现GIN中的卷积过程。

import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder

### GIN convolution along the graph structure

class GINConv(MessagePassing):
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr = "add")
        self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
        self.eps = nn.Parameter(torch.Tensor([0]))
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边表征
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

2.1.2 节点表示学习模块

输入到此节点嵌入模块的节点属性为类别型向量,因此首先用 AtomEncoder 对其做嵌入得到第0层节点表征 。然后逐层计算节点表征 ,从第1层开始到第num_layers层,每一层节点表征的计算都以上一层的节点表征 h_list[layer] 、边 edge_index 和边的属性 edge_attr 为输入 。需要注意的是,GINConv的层数越多,此节点嵌入模块的感受野(receptive field)越大,结点 i 的表征最远能捕获到结点 i 的距离为 num_layers 的邻接节点的信息

import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F

# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
        """GIN Node Embedding Module"""

        super(GINNodeEmbedding, self).__init__()
        # num_layers (int, optional): number of GINConv layers. Defaults to 5.
        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        # add residual connection or not
        self.residual = residual

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")
    # emb_dim (int, optional): dimension of node embedding. Defaults to 300.
        self.atom_encoder = AtomEncoder(emb_dim)# 先将类别型节点属性转化为节点表征

        # List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layers):
            self.convs.append(GINConv(emb_dim))
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, batched_data):
        x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr

        # computing input node embedding
        h_list = [self.atom_encoder(x)]  # 先将类别型节点属性转化为节点表征
        for layer in range(self.num_layers):
            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            if layer == self.num_layers - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        # Different implementations of Jk-concat
        # JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layers + 1):
                node_representation += h_list[layer]

        return node_representation

AtomEncoder 与 BondEncoder
当节点和边的属性都为离散值时,它们属于不同的空间,无法直接将它们融合在一起。通过嵌入(Embedding),可以将节点属性和边属性分别映射到一个新的空间,在这个新的空间中,就可以对节点和边进行信息融合。从而在GINConv中,message()函数中的x_j + edge_attr 操作可以执行节点信息和边信息的融合。

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()
        
        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])

        return x_embedding

class BondEncoder(torch.nn.Module):
    
    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()
        
        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

        return bond_embedding   

if __name__ == '__main__':
    from ogb.graphproppred.dataset_pyg import PygGraphPropPredDataset
    dataset = PygGraphPropPredDataset(name = 'ogbg-molhiv')
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))

2.2 GIN-图池化Graph Pooling/图读出Graph Readout

GIN中的READOUT 函数为 SUM函数,通过对每次迭代得到的所有节点的特征求和得到该轮迭代的图特征,再拼接起每一轮迭代的图特征来得到最终的图特征:



图神经网络异常检测数据集 图神经网络测试_数据挖掘_08


采用拼接每一轮迭代的图特征而不是相加的原因在于不同层节点的表征属于不同的特征空间。这样得到的图的表示与WL Subtree Kernel得到的图的表征是等价的。

首先采用GINNodeEmbedding模块对图上每一个节点做节点嵌入(Node Embedding),得到节点表征;然后对节点表征做图池化得到图的表征;最后用一层线性变换对图表征转换为对图的预测。

import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding

class GINGraphRepr(nn.Module):

    def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
        """GIN Graph Pooling Module
        Args:
            num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表征的维度,dimension of graph representation).
            num_layers (int, optional): number of GINConv layers. Defaults to 5.
            emb_dim (int, optional): dimension of node embedding. Defaults to 300.
            residual (bool, optional): adding residual connection or not. Defaults to False.
            drop_ratio (float, optional): dropout rate. Defaults to 0.
            JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
            graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".

        Out:
            graph representation
        """
        super(GINGraphPooling, self).__init__()

        self.num_layers = num_layers
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layers < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)

        # Pooling function to generate whole-graph embeddings
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            self.pool = GlobalAttention(gate_nn=nn.Sequential(
                nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
        elif graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, batched_data):
        h_node = self.gnn_node(batched_data)

        h_graph = self.pool(h_node, batched_data.batch)
        output = self.graph_pred_linear(h_graph)

        if self.training:
            return output
        else:
            # At inference time, relu is applied to output to ensure positivity
            # 因为预测目标的取值范围就在 (0, 50] 内
            return torch.clamp(output, min=0, max=50)

Task6学习心得

这次的任务教程最初看的几遍比较陌生,具体原理感觉没有弄透,后面主动上网查找了相关资料,才慢慢理清思路,已经熟悉了大体的流程。但弄懂代码的具体细节还需要进一步学习~

参考资料:

  1. Datawhale组队学习【图神经网络】
  2. GIN:逼近WL-test的GNN架构