Task5 概览

图神经网络已经成功地应用于许多节点或边的预测任务,然而,在超大图上进行图神经网络的训练仍然具有挑战。普通的基于SGD的图神经网络的训练方法存在算力、内存、精度等各方面的问题。

本次任务将首先分析传统方法处理超大图存在的问题;接着介绍一种新的图神经网络的训练方法——Cluster-GCN,用于解决超大图的训练问题;最后使用代码来实现Cluster-GCN算法。

一、处理超大图存在的问题及现有方法

使用传统方法处理超大图主要存在以下问题:

  • 算力问题: 随着图神经网络层数增加,计算成本 呈指数增长
  • 内存问题: 保存整个图的信息和每一层每个节点的表征到内存(显存)而消耗 巨大内存(显存)空间
  • 精度损耗问题: 无需保存整个图信息和每一层每个节点表征的方法,可能会损失预测精度 或者对内存利用率提高不明显

基于现有的问题,PyG文档给出了常见的应对超大图节点表征学习的的方法。

一种主流的思路是将数据集进行划分,分批训练,再使用合适的方法耦合训练结果。相关论文包括:

  1. Inductive Representation Learning on Large Graphs, 这篇文章使用了大图中节点的低维嵌入方法,同时提出了GraphSAGE;


图神经网络聚合 aggregation 图神经网络 组合优化_深度学习


  1. Deep Graph Neural Networks with Shallow Subgraph Samplers,这篇文章使用浅子图采样器进行大图训练,主要针对深图神经网络的计算爆炸问题;
  2. Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Network,提出一种新的图神经网络的训练方法,它利用图聚类结构进行数据集采样。即本次学习的模型。


图神经网络聚合 aggregation 图神经网络 组合优化_神经网络_02


二、Cluster-GCN方法

正是由于传统的方法存在以上的各种局限,于是有研究者提出了Cluster-GCN方法,解决了普通训练方法无法训练超大图的问题,主要分成一下几点:

  1. 利用图节点聚类算法将一个图的节点划分为c个簇,每一次选择几个簇的节点和这些节点对应的边构成一个子图,然后对子图做训练。
  2. 由于是利用图节点聚类算法将节点划分为多个簇,所以簇内边的数量要比簇间边的数量多得多,所以可以提高表征利用率,并提高图神经网络的训练效率。
  3. 每一次随机选择多个簇来组成一个batch,这样不会丢失簇间的边,同时也不会有batch内类别分布偏差过大的问题。
  4. 基于小图进行训练,不会消耗很多内存空间,于是我们可以训练更深的神经网络,进而可以达到更高的精度。

三、Cluster-GCN实践

3.1源代码分析

3.1.1 加载数据集

from torch_geometric.datasets import Reddit
from torch_geometric.data import ClusterData, ClusterLoader, NeighborSampler

dataset = Reddit('../dataset/Reddit')
data = dataset[0]
print(dataset.num_classes)
print(data.num_nodes)
print(data.num_edges)
print(data.num_features)

3.1.2 图节点聚类与数据加载器生成

cluster_data = ClusterData(data, num_parts=1500, recursive=False, save_dir=dataset.processed_dir)
# 此数据加载器返回的一个batch由多个簇组成
train_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True, num_workers=12)
# 使用此数据加载器对图节点聚类
subgraph_loader = NeighborSampler(data.edge_index, sizes=[-1], batch_size=1024, shuffle=False, num_workers=12)

3.1.3 图神经网络的构建

class Net(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Net, self).__init__()
        self.convs = ModuleList(
            [SAGEConv(in_channels, 128),
             SAGEConv(128, out_channels)])

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i != len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return F.log_softmax(x, dim=-1)

  # inference方法用于推理阶段,获取更高的预测精度
    def inference(self, x_all):
        pbar = tqdm(total=x_all.size(0) * len(self.convs))
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch.
        for i, conv in enumerate(self.convs):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(device)
                x = x_all[n_id].to(device)
                x_target = x[:size[1]]
                x = conv((x, x_target), edge_index)
                if i != len(self.convs) - 1:
                    x = F.relu(x)
                xs.append(x.cpu())

                pbar.update(batch_size)

            x_all = torch.cat(xs, dim=0)

        pbar.close()

        return x_all

3.1.4 训练、验证与测试

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_features, dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

def train():
    model.train()

    total_loss = total_nodes = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
        loss.backward()
        optimizer.step()

        nodes = batch.train_mask.sum().item()
        total_loss += loss.item() * nodes
        total_nodes += nodes

    return total_loss / total_nodes


@torch.no_grad()
def test():  # Inference should be performed on the full graph.
    model.eval()

    out = model.inference(data.x)
    y_pred = out.argmax(dim=-1)

    accs = []
    for mask in [data.train_mask, data.val_mask, data.test_mask]:
        correct = y_pred[mask].eq(data.y[mask]).sum().item()
        accs.append(correct / mask.sum().item())
    return accs


for epoch in range(1, 31):
    loss = train()
    if epoch % 5 == 0:
        train_acc, val_acc, test_acc = test()
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
              f'Val: {val_acc:.4f}, test: {test_acc:.4f}')
    else:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

【注意】:如果内存不够大的话这里同样也会报错,可以通过增加虚拟内存解决该问题。

运行结果如下:

Epoch: 01, Loss: 1.1692
Epoch: 02, Loss: 0.4743
Epoch: 03, Loss: 0.3937
Epoch: 04, Loss: 0.3554
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [18:36<00:00, 417.46it/s]
Epoch: 05, Loss: 0.3465, Train: 0.9570, Val: 0.9557, test: 0.9527
Epoch: 06, Loss: 0.3177
Epoch: 07, Loss: 0.3175
Epoch: 08, Loss: 0.3054
Epoch: 09, Loss: 0.2904
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [18:13<00:00, 426.15it/s]
Epoch: 10, Loss: 0.3034, Train: 0.9530, Val: 0.9456, test: 0.9439
Epoch: 11, Loss: 0.2816
Epoch: 12, Loss: 0.2738
Epoch: 13, Loss: 0.2745
Epoch: 14, Loss: 0.2858
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:41<00:00, 439.03it/s]
Epoch: 15, Loss: 0.2681, Train: 0.9657, Val: 0.9549, test: 0.9521
Epoch: 16, Loss: 0.2662
Epoch: 17, Loss: 0.2626
Epoch: 18, Loss: 0.2564
Epoch: 19, Loss: 0.2780
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:39<00:00, 439.92it/s]
Epoch: 20, Loss: 0.2623, Train: 0.9639, Val: 0.9477, test: 0.9466
Epoch: 21, Loss: 0.2503
Epoch: 22, Loss: 0.2437
Epoch: 23, Loss: 0.2382
Epoch: 24, Loss: 0.2426
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:38<00:00, 440.08it/s]
Epoch: 25, Loss: 0.2419, Train: 0.9680, Val: 0.9523, test: 0.9512
Epoch: 26, Loss: 0.2437
Epoch: 27, Loss: 0.2693
Epoch: 28, Loss: 0.2393
Epoch: 29, Loss: 0.2305
Evaluating: 100%|█████████████████████████████████████████████████████████████| 465930/465930 [17:38<00:00, 440.00it/s]
Epoch: 30, Loss: 0.2307, Train: 0.9721, Val: 0.9541, test: 0.9522

3.2 不同数量簇的实验

尝试将数据集切分成不同数量的簇进行实验,然后观察结果并进行比较。分别设置num_parts为1500、1000和2000。部分实验结果如下:

Epoch: 01, Loss: 1.3562
Epoch: 02, Loss: 0.5103
......
Epoch: 29, Loss: 0.2154
Evaluating: 100%|██████████| 465930/465930 [17:43<00:00, 441.52it/s]
Epoch: 30, Loss: 0.2126, Train: 0.9741, Val: 0.9542, test: 0.9530

根据实验结果观察,可以发现簇的数量越多,占用服务器的内存连越多,但是效果不一定好。在对Reddit数据及的实验中,num_parts为1000的效果更好。