GAT pytorch 科普文章

1. 引言

图神经网络(Graph Neural Networks, GNNs)是最近兴起的一种深度学习方法,主要用于处理图结构数据。传统的神经网络无法有效地处理图数据,而GNN则通过对图结构的节点和边进行建模,能够充分利用图中节点和边的关系。其中,Graph Attention Network (GAT) 是一种非常流行的图神经网络模型,本文将对GAT进行科普介绍,并提供pytorch的代码示例。

2. GAT简介

GAT是由Veličković等人于2017年提出的一种图神经网络模型。GAT通过自适应地计算节点之间的注意力权重,并将注意力权重作为节点特征的加权和。这使得GAT能够对不同节点之间的关系赋予不同的重要性。GAT模型的关键点是使用注意力机制来计算节点之间的权重。

3. GAT模型

GAT模型可以分为以下几个步骤:

步骤1: 图数据准备

首先,我们需要将图数据转换为计算机可以处理的形式。通常,图数据可以使用邻接矩阵(Adjacency Matrix)和节点特征矩阵(Node Feature Matrix)来表示。

步骤2: 节点表示学习

GAT模型使用自注意力机制来计算节点之间的注意力权重。具体地,对于图中的每个节点i,GAT模型首先计算出节点i与其邻居节点之间的注意力权重。然后,使用注意力权重对节点i的邻居节点的特征进行加权平均。这样,我们可以得到节点i的表示向量。

步骤3: 分类预测

在节点表示学习之后,我们可以使用这些表示向量进行各种任务,如节点分类、图分类等。在节点分类任务中,我们可以使用全连接层将节点表示向量映射到不同的类别。

步骤4: 损失计算

最后,我们使用损失函数来计算模型的预测结果与真实标签之间的差异。常用的损失函数包括交叉熵损失函数等。

下面是GAT模型的代码示例:

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(1))
        attention = torch.sparse_softmax(torch.sparse.FloatTensor(adj[0].t(), e, adj[1]))

        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)