这个方法更好的解决了模型过拟合问题。

EarlyStopping的原理是提前结束训练轮次来达到“早停“的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch)。

首先,我们需要一个一个标识,可以采用'val_acc’、’val_loss’等等,这些量在每一个轮次中都会不断更新自己的值,也和模型的参数息息相关,所以我们想通过他们间接操作模型参数。以val_loss来说,当模型训练时可能会出现当val_loss到一定值的时候会出现回弹的情况,所以我们希望在他回弹之前结束模型的训练。 

早停法其实一共有3类停止标准,这里我们选用最简单的一种入门。话不多说,上代码!!!

import numpy as np
import torch

导入两个最基本的包就行,因为早停法是一种可以自己就写出来的算法!!!

参数有5个:

第一个patience:这个是当有连续的patience个轮次数值没有继续下降,反而上升的时候结束训练的条件(以val_loss为例)

第二个verbose:这个其实就是是否print一些值,可也不传参,因为他有默认值

第三个delta:这个就是控制对比是的”标准线“

第四个path:这个是权重保存路径,早停法会在每一轮次次产生最优解(就是val_loss继续减少)的时候保存当前的模型参数。注:只要保存路径不变,每一次保存在文件里面的参数都会覆盖上一次保存在文件里面的参数。

第五个trace_func:这个就是显示每一个轮次变化的数值的方式,默认print,也可以改成进度条显示(tqdm的对象)

class EarlyStopping:
   
    def __init__(self, patience=7, verbose=False, delta=0, path='weight7-stop.pth', trace_func=print):
       
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

重点就在中间那个__call__方法里面,比较的是这一轮的val_loss和之前最好的val_loss(可以加上一个数实现‘标准线’的‘上移’或者‘下移’)

实际应用与项目当中

这是我再积水检测项目中的代码的一部分。

我设置了patience为7.

epoch为200。(这个推荐小一点,因为太大没有意义,一定会过拟合的)

pytorch代码提前停止训练 pytorch如何暂停训练_深度学习

pytorch代码提前停止训练 pytorch如何暂停训练_人工智能_02

 注:本文使用的早停法源代码不是原创,取自github。