1. 功能描述:

MindSpore实现WeightNorm参数归一化。

2. 实现分析:

在MindSpore实现高性能方案,建议采用图模式,同时也能保证动静统一。MindSpore图模式需要把归一化操作表达到整图里,可以采用自定义的方式在网络结构中实现。

3. 参数归一化功能简介(背景介绍):

在深度学习中通常对卷积层的权重进行参数归一化,参数归一化功能根据以下公式对传入的 layer 中的权重参数进行归一化: 

归一化权重原理_归一化

 公式中w是网络权重,g 代表长度变量 ,v代表方向变量。权重归一化可以将神经网络中权重w的向量长度g与其方向v解耦,将w用g和v两个变量表示。 (例如:详细可以参考论文: https://arxiv.org/pdf/1602.07868.pdf。)

4. 解决方案:

实现MindSpore的WeightNorm需要注意:

  • 4.1 MindSpore实现时,需要封装一个Wrapper,将WeightNorm和需要进行参数归一化的网络结构(如卷积)封装为一个整体,这样每次在卷积执行之前,就会先执行WeightNorm。具体伪代码如下:
class WeightNorm(nn.Cell):
    def __init__(self):
        ...
        register_w_v_g()
        self.layer = layer

    def construct(self, inputs):
        compute_weight_norm()
        result = self.layer(inputs)
        return result
  • 4.2 使用参数归一化需要能够添加和删除weight norm,但MindSpore静态图编译后无法删除Weight Norm

remove_weight_norm的场景:

  - 4.2.1 inference,即推理阶段需要移除Weight Norm。

  - 4.2.2 进行一次Weight Norm计算,然后固定w(WeightNorm.remove()的执行逻辑) remove_weight_norm的使用场景,即模型进行推理时,在加载Checkpoint后进行操作,此时未涉及到静态图的编译阶段,因此可以对实例化的模型进行任意修改。 PS: 静态图不支持在训练过程中移除weight norm。

MindSpore  WeightNorm示例:

class WeightNorm (nn.Cell):

      def __init__(self, module, dim:int=0):
           super().__init__()

           if dim is None:
               dim = -1

          self.dim = dim
          self.module = module
          self.assign = P.Assign()
          # add g and v as new parameters and express w as g/||v|| * v
          self.param_g = Parameter(Tensor(norm_except_dim(self.module.weight, 2, dim)))
          self.param_v = Parameter(Tensor(self.module.weight.data))
          self.module.weight.set_data(_weight_norm(self.param_v, self.param_g, self.dim))
          self.use_weight_norm = True

     def construct(self, *inputs, **kwargs):
           if not self.use_weight_norm:
               return self.module(*inputs, **kwargs)
          self.assign(self.module.weight, _weight_norm(self.param_v, self.param_g, self.dim))
              return self.module(*inputs, **kwargs)

     def remove_weight_norm(self):
          self.assign(self.module.weight, _weight_norm(self.param_v, self.param_g, self.dim))
          self.use_weight_norm = False
  • 4.3 use_weight_norm可以达到移除WeightNorm的目的。即调用remove_weight_norm方法后,将self.use_weight_norm设置为False,当再次construct函数时,就会直接调用self.module,忽略Weight Norm计算。
  • 4.4 self.param_g = Parameter(Tensor(norm_except_dim(self.module.weight, 2, dim))) 实现 w和 ||v|| 的计算,静态图不支持getattr方法,考虑到MindSpore的nn层设计,就固定module的权重为module.weight。
def norm_except_dim(v, pow, dim):
    if dim == -1:
       return mnp.norm(v, pow)
    elif dim == 0:
       output_size = (v.shape[0],) + (1,) * (v.ndim - 1)
       return mnp.norm(v.view((v.shape[0], -1)), pow, 1).view(output_size)
    elif dim == (v.ndim - 1):
       output_size = (1,) * (v.ndim - 1) + (v.shape[v.ndim - 1])
       return mnp.norm(v.view((-1, v.shape[v.ndim - 1])), pow, 0).view(output_size)
    else:
       return norm_except_dim(v.swapaxes(0, dim), pow, dim).swapaxes(0,dim)

def _weight_norm(v, g, dim):
    return v * (g / norm_except_dim(v, 2, dim))
  • 4.5 上述代码WeightNorm中,self.module.weight是要进行归一化的网络权重,self.param_g是长度变量,self.param_v是方向变量, 其中norm_except_dim函数用于计算指定维度的长度。

5. MindSpore的WeightNorm简单使用方式

# assume we need apply weight norm on nn.Dense layer
   m = WeightNorm(nn.Dense(20, 40))
   # m.param_g.shape is (40, 1)
   # m.param_v.shape is (40, 20)
   # use m as normal nn.Dense

   inputs = Tensor(np.random.randn(10, 20), mstype.float32)
   outputs = m(inputs)

   # if you want to remove weight norm, just call it
   m.remove_weight_norm()
   # m.use_weight_norm == False