最近使用带有SE block的网络在pytorch框架下做训练。training loss 随着epoch增多不断下降,但是突然到某一个epoch出现loss为nan的情况,但是两三个epoch之后,loss竟然又恢复正常,而且下降了。
这几篇博客是我debug的借鉴,真的非常有用。
这篇介绍了出现nan的基本解决思路。

这篇介绍了为什么在多层dense layer之后某一层dense layer输出可能会出现nan,以及weight初始化的重要性及初始化方法。
https://yey.world/2020/12/17/Pytorch-14/ 我自己的基本排查思路是:

  1. 首先检查数据中是否有inf 或者nan的情况。

普通numpy数组可用

np.all(np.isfinite(data))

我自己用的是医疗图像nifiti格式的压缩图像,所以使用代码如下:

def check_image(img_fname: str):
	#检查图像/numpy数组中是否有nan存在
    npy = sitk.GetArrayFromImage(sitk.ReadImage(img_fname))
    return np.any(np.isnan(npy))

def check_image_inf(img_fname: str):
	#检查图像/numpy数组中数据是否都是finit的(提示:若使用此函数,单独检查nan的函数可不用)
    npy = sitk.GetArrayFromImage(sitk.ReadImage(img_fname))
    return np.all(np.isfinite(npy))

def check_for_nan(input_folder: str):
    nii_files = subfiles(input_folder, suffix='.nii.gz')
    for n in nii_files:
        if check_image(n):
            print("nans found in ", n)
        elif not check_image_inf(n):
            print("infs found in ", n)
            
img_fold = '' #保存数据的文件夹
check_for_nan(img_fold)

如果数据没有问题,检查数据是否有normalization,如果没有归一化也可能出现nan或者网络层计算中出现infinit。

  1. 检查使用loss是否带有除法,算log的时候有负数或者很小的数。
    我所用的检查loss是否为nan的方法:
assert torch.isnan(loss).sum() == 0 and torch.isinf(loss).sum() == 0, ('loss is nan or ifinit', loss)

如果loss中有infit或者nan,则会输出

'loss is nan or ifinit', loss(这里会输出loss的值)

如果确认loss也并没有问题,那么问题可能出现在forward path中。

  1. 检查forward path每一层的输出结果,进行问题定位。在每一层后加入:
assert torch.isnan(out).sum() == 0 and torch.isinf(out).sum() == 0, ('output of XX layer is nan or infinit', out.std()) #out 是你本层的输出 out.std()输出标准差

如果是某一层计算出问题,考虑是不是初始化函数没有使用或者用得不对。
接下来,就到我的检查血泪史了。我发现我SE block的nn.AdaptiveAvgPool3d(1)输出中有inf!!!奇怪的是,这一次层的输入却没这个问题。后来我直接用了torch.nn.AvgPool3d代替前面的函数,它终于正常跑起来了。

b, c, D, H, W = x.size()  #b: batch size, c: channels, (D, H, W) data shape
y = torch.nn.AvgPool3d((D,H,W), padding=0)(x)
y = y.view(b, c) # 将y变成shape为(b,c)的tensor, 后面接全连接层

一周的辛酸血泪,终于好了。
具体nn.AdaptiveAvgPool3d()函数的内部实现我还没来得及研究,之后搞明白了再来分享到底为啥出问题。