一、什么是k折交叉验证?

在训练阶段,我们一般不会使用全部的数据进行训练,而是采用交叉验证的方式来训练。交叉验证(Cross Validation,CV)是机器学习模型的重要环节之一。它可以增强随机性,从有限的数据中获得更全面的信息,减少噪声干扰,从而缓解过拟合,增强模型的泛化能力。

比赛一般会只给我们训练集,但是测试集我们是看不到的,所以我们一般会将训练集按照一定的方式划分为训练集和验证集。训练集用于模型的训练,验证集用于本地验证,选取最好的pt权重文件,再提交到比赛官网进行测试集的验证。所以如何划分训练集和验证集,让我们最大限度的利用训练集,学习有效的特征,是至关重要的。交叉验证就是做这个事的。

交叉验证步骤:

  1. 将整个数据集划分为大小相等的K个部分;
  2. 每次选取其中一份作为验证集,其余K-1份作为训练集进行训练;
  3. 重复K次,直至每一份数据都被当作验证集验证了一遍;
  4. 模型的最终精度是通过K个子模型的平均精度来计算的;

下面这个图可以比较好的诠释上面这个过程:

深度学习 五折交叉验证 五折交叉验证的目的_交叉验证

我们一般不会自己实现这个功能,一般都是调用SKLearn包直接使用,SKlearn帮我们实现了KFold、Stratified KFold、Group KFold和Stratified Group KFold四种方式,下面我一一介绍它们的区别和用法。

二、常见的几种交叉验证方式

2.1、KFold

KFold是最简单的一种K折交叉验证,它的具体步骤如下图所示是一个4折交叉验证,橘色代表验证集(1份),蓝色代表训练集(3份),整个数据集有三个类别(对应图中三种颜色的分布情况);这些数据属于很多个不同的组;

深度学习 五折交叉验证 五折交叉验证的目的_kaggle比赛_02


可以看的很清楚,这种K折交叉验证,有两个缺点:

  1. 不适应于数据集样本不均衡的情况,因为很可能会把整个少数的类别划分为验证集或训练集;
  2. 不适应于时间序列问题;

2.2、Stratified(分层) KFold

上面讲到,KFold不适应于数据不平衡的问题,所以Stratified KFold(分层)交叉验证就是专门来解决这个问题的。如下图,在分层交叉验证中,数据集依然被划分为K组,但是验证组的目标类别是从各个类中分层抽取出来的,是均匀的,所以就不会存在少数类别被全部划分为验证集或训练集。

深度学习 五折交叉验证 五折交叉验证的目的_Group_03


特点:可以解决数据不平衡问题,但是不适应于时间序列问题。

2.3、Group (分组)KFold

GroupKFold是KFold一个变体,目的在于将group严格分开,就是说同一个group的数据只能出现在训练集或者验证集,不能同时出现在训练集和验证集,如下图:

深度学习 五折交叉验证 五折交叉验证的目的_数据_04

特点:可以将数据的group完全分开,避免高度相似的样本既出现在训练集又测试在验证集。

2.4、Stratified Group KFold

Group KFold和Stratified KFold的合体,如下图:

深度学习 五折交叉验证 五折交叉验证的目的_交叉验证_05


特点:可以将数据的group和标签的class完全分层划开,避免出现样本高度相似和标签分布不均的问题。

2.5、Time Series Split

可以解决时间序列相关的问题。对于时间序列数据集,根据时间将数据分为训练和验证,也称为前向链接方法或滚动交叉验证。

深度学习 五折交叉验证 五折交叉验证的目的_Group_06

使用方式举例:

skf = StratifiedGroupKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['empty'], groups = df["case"])):

三、什么是TTA?

TTA,即Test time augmention,测试时增强。数据增强一般是出现在训练阶段,使用数据增强一般都能提升性能。而测试时数据增强是指在测试的时候,将原图进行数据增强(比如水平翻转、垂直翻转、对角线翻转、旋转等,这里假设使用了3种数据增强),可以得到4张测试图片,对这四张测试图片分布进行推理,得到推理结果。再对三张增强后的推理结果再变换回来(比如我对原图进行水平翻转,得到的mask,再对mask进行水平翻转)。最后就得到了4张预测结果,对这四张预测结果mask对应位置相加取平均,就得到了最终的mask预测果。

使用方式举例:

model = build_model(CFG, test_flag=True)
        model.load_state_dict(torch.load(sub_ckpt_path))
           model.eval()
           y_preds = model(images) # [b, c, w, h]
           y_preds   = torch.nn.Sigmoid()(y_preds)
           masks += y_preds

           #x,y,xy flips as TTA
           if CFG.tta:
               flips = [[-1]]  # 水平翻转
               for f in flips:
                   images_f = torch.flip(images, f)
                   y_preds = model(images_f) # [b, c, w, h]
                   y_preds = torch.flip(y_preds, f)
                   y_preds   = torch.nn.Sigmoid()(y_preds)
                   masks += y_preds

        if CFG.tta:
            total_ckpt_paths = len(ckpt_paths_dict) * CFG.n_fold * 2
        else:
            total_ckpt_paths = len(ckpt_paths_dict) * CFG.n_fold