1.前置

GCN的提出可以说缓解了大小卷积核之间的矛盾,大小卷积核对模型产生的作用各有千秋,总结来说:

大卷积核:

  • 优点:感受域较大,能够获得丰富的上下文信息
  • 缺点:参数量大,计算量大
  • 举例:AlexNet,LeNet等使用了5*5,11*11的大卷积核

小卷积核:

  • 优点: 参数量小,计算量小,用多个小卷积核代替一个大卷积核,可以进行多次非线性激活,使模型判别能力增加。
  • 缺点:感受域不足,产生的特征图可能 比较稀疏,深度堆叠卷积可能会产生不可控问题(例如模型退化等)
  • 举例:VGG之后

当输入通道和输出通道相同时,使用小卷积核参数会比使用大卷积核参数要少,当输入通道和输出通道不同时,使用大卷积核参数反而会比较少。

自VGG之后各位大佬们都以解决小卷积核带来的弊端这个方向做研究,提出了空洞卷积,ASPP等方法,GCN可以说是从另一个方向入手问题,解决大卷积核带来的问题。在引言部分GCN从结果出发,提出语义分割的两个挑战——分类和定位。分类和定位是相互矛盾的,对分类任务来说,模型需要具有平移不变性,以应对目标的多种变化例如平移和旋转,而对于定位来说,模型需要对变换所敏感,能够精确定位每个像素。好的模型应该处理好上述两种挑战。为了解决定位任务,模型应该使用全卷积网络,去除全卷积或全局池化;为了解决分类任务,则可以使用较大的卷积核。

组卷积:

组卷积是对输入特征图进行分组,每组分别进行卷积,总参数量可以减少为原来的1/组数。

深度可分离卷积:

深度可分离卷积是组卷积的一种极端情况,即对每一个输入特征图的通道都用一个单独的卷积。

随机分组:

分组卷积会存在一个缺陷,那就是不同组之间没有信息交互,这样可能会导致即使某一组产生的特征提取特别垃圾,也依旧会往后传递,导致整个模型的效果也不能令人满意。随机分组则是把不同组之间打乱顺序,这样促进了不同组之间的信息交互,使得效果更好。

2.Pytorch实现GCN

gcc 架构 gcn架构优缺点_卷积核

from PIL import Image
import torch.nn as nn
from torchvision import models
import torch as t
resnet152_pretrained=models.resnet152(pretrained=True)

class GCM(nn.Module):
    def __init__(self,in_channels,num_class,k=15):
        super(GCM,self).__init__()

        pad=(k-1)//2

        self.conv1=nn.Sequential(nn.Conv2d(in_channels,num_class,kernel_size=(1,k),padding=(0,pad),bias=False),
                                 nn.Conv2d(num_class,num_class,kernel_size=(k,1),padding=(pad,0),bias=False)
                                 )
        self.conv2=nn.Sequential(nn.Conv2d(in_channels,num_class,kernel_size=(k,1),padding=(pad,0),bias=False),
                                 nn.Conv2d(num_class,num_class,kernel_size=(1,k),padding=(0,pad),bias=False)
                                 )
    def forward(self,x):
        x1=self.conv1(x)
        x2=self.conv2(x)

        assert x1.shape==x2.shape

        return x1+x2

class BR(nn.Module):
    def __init__(self,num_class):
        super(BR,self).__init__()

        self.shortcut=nn.Sequential(nn.Conv2d(num_class,num_class,3,padding=1,bias=False),
                                    nn.ReLU(),
                                    nn.Conv2d(num_class,num_class,3,padding=1,bias=False))
    def forward(self,x):
        return x+self.shortcut(x)

class GCN_BR_BR_Deconv(nn.Module):
    def __init__(self,in_channels,num_class,k=15):
        super(GCN_BR_BR_Deconv, self).__init__()

        self.gcn=GCM(in_channels,num_class,k)
        self.br=BR(num_class)

        self.deconv=nn.ConvTranspose2d(num_class,num_class,4,2,1,bias=False)

    def forward(self,x1,x2=None):
        x1=self.gcn(x1)
        x1=self.br(x1)

        if x2 is None:
            x=self.deconv(x1)
        else:
            x=x1+x2
            x=self.br(x)
            x=self.deconv(x)

        return x

class GCN(nn.Module):
    def __init__(self,num_classes,k=15):
        super(GCN,self).__init__()
        self.num_class=num_classes
        self.k=k


        self.layer0=nn.Sequential(resnet152_pretrained.conv1,resnet152_pretrained.bn1,resnet152_pretrained.relu)
        self.layer1=nn.Sequential(resnet152_pretrained.maxpool,resnet152_pretrained.layer1)
        self.layer2=resnet152_pretrained.layer2
        self.layer3=resnet152_pretrained.layer3
        self.layer4=resnet152_pretrained.layer4

        self.br=BR(self.num_class)
        self.deconv=nn.ConvTranspose2d(self.num_class,self.num_class,4,2,1,bias=False)

    def forward(self,input):
        x0=self.layer0(input)
        x1=self.layer1(x0)
        x2=self.layer2(x1)
        x3=self.layer3(x2)
        x4=self.layer4(x3)

        branch4=GCN_BR_BR_Deconv(x4.shape[1],self.num_class,self.k)#在前向传播中定义方便使用参数,而不用自己再计算
        branch3 = GCN_BR_BR_Deconv(x3.shape[1], self.num_class, self.k)
        branch2 = GCN_BR_BR_Deconv(x2.shape[1], self.num_class, self.k)
        branch1 = GCN_BR_BR_Deconv(x1.shape[1], self.num_class, self.k)

        branch4=branch4(x4)
        branch3=branch3(x3,branch4)
        branch2 = branch2(x2, branch3)
        branch1 = branch1(x1, branch2)

        x=self.br(branch1)
        x=self.deconv(x)
        x=self.br(x)

        return x

if __name__ == "__main__":
    rgb=t.randn(1,3,512,512)
    net=GCN(21)
    out=net(rgb)
    print(out.shape)