Non-local U-Nets for Biomedical Image Segmentation

 

简介

自2015年以来,在生物医学图像分割领域,U-Net得到了广泛的应用,至今,U-Net已经有了很多变体。U-Net如下图所示,是一个encoder-decoder结构,左边一半的encoder包括若干卷积,池化,把图像进行下采样,右边的decoder进行上采样,恢复到原图的形状,给出每个像素的预测。


tensorflow图片分类_tensorflow图像分割unet

这篇文章的作者提出了一个新的U-Net模型: Non-local U-Nets,推理速度更快,精度更高,参数量更少。提出了一个新的上/下采样方法:Global Aggregation Block,把self-attention和上/下采样相结合,在上/下采样的同时考虑全图(Non-local)信息。

更多的U-Net可以在以下代码库了解:

https://github.com/ShawnBIT/UNet-family

存在的不足

作者首先分析了U-Net存在的不足。

  1. U-Net的encoder部分由若干卷积层和池化层组成,由于他们都是local的运算,只能看到局部的信息,因此需要通过堆叠多层来提取长距离信息,这种方式较为低效,参数量大,计算量也大。过多的下采样导致更多空间信息的损失(U-Net下采样16倍),图像分割要求对每个像素进行准确地预测,空间信息的损失会导致分割图不准确。
  2. decoder的形式与encoder部分正好相反,包括若干个上采样运算,使用反卷积或插值方法,他们也都是local的方法。

创新点

  1. 为了解决以上问题,作者基于self-attention提出了一个Non-local的结构,global aggregation block用于在上/下采样时可以看到全图的信息,这样会使得到更精准的分割图。
  2. 简化U-Net,减少参数量,提高推理速度,上采样和下采样使用global aggregation block,使分割更准确。

global aggregation block



tensorflow图片分类_卷积_02

global aggregation block如上图所示,图看起来很复杂,实际上运算过程很简单。

1.该结构与Attention Is All You Need这篇文章的形式很相似。输入图为X(B*H*W*C),经过QueryTransform和1*1卷积,转换为Q(B*Hq*Wq*Ck),K(B*H*W*Ck),V(B*H*W*Cv)。QueryTransform可以为卷积,反卷积,插值等你喜欢的方法,最后的输出结果的H与W将与这个值一致。

2.代码里在Unfold之前有Multi-Head(Attention Is All You Need)的操作,不过在论文中没有说明,实际上是把通道分为N等份。Unfold是把Batch,height,width,N通道合并,Q(B*Hq*Wq*N*ck),K(B*H*W*N*ck),V(B*H*W*N*cv)。



tensorflow图片分类_图像分割_03

3.接下来是经典的点积attention操作,得到一个权值矩阵A((B*Hq*Wq*N)*(B*H*W*N)),用于self-attention的信息加权,分母Ck是通道数,作用是调节矩阵的数值不要过大,使训练更稳定(这个也是Attention Is All You Need提出的)。最后权值矩阵A和V点乘,得到最终的结果((B*Hq*Wq*N)*cv),可见输出的height和width由Q决定,通道数由V决定。



tensorflow图片分类_卷积_04

4.我最近看了两篇上采样的论文,DUpsample和CARAFE,现在很多上采样相关的论文关注于在上采样时用CNN扩大感受野,增加图像局部信息。这篇文章提出的global aggregation block是一个将注意力机制和上/下采样相结合的方法,关注全图信息,感受野更大,可以在其他任务上试用一下,效果如何还是要看实践的结果。

Non-local U-Nets



tensorflow图片分类_卷积_05

文章提出的Non-local U-Nets如上图所示。

相比U-Net,卷积层数减少,图像下采样倍率从16倍变成4倍,保留了更多空间信息。encoder和decoder之间的skip connections用相加的方式,不是拼接的方式,让模型推理速度更快。

上下采样使用global aggregation block,使分割图更准确。

实验结果

实际上这篇文章提出的东西没有很高深,但是它数据很多,很漂亮。

医学图像分割的评价指标不太了解哈,不过看数据效果挺好。



tensorflow图片分类_卷积_06

tensorflow图片分类_图像分割_07

参数量减少了很多,推理速度更快。



tensorflow图片分类_图像分割tensorflow_08

结果图如下:(训练数据直译是3D多模态等强度婴儿脑MR图像)



tensorflow图片分类_tensorflow图片分类_09