1. 功能描述:
MindSpore实现WeightNorm参数归一化。
2. 实现分析:
在MindSpore实现高性能方案,建议采用图模式,同时也能保证动静统一。MindSpore图模式需要把归一化操作表达到整图里,可以采用自定义的方式在网络结构中实现。
3. 参数归一化功能简介(背景介绍):
在深度学习中通常对卷积层的权重进行参数归一化,参数归一化功能根据以下公式对传入的 layer 中的权重参数进行归一化:
公式中w是网络权重,g 代表长度变量 ,v代表方向变量。权重归一化可以将神经网络中权重w的向量长度g与其方向v解耦,将w用g和v两个变量表示。 (例如:详细可以参考论文: https://arxiv.org/pdf/1602.07868.pdf。)
4. 解决方案:
实现MindSpore的WeightNorm需要注意:
- 4.1 MindSpore实现时,需要封装一个Wrapper,将WeightNorm和需要进行参数归一化的网络结构(如卷积)封装为一个整体,这样每次在卷积执行之前,就会先执行WeightNorm。具体伪代码如下:
class WeightNorm(nn.Cell):
def __init__(self):
...
register_w_v_g()
self.layer = layer
def construct(self, inputs):
compute_weight_norm()
result = self.layer(inputs)
return result
- 4.2 使用参数归一化需要能够添加和删除weight norm,但MindSpore静态图编译后无法删除Weight Norm
remove_weight_norm的场景:
- 4.2.1 inference,即推理阶段需要移除Weight Norm。
- 4.2.2 进行一次Weight Norm计算,然后固定w(WeightNorm.remove()的执行逻辑) remove_weight_norm的使用场景,即模型进行推理时,在加载Checkpoint后进行操作,此时未涉及到静态图的编译阶段,因此可以对实例化的模型进行任意修改。 PS: 静态图不支持在训练过程中移除weight norm。
MindSpore WeightNorm示例:
class WeightNorm (nn.Cell):
def __init__(self, module, dim:int=0):
super().__init__()
if dim is None:
dim = -1
self.dim = dim
self.module = module
self.assign = P.Assign()
# add g and v as new parameters and express w as g/||v|| * v
self.param_g = Parameter(Tensor(norm_except_dim(self.module.weight, 2, dim)))
self.param_v = Parameter(Tensor(self.module.weight.data))
self.module.weight.set_data(_weight_norm(self.param_v, self.param_g, self.dim))
self.use_weight_norm = True
def construct(self, *inputs, **kwargs):
if not self.use_weight_norm:
return self.module(*inputs, **kwargs)
self.assign(self.module.weight, _weight_norm(self.param_v, self.param_g, self.dim))
return self.module(*inputs, **kwargs)
def remove_weight_norm(self):
self.assign(self.module.weight, _weight_norm(self.param_v, self.param_g, self.dim))
self.use_weight_norm = False
- 4.3 use_weight_norm可以达到移除WeightNorm的目的。即调用remove_weight_norm方法后,将self.use_weight_norm设置为False,当再次construct函数时,就会直接调用self.module,忽略Weight Norm计算。
- 4.4 self.param_g = Parameter(Tensor(norm_except_dim(self.module.weight, 2, dim))) 实现 w和 ||v|| 的计算,静态图不支持getattr方法,考虑到MindSpore的nn层设计,就固定module的权重为module.weight。
def norm_except_dim(v, pow, dim):
if dim == -1:
return mnp.norm(v, pow)
elif dim == 0:
output_size = (v.shape[0],) + (1,) * (v.ndim - 1)
return mnp.norm(v.view((v.shape[0], -1)), pow, 1).view(output_size)
elif dim == (v.ndim - 1):
output_size = (1,) * (v.ndim - 1) + (v.shape[v.ndim - 1])
return mnp.norm(v.view((-1, v.shape[v.ndim - 1])), pow, 0).view(output_size)
else:
return norm_except_dim(v.swapaxes(0, dim), pow, dim).swapaxes(0,dim)
def _weight_norm(v, g, dim):
return v * (g / norm_except_dim(v, 2, dim))
- 4.5 上述代码WeightNorm中,self.module.weight是要进行归一化的网络权重,self.param_g是长度变量,self.param_v是方向变量, 其中norm_except_dim函数用于计算指定维度的长度。
5. MindSpore的WeightNorm简单使用方式
# assume we need apply weight norm on nn.Dense layer
m = WeightNorm(nn.Dense(20, 40))
# m.param_g.shape is (40, 1)
# m.param_v.shape is (40, 20)
# use m as normal nn.Dense
inputs = Tensor(np.random.randn(10, 20), mstype.float32)
outputs = m(inputs)
# if you want to remove weight norm, just call it
m.remove_weight_norm()
# m.use_weight_norm == False