归总一下从头复现论文或是自己写网络模型进行训练时遇到过的一些问题。

主要包括train loss不收敛train loss收敛但是val acc低train loss收敛但是仍然不够低(训练技巧)

本文主要从这三个问题入手,列出常见的问题可能性以及排查方式,以及容易被人忽略的一些原因,例如对BN处理对一些细节。

 

1. train loss不收敛,震荡

基础问题

(a) coding有问题。仔细检查代码,如果是tensorflow注意检查不同变量是否加入到了正确的GraphKey中。

(b)输入输出不正确或者有bad case。写一个datacheck,仔细检查每一对<输入,输出>是否都是正确对。

(c)模型结构设计有问题。简化模型,去掉一切技巧,只使用极少对data让模型去拟合,检查模型对合理性。

(d)loss设计有问题。同理,使用最简单对loss,例如cross entropy等让模型去学习。

(f)训练技巧问题。使用小学习率,使用warmup学习率。

(g)梯度消失,梯度爆炸。打印梯度,查看梯度是否正常,爆炸就加clip,消失则考虑其他技巧。

(h)正则化L1,L2权重过大。打印ckpt中每一层的均值与方差,权重过大导致模型过于简单,泛化能力不够。

 

进阶问题

(a)python对shallow copy改变了label的值。bad case很有可能是由于python对shallow copy对某些label做变换处理时直接改变了原始值,尤其是label里面有用到list,tuple等。

(b)输出位置加入了BN。模型最终输出是不可以加入BN。BN会将输出强制拉到0 1正态分布,导致模型什么都没有学习到。

(c)激活函数搭配的初始化方式不对。用不同的激活函数,relu,leakly relu,relu-6,p-relu,tanh。用不同的初始化方式,Normal,Xavir,He-norm。但是经验告诉我,模型加入了BN,对于初始化并没有太多要求,用Xavir即可应对。

(d)loss里面有log等函数输入<=0,有除数=0。当输入太小可能会超过float32精度被认为是0,将输入和一个非常小的数eps相加,避免这种情况。

(e)数据分布的var非常小,模型学习能力不够。对应于排除了基础问题(c),多训练一些epoch,如果仍然无法拟合则考虑重新设计一下网络结构或者对data做预处理

 

PS:

通常来说,如果模型是你自己写的,那么基础问题a,b也就是编程和输入输出存在问题是最有可能的。如果能排除,则检查c,简化模型通常也能找到问题,f,g,h其实都是对应于c提出来的。

当以上都不管用则要要考虑进阶问题了,其中a问题主要出现在关键点检测,目标检测,label使用list或者tuple保存的,list,tuple里面可能还有嵌套,由于python函数参数传递对于可变参数是址传递,而且list.copy这种属于浅拷贝,经过坐标点变化就很容易出现问题。BN也是一个很重要的层,但是绝不能用在输出层,会让模型学不到东西。激活函数和初始化方式其实影响不大,但是也不能排除有问题,导致权重学习偏了。最后一个可能性e则是说模型泛化性不够。

 

2. train loss收敛, val loss不收敛,val acc很低

基础问题

(a)过拟合。首先检查是否过拟合了,用train set的数据去做inference,如果acc很高则是过拟合。使用dropout,正则化,数据增强等方式缓解。

 

进阶问题

(a)BN中参数没有更新。打印某一层BN的参数,做infer的时候看看值是否有改变。tensorflow是图计算模型,Batch norm中的moving-mean和moving-var参数不是trainable,如果用train_op直接更新计算图,这两个参数会得不到更新。需要手动设置依赖,强制更新这两个参数。

(b)BN中参数没有保存。同上一条,tf.train.Saver默认只保存trainable参数,需要手动加入保存。

(c)BN没有warm up。训练的时候打印BN参数,看看moving-mean有没有震荡。由于BN在inference的时候是用整体估计局部,训练的时候用了一个很小的值去decay,需要用小学习率多训练一些epoch直BN层warmup

 

ps:

碰到这个问题首先考虑是否过拟合了。其次用了BN的话很容易忽略BN参数没有更新或者保存的问题。而BN需要进行warmup,也就是很多时候你的train loss很低了,不变了,但是继续训练后val的acc会继续提升的原因。

 

3. 进一步降低train loss,提升val acc

此部分内容偏向训练技巧,并不属于问题了

基础技巧

(a)做数据增强。如果获取不了更多数据就考虑做数据增强,提升数据容量,提升泛化性。

(b)检查数据增强合理性。某些数据增强并不适用于所有的任务,例如做关键点检测需要区分左右腿就不能做镜像翻转。

(c)学习率设置。使用warmup技巧,Plateau Detection技巧。

(d)超参数设置。使用不同超参数,属于调参技巧了

进阶技巧

(a)optimizer选择。可以使用Adam作为前期的优化器,训练到中间段之后用Momentum进行训练。

(b)对图像进行预处理。尽量不要直接resize,尤其是对于长宽比例不协调,模型输入又是1:1这种,考虑做预处理,padding等

(c)Assemble。多训练几个模型做assemble。