https://arxiv.org/pdf/2105.01601
----------------------------------------------------------
2021-09-02
感知机:判别模型 线性二分类
token-mixing:作用于列,混合提炼不同patch的特征 depth-wise conv
channel-mixing:作用于行,混合提炼不同channel的特征 1*1卷积
class PreNormResidual(nn.Module): def __init__(self,dim,fn): super(PreNormResidual, self).__init__() self.fn=fn self.norm=nn.LayerNorm(dim) def forward(self,x): return self.fn(self.norm(x))+x def FeedForward(dim, expansion_factor=4, dropout=0, dense=nn.Linear): return nn.Sequential( dense(dim,dim*expansion_factor), nn.GELU(), nn.Dropout(dropout), dense(dim*expansion_factor,dim), nn.Dropout(dropout) ) def MLPMixer(*,image_size,channels,patch_size,dim,depth,num_classes, expansion_factor=4, dropout=0): num_patches=(image_size//patch_size)**2 chan_first,chan_last=partial(nn.Conv1d,kernel_size=1),nn.Linear return nn.Sequential( Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)",p1=patch_size,p2=patch_size), nn.Linear((patch_size**2)*channels,dim), *[ nn.Sequential( PreNormResidual(dim,FeedForward(num_patches,expansion_factor,dropout,chan_first)), PreNormResidual(dim,FeedForward(dim,expansion_factor,dropout,chan_last)) ) for _ in range(depth)], nn.LayerNorm(dim), Reduce("b n c -> b c","mean"), nn.Linear(dim,num_classes) )