GAT pytorch 科普文章
1. 引言
图神经网络(Graph Neural Networks, GNNs)是最近兴起的一种深度学习方法,主要用于处理图结构数据。传统的神经网络无法有效地处理图数据,而GNN则通过对图结构的节点和边进行建模,能够充分利用图中节点和边的关系。其中,Graph Attention Network (GAT) 是一种非常流行的图神经网络模型,本文将对GAT进行科普介绍,并提供pytorch的代码示例。
2. GAT简介
GAT是由Veličković等人于2017年提出的一种图神经网络模型。GAT通过自适应地计算节点之间的注意力权重,并将注意力权重作为节点特征的加权和。这使得GAT能够对不同节点之间的关系赋予不同的重要性。GAT模型的关键点是使用注意力机制来计算节点之间的权重。
3. GAT模型
GAT模型可以分为以下几个步骤:
步骤1: 图数据准备
首先,我们需要将图数据转换为计算机可以处理的形式。通常,图数据可以使用邻接矩阵(Adjacency Matrix)和节点特征矩阵(Node Feature Matrix)来表示。
步骤2: 节点表示学习
GAT模型使用自注意力机制来计算节点之间的注意力权重。具体地,对于图中的每个节点i,GAT模型首先计算出节点i与其邻居节点之间的注意力权重。然后,使用注意力权重对节点i的邻居节点的特征进行加权平均。这样,我们可以得到节点i的表示向量。
步骤3: 分类预测
在节点表示学习之后,我们可以使用这些表示向量进行各种任务,如节点分类、图分类等。在节点分类任务中,我们可以使用全连接层将节点表示向量映射到不同的类别。
步骤4: 损失计算
最后,我们使用损失函数来计算模型的预测结果与真实标签之间的差异。常用的损失函数包括交叉熵损失函数等。
下面是GAT模型的代码示例:
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphAttentionLayer(nn.Module):
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.dropout = dropout
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, input, adj):
h = torch.mm(input, self.W)
N = h.size()[0]
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(1))
attention = torch.sparse_softmax(torch.sparse.FloatTensor(adj[0].t(), e, adj[1]))
attention = F.dropout(attention, self.dropout, training=self.training)
h_prime = torch.matmul(attention, h)
if self.concat:
return F.elu(h_prime)
else:
return h_prime
class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
super(GAT, self).__init__()
self.dropout = dropout
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
self.add_module('attention_{}'.format(i), attention)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)