- 相关论文:A deep graph neural network architecture for modelling spatio-temporal dynamics in resting-state functional MRI data
0 abstract
静息态功能核磁共振(Resting-state functional magnetic resonance imaging rs-fMRI)成功的应用在人脑组织的了解。通常大脑被分成多个ROI区域,并且建模为一个“图”。每一个ROI表示一个node,而association meansures between ROI-specific blood-oxygen-level-dependent time series作为edge。
最近图神经网络GNN have seen a surge in popularity due to 他们在modeling unstructured relational data的成功。然而最近的GNN的发展并没有用在rs-fMRI的分析当中,特别它的spatio-temporal dynamics。
本文中,我们提出了一种novel deep neural network architecture,结合了GNN和TCN,来端到端的学习spatial和temporal的内容。特别是这从涉及到intra-feature learning和inter-feature learning。
- intra-feature learning特征内学习,例如用TCN来学习tempora dynamics
- inter-feature learning特征间学习,例如利用ROI-wise dynamics with GNN
我们用消融实验来评估了我们的研究,使用了25159个来自UK biobank的rs-fMRI数据库。同时用了smaller Human Connectome Project数据库,在unimodal和multimodel fashion下进行的。
此外,我们还展示了我们的架构包含explainability-related features,可以轻松的映射到现实的神经生物学结论。
We suggest that this model could lay the groundwork for future deep learning architectures focused on leveraging the inherently and inextricably spatio-temporal nature of rs-fMRI data.(表示这个研究可以为未来的研究打下基础。)
1 introduction
rs-fMRI是当前了解人脑的一种无创成像技术。通常rs-fMRI数据的使用包含一些graph-theoretical measures,例如centrality measures and community structures,用这样的方式来总结高维的,全脑的数据。
为了实现这样的效果,通常需要降低数据的维度,用下面的三种方法之一:
- 折叠时间维度,例如 connectivity matrices between brain regions;
- 降低空间维度,例如 将voxelwise signal集合到预先定义的脑部区域;
- 1和2的方法都采用。
这些特征工程的步骤可以起到作用,主要是因为典型的rs-fMRI数据量大,而信噪比又比较低。(换言之,噪音大)。
上述的降低维度的方法,尽管大幅度降低了计算量,但是inevitably的忽略了大量的信息,这些信息在分析任务当中很可能是有价值的。比方说,这些rs-fMRI的时间维度到static volume静态体积,那么大脑不同区域之间的interactions也会固定。然而又很多的研究表明,大脑的之间的功能连接是动态的,并且随着时间的推移不断变化。另外一个例子是,最常用的关联性测量师基于线性模型的,但是众所周知,neuromonitoring data和大脑信号是非线性的。
为了克服上面的问题,一种可能的方案是设计一个模型,这个模型可以结合特征工程和大脑功能活动的低维表示学习。
Such a model would need to be able to accommodate both the spatial and temporal complexities of rs-fMRI data. To date, deep learning architectures have had great success at leveraging specific inductive biases from complex high-dimensional data. Convolutional neural networks (CNNs), for instance, are extremely effective at extracting shared spatial features such as corners and edges from grid-like data (e.g., 2D and 3D images).
至今为止,深度学习架构在复杂高维数据的归纳偏差方面取得了巨大的成功。例如卷积网络在网格状数据中提取共享空间特征方面非常有效。CNN、RNN、GNN都简单介绍了一下。
在我们的工作中,我们提出了一种模型,利用GNN负责空间的脑部区域间的关系,使用TCN捕捉血氧等级依赖性时间序列的特征。通过incorporating GNN和CNN,我们结合了特征间和特征内学习。特别的,GNN可以捕获ROI之间的高阶相互关系,消除脑区域特定时间序列之间的相互作用中的线性假设的弊端。
我们进一步的设计了我们的结构,保留了edge weights,避免了阈值化和二值化邻接矩阵的行为,保留了一定的可解释性。这样做的目的是可以对模型内部的行为做出神经科学上的解释。然后是将自己用的数据和可重复性。
本文的贡献在于:
- 使用了更大的数据集;
- 扩展了choices in the graph threshold hyperparameter
- 分析了变换权重的重要性
- 先前的研究仅仅使用了1D卷积核二值化后的GNN,而不是包含边缘权重的GNN
- 先前的研究没有进行可解释性的分析。
- 代码已经公开。
Related Works
Methods
问题定义
为了将rs-fMRI表示为无向权重图,大脑被空间上的划分成了N个ROI区域,表示N个nodes,并且记作:.表示节点i的BOLD时间序列特征。BOLD时间序列长度为T。ROI之间的连接被表示为VxV的图的边。每一条边(i,j)连接两个节点,边的connection strength被定义为。节点集合为V,边的集合为E,图记作G=(V,E)。我们定义图结构为G,为节点特征,边特征为,adjacency matrix邻接矩阵为.
TCN网络
有证据表明TCN的性能甚至会优于RNN的性能,对于序列数据。卷积算子的优势在于:
- 对于长序列输入有着等地的需求,LSTM和GRU需要需要大量内存;
- 更好的并行化,因为CNN和TCN是作为整体处理的,而不是RNN那样按照顺序处理的。
- 更容易训练。因为LSTM会经常存在梯度消失的问题。
TCN网络的笔记之前做过, (1277条消息) LSTM的备胎,用卷积处理时间序列——TCN与因果卷积(理论+Python实践)
GNN图网络
主要用的是这个论文提出的模型: Relational inductive biases, deep learning, and graph networks.(也有少许出入)
主要的block叫做GN block,包含两个update functions 和一个aggregation functions。
【edge model】 第一个update function是叫做edge model,记作,为每一个edge更新属性。如何更新呢?基于原来的边attributes 和相连接节点i和j的特征。因此可以写作:
需要注意的是,每一个边原来只是一个标量,但是经过这个的操作后,结果的维度可能变得不同。可能变成一个M维度的向量。
【node model】 接下来更新每一个节点的特征。对于每一个节点,需要聚合edge特征。
其中,,是一个从node i为起点的边,与node j通过边edge k相连接。其中需要对边缘的顺序具有不变性,从而解释数据的unordered structure无需结构。平均求和是这类对顺序具有不变性的一个例子。
最终我们更新节点的特征使用下面的更新公式:
The aggregation function 需要对边的顺序具有不变性,但是更新函数,可以更加的灵活一些。比方说,如果特征向量是1D的,更新函数可以用MLP,如果特征表示图像,那么CNN更合适,如果特征表示为序列,那么RNN更合适。
graph pooling
本文的预测结果当中,是需要一个graph-level的预测结果的。因此需要把node的特征做一个整合(pool)。
通常来说,用一个全局池化机制可以实现。每一个节点的特征平均或者concat成整个graph的特征,因此构建一个最终、低维度的图的嵌入特征。
但是不同的节点有着不同的重要性,在不同的子任务当中。(不同脑部疾病可能更关注不同的脑部区域)。我们假定一个hierarchical pooling mechanism可以创造更棒的embedding。最终,我们采用了differentiable pooling operator,可以称作DiffPool。这个东西可以如何将节点折叠成更小的簇,直到最终只有一个节点的存在。
当使用GN模块的时候,节点和边的稀疏表达是没问题的。到那时DiffPool只能作用在稠密表示的图上。在其他的
这一块到时候看代码的时候边看边思考。