目录
- 前言
- 1. 数据处理
- 2. GCN链接预测
- 2.1 负采样
- 2.2 模型搭建
- 2.3 模型训练/测试
- 完整代码
前言
1. 数据处理
这里以CiteSeer网络为例:Citeseer网络是一个引文网络,节点为论文,一共3327篇论文。论文一共分为六类:Agents、AI(人工智能)、DB(数据库)、IR(信息检索)、ML(机器语言)和HCI。如果两篇论文间存在引用关系,那么它们之间就存在链接关系。
加载数据:
dataset = Planetoid('data', name='CiteSeer')
print(dataset[0])
输出:
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
x=[3327, 3703]表示一共有3327个节点,然后节点的特征维度为3703,这里实际上是去除停用词和在文档中出现频率小于10次的词,整理得到3703个唯一词。edge_index=[2, 9104],表示一共9104条edge,数据一共两行,每一行都表示节点编号。
利用PyG封装的RandomLinkSplit我们很容易实现数据集的划分:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.1, num_test=0.1, is_undirected=True,
add_negative_train_samples=False),
])
dataset = Planetoid('data', name='CiteSeer', transform=transform)
train_data, val_data, test_data = dataset[0]
最终我们得到train_data, val_data, test_data
。
输出一下原始数据集和三个被划分出来的数据集:
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
Data(x=[3327, 3703], edge_index=[2, 7284], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[3642], edge_label_index=[2, 3642])
Data(x=[3327, 3703], edge_index=[2, 7284], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[910], edge_label_index=[2, 910])
Data(x=[3327, 3703], edge_index=[2, 8194], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[910], edge_label_index=[2, 910])
从上到下依次为原始数据集、训练集、验证集以及测试集。其中,训练集中一共有3642个正样本,验证集和测试集中均为455个正样本+455个负样本。
2. GCN链接预测
本次实验使用GCN来进行链接预测:首先利用GCN对训练集中的节点进行编码,得到节点的向量表示,然后使用这些向量表示对训练集中的正负样本(在每一轮训练时重新采样负样本)进行有监督学习,具体来讲就是利用节点向量求得样本中节点对的内积,然后与标签求损失,最后反向传播更新参数。
2.1 负采样
链接预测训练过程中的每一轮我们都需要对训练集进行采样以得到与正样本数量相同的负样本,验证集和测试集在数据集划分阶段已经进行了负采样,因此不必再进行采样。
负采样函数:
def negative_sample():
# 从训练集中采样与正边相同数量的负边
neg_edge_index = negative_sampling(
edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
# print(neg_edge_index.size(1)) # 3642条负边,即每次采样与训练集中正边数量一致的负边
edge_label_index = torch.cat(
[train_data.edge_label_index, neg_edge_index],
dim=-1,
)
edge_label = torch.cat([
train_data.edge_label,
train_data.edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)
return edge_label, edge_label_index
这里用到了negative_sampling
方法,其参数有:
具体来讲,negative_sampling
方法利用传入的edge_index
参数进行负采样,即采样num_neg_samples
条edge_index
中不存在的边。num_nodes
指定节点个数,method
指定采样方法,有sparse
和dense
两种方法。
采样后将neg_edge_index
与训练集中原有的正样本train.edge_label_index
进行拼接以得到完整的样本集,同时也需要在原本的train_data.edge_label
后面添加指定数量的0用于表示负样本。
2.2 模型搭建
GCN链接预测模型搭建如下:
class GCN_LP(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
def decode(self, z, edge_label_index):
# z所有节点的表示向量
src = z[edge_label_index[0]]
dst = z[edge_label_index[1]]
# print(dst.size()) # (7284, 64)
r = (src * dst).sum(dim=-1)
# print(r.size()) (7284)
return r
def forward(self, x, edge_index, edge_label_index):
z = self.encode(x, edge_index)
return self.decode(z, edge_label_index)
编码器由一个两层GCN组成,用于得到训练集中节点的向量表示,解码器用于得到训练集中节点对向量间的内积。
由前面可知训练集中的正样本数量为3642,经过负采样函数negative_sample
得到3642个负样本,一共7284个样本,最终解码器返回7284个节点对间的内积。
损失函数采用BCEWithLogitsLoss
,要想弄懂BCEWithLogitsLoss
,就要先了解BCELoss
。
BCELoss
是一种二元交叉熵损失:
而BCEWithLogitsLoss
则是在BCELoss
的基础上增加了Sigmoid
选项,即先把输入经过一个Sigmoid
,然后再计算BCELoss
。
评价指标采用AUC:
roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
2.3 模型训练/测试
代码如下:
def test(model, data):
model.eval()
with torch.no_grad():
z = model.encode(data.x, data.edge_index)
out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
model.train()
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
def train():
model = GCN_LP(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss().to(device)
min_epochs = 10
best_model = None
best_val_auc = 0
final_test_auc = 0
model.train()
for epoch in range(100):
optimizer.zero_grad()
edge_label, edge_label_index = negative_sample()
out = model(train_data.x, train_data.edge_index, edge_label_index).view(-1)
loss = criterion(out, edge_label)
loss.backward()
optimizer.step()
# validation
val_auc = test(model, val_data)
test_auc = test(model, test_data)
if epoch + 1 > min_epochs and val_auc > best_val_auc:
best_val_auc = val_auc
final_test_auc = test_auc
print('epoch {:03d} train_loss {:.8f} val_auc {:.4f} test_auc {:.4f}'
.format(epoch, loss.item(), val_auc, test_auc))
return final_test_auc
最终测试集上的AUC为:
final best auc: 0.9076681560198044
完整代码
代码地址:GNNs-for-Link-Prediction。原创不易,下载时请给个follow和star!感谢!!