图卷积神经网络层的pytorch复现
- 基本概念:
- 图卷积层的数学描述:
- 图的总体架构:
- 图卷积层pytorch代码实现[^2]和注释:
- 参考:
基本概念:
图结构非常常见,属于非欧式空间,例如社交网络图、知识图谱、用户点击购买产品产生的关系图、分子结构图、人体关节点连接图。图卷积神经网络算法是一种根据图卷积和神经网络的理论,应用于广泛存在的图结构的实体的算法。图卷积来源于二维卷积,神经网络算法相当于在传统机器学习算法上加上可以学习的权重,使用梯度下降算法更新权重。总的来说,图卷积神经网络是一种结合信号处理和神经网络应用于图结构的一种新算法。具体应用可以对节点、边和整个图进行分类、分割、检测等应用。本文主要记录学习图卷积神经网络的一些理论和想法。
图卷积层的数学描述:
图卷积层经过很多的优化和迭代,目前比较主流的一种方法是每一层的复杂度更低,而通过堆叠多层进行更深层次的学习的方法进行学习。具体的推导过程在文献1中,这里省略大篇幅的推导过程。
多层的图卷积网络按照下面的逐层递推规则:
式子中的指的是图的邻接矩阵形式,指的是可学习的权重,是图节点的最初的特征矩阵经过每一层变换后的矩阵,指的是激活函数。邻接矩阵和拉普拉斯矩阵可以参考2。
图的总体架构:
图的总体架构如下所示,本篇文章需要实现的就是里面的hidden layers,GraphConvolutionLayer不改变图的结构,所以图结构进过图卷积神经网络层后仍然保持原来的结构。但是后面层的节点能够聚合前面层的节点信息。类似于卷积神经网络的“视野”的概念。深层得到更多的语义信息,浅层则保留更多的原始特征信息。
图卷积层pytorch代码实现3和注释:
# -*- coding: utf-8 -*-
# # @Use : Paper reproduction
# # @Time : 2022/8/11 21:30
# # @FileName: GraphConvolutionLayer.py
# # @Software: PyCharm
# # @Paper : Spectral Networks and Locally Connected Networks on Graphs
import torch
import torch.nn as nn
class GraphConvolutionLayer(nn.Module):
"""
图卷积神经网络
"""
def __init__(self, input_dim, output_dim, adjacency_matrix=None, use_bias=True):
super(GraphConvolutionLayer, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.use_bias = use_bias
data = torch.tensor(input_dim, output_dim)
self.weight = nn.Parameter(data=data)
if self.use_bias:
self.bias = nn.Parameter(torch.tensor(input_dim, output_dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self.L_matrix = self.calculate_L_matrix(adjacency_matrix)
def reset_parameters(self):
"""
重置权重
"""
nn.init.kaiming_normal_(self.weight)
if self.use_bias:
nn.init.zeros_(self.bias)
def forward(self, input_feature):
"""
邻接矩阵是稀疏矩阵,使用稀疏矩阵的乘法
@param input_feature:输入特征
"""
# 计算图卷积的输出
# (\widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})
suport = torch.mm(input_feature, self.weight)
output = torch.sparse.mm(self.L_matrix, suport) # 注意因为邻接矩阵是稀疏矩阵,所以使用稀疏矩阵乘法提高效率
if self.use_bias:
output += self.bias
return output
@staticmethod
def calculate_L_matrix(adjcency: torch.Tensor) -> torch.Tensor:
"""
根据图的邻接矩阵计算矩阵L_matrix
L_matrix = \widetilde{D}^{-\frac{1}{2}}\widetilde{A}\widetilde{D}^{-\frac{1}{2}}
"""
dim = adjcency.shape[0]
A_ware = adjcency + torch.eye(dim) # 生成单位矩阵
D_ii = torch.flatten(torch.sum(A_ware, dim=0)) # 按照列进行求和,并且展平成一维向量
D_ware = torch.diag_embed(D_ii) # 转换成对角矩阵
D_ware_temp = torch.pow(D_ware, -0.5) # 求对角阵的-1/2指数
L_matrix = torch.mm(torch.mm(D_ware_temp, A_ware), D_ware_temp) # 使用广播机制进行矩阵乘法
return L_matrix