图神经网络中的注意力机制

本文讨论了 GNN 中常用的注意力机制,相关论文有:

Graph Attention Networks (GAT)

GAT 的基本原理

GAT 是 GNN 中的经典模型,原始论文为 Graph Attention Networks 。在最初的 GCN 中,中心节点从邻域节点得到的消息会通过 sum, max, mean 等方式进行聚合,每个节点消息的重要性都是相等的。所谓注意力,就是希望中心节点对不同节点传递的消息做不同的对待,即对所有消息都分配一个权重。GAT 的思路非常简单,节点嵌入的计算方式为
Graphviz绘制神经网络图 图神经网络gat_深度学习
其中 Graphviz绘制神经网络图 图神经网络gat_Graph_02 表示节点 Graphviz绘制神经网络图 图神经网络gat_深度学习_03 对节点 Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_04 的注意力,计算公式为
Graphviz绘制神经网络图 图神经网络gat_神经网络_05
如果边 Graphviz绘制神经网络图 图神经网络gat_机器学习_06 也有特征,那么注意力 Graphviz绘制神经网络图 图神经网络gat_Graph_02 可以是
Graphviz绘制神经网络图 图神经网络gat_Graph_08
可以把上面的式子分解为两步,一是计算消息的权重 (weight)
Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_09
二是通过 softmax 计算注意力
Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_10
以上是单头注意力公式,如果考虑多头注意力 (multi-head attention) ,可以将多个注意力计算的结果联结 (concat) 组成一个嵌入向量,也可以计算多个注意力结果的平均值。联结计算方式为
Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_11
平均计算方式为
Graphviz绘制神经网络图 图神经网络gat_机器学习_12
GAT 消息传递的过程可以用论文中的 Figure 1 来说明

Graphviz绘制神经网络图 图神经网络gat_深度学习_13

GAT 的优缺点

优点如下:

  • 注意力计算只涉及到中心节点及其邻域节点,很容易实现并行计算
  • 因为只关心图结构中的局部注意力,所以能够将训练的模型应用到陌生的图数据中,并不局限于训练数据中才有的图结构

缺点有

  • 注意力只关注节点的局部特征,在获取全局特征上效果可能不佳(个人观点)

代码实现

此处的代码为 Pytorch-Geometric 中的 GATConv ,为了简单起见,我们不考虑二分图 (bipartite graphs)的情况。

class GATConv(MessagePassing):
   
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, heads: int = 1, concat: bool = True,
                 negative_slope: float = 0.2, dropout: float = 0.0,
                 add_self_loops: bool = True, bias: bool = True, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super(GATConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops

        # 节点特征变换需要的算子
        self.lin_src = Linear(in_channels, heads * out_channels,
                              bias=False, weight_initializer='glorot')
        self.lin_dst = self.lin_src

        # 计算注意力需要的权重参数 W
        self.att_src = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_dst = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha = None

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_src.reset_parameters()
        self.lin_dst.reset_parameters()
        glorot(self.att_src)
        glorot(self.att_dst)
        zeros(self.bias)

    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                size: Size = None, return_attention_weights=None):

        H, C = self.heads, self.out_channels

        # 首先使用 torch 中的 Linear 对输入节点特征做变换,这里源节点和目标节点
        # 变换计算的权重是共享的,如果输入时二分图,二者的权重就不同
        assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
        x_src = x_dst = self.lin_src(x).view(-1, H, C)

        x = (x_src, x_dst)

        # 接下来计算节点级别的注意力系数,源节点和目标节点都需要计算
        # 计算公式为 a^T @ x_i
        alpha_src = (x_src * self.att_src).sum(dim=-1)
        alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
        alpha = (alpha_src, alpha_dst)

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                # We only want to add self-loops for nodes that appear both as
                # source and target nodes:
                num_nodes = x_src.size(0)
                if x_dst is not None:
                    num_nodes = min(num_nodes, x_dst.size(0))
                num_nodes = min(size) if size is not None else num_nodes
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                edge_index = set_diag(edge_index)

        # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
        out = self.propagate(edge_index, x=x, alpha=alpha, size=size)

        alpha = self._alpha
        assert alpha is not None
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        if isinstance(return_attention_weights, bool):
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        # 在 message 之前,MassagePassing 通过 __collect__ 函数计算 message 需要的
        # 参数。这里 propagate 的输入参数为节点特征 x(类型为 Tuple), 节点注意力系数 
        # alpha (类型为 Tuple) 以及节点数量,经过 __collect__ 过后,x_j=x[0],
        # alpha_j = alpha[0], alpha_i = alpha[1]
                
        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i

        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha  # Save for later use.
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.unsqueeze(-1)

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)

注意到代码中 Graphviz绘制神经网络图 图神经网络gat_神经网络_14 的计算并不是按照公式直译为代码,实际计算过程为
Graphviz绘制神经网络图 图神经网络gat_Graph_15
这样做我觉得是为了适应 MassagePassing 结构,需要的内存也小一点。

GATv2 Conv

GATv2 Conv 是对 GAT 的改进,原始论文为 How Attentive are Graph Attention Networks. 相对于 GAT, GATv2 只是修改的注意力中线性变换 Linear 的计算顺序,并引入了静态注意力 (Static attention ) 和动态注意力 (Dynamic attention). 具体计算公式如下
Graphviz绘制神经网络图 图神经网络gat_机器学习_16
注意力 Graphviz绘制神经网络图 图神经网络gat_Graph_02
Graphviz绘制神经网络图 图神经网络gat_Graph_18
对比 GAT,只是改变了 Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_19 , Graphviz绘制神经网络图 图神经网络gat_神经网络_20 , Graphviz绘制神经网络图 图神经网络gat_神经网络_21

Transformer Conv

Transformer conv 基本原理

Transformer conv 是来自百度的论文 Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification ,该论文使用了类似于 Transformer 的方式计算节点注意力。

对于每一条从节点 Graphviz绘制神经网络图 图神经网络gat_深度学习_03 指向节点 Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_04 的边,我们需要计算 query, key, value,具体计算公式如下
Graphviz绘制神经网络图 图神经网络gat_机器学习_24
其中

  • q, k 分别是 query 和 key 向量
  • Graphviz绘制神经网络图 图神经网络gat_深度学习_25
  • Graphviz绘制神经网络图 图神经网络gat_Graph_26, 与 Transformer 中一样是计算 query 和 key 之间的点积注意力

写成矩阵形式就是
Graphviz绘制神经网络图 图神经网络gat_深度学习_27
注意力为
Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_28
如果考虑边的特征
Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_29
如果是多头注意力,仍然可以采用 GAT 中的联结 (concat) 和平均 (mean) 两种方式将多头注意的结果变换为一个节点特征向量。

除此之外,论文中还增加了一个门控单元来计算残差的权重,以避免过平滑问题 (over smoothing).
Graphviz绘制神经网络图 图神经网络gat_Graphviz绘制神经网络图_29
论文中 Transformer Conv 的网络结构如下

Graphviz绘制神经网络图 图神经网络gat_Graph_31

Transformer Conv 代码实现

代码为 PyG 中 TransformerConv 的实现,最好结合 PyG 文档阅读

class TransformerConv(MessagePassing):
     def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        beta: bool = False,
        dropout: float = 0.,
        edge_dim: Optional[int] = None,
        bias: bool = True,
        root_weight: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super(TransformerConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads 
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.edge_dim = edge_dim

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        #  query, key, value 的变换算子,使用 Linear 完成
        self.lin_key = Linear(in_channels[0], heads * out_channels)
        self.lin_query = Linear(in_channels[1], heads * out_channels)
        self.lin_value = Linear(in_channels[0], heads * out_channels)
        # edge feature 的变换算子,在实际计算中 edge feature 可以是节点相对位置,
        # 或者其他能够表示节点相对信息的特征
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)

        # 使用 cancat 方式组合多头注意的结果,需要计算的变量有
        # 残差连接 (skip), skip 的门控权重 beta
        if concat:
            self.lin_skip = Linear(in_channels[1], heads * out_channels,
                                   bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)
        else:
            self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()
        self.lin_skip.reset_parameters()
        if self.beta:
            self.lin_beta.reset_parameters()

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, return_attention_weights=None):
		# edge_attr 为每一条边的特征
                
        
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)

        # propagate_type: (x: PairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        # 计算中心节点特征,一般推荐使用 root_weight=True
        # 是否使用门控计算 beta,看需求
        # 对应的公式为 x_i' = W_1 * x_i + \sum message_j 或者
        #  x_i' = \beta_i W_1 * x_i + (1 - \beta_i) (\sum message_j)
        if self.root_weight:
            x_r = self.lin_skip(x[1])
            if self.lin_beta is not None:
                beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
                beta = beta.sigmoid()
                out = beta * x_r + (1 - beta) * out
            else:
                out += x_r

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
		
        # 计算 query, key
        # query = W_3 * x_i
        # key = W_4 * x_j
        query = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
        key = self.lin_key(x_j).view(-1, self.heads, self.out_channels)

        # 计算边的特征
        # edge_feat = W_6 * edge_attr
        # key = key + edge_feat
        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            key += edge_attr

        # 计算 query 和 key 的点积注意力
        alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        # 计算边的消息,out = (message + edge_attr) * alpha_{ij}
        out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
        if edge_attr is not None:
            out += edge_attr

        out *= alpha.view(-1, self.heads, 1)
        return out

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)