这次涉及到了图像分类的核心内容,在本地进行模型训练,最近事情太多,没有时间去建立新的数据集,选择了开源的fruit30数据集。

首先,我们需要载入数据集,使用常用的ImageFolder()函数,载入各类别的图像,并将类别对应到索引号上,方便后期使用。

然后,定义数据加载器DataLoader,将一个一个的batch喂到模型中进行训练。

最重要的一步,也就是在Imagenet训练好的模型基础上进行迁移学习,很多情况下,我们并不会从头到尾进行模型训练,这样会牵扯到更多的训练时间与资源,可以在一些比较好的模型上利用我们的数据集进行改进,本次使用的fruit30数据集很大部分已经包含在了Imagenet里面,所以我们只需在最后的全连接层进行微调,将全连接层的输出与当前数据集类别数对应。本次使用交叉熵损失函数,规定训练轮次为20,每个轮次一次获得一个又一个batch的标注和数据,进行损失函数计算后就可以开始1.清除梯度;2.反向传播;3.优化更新三部曲。训练完成后可以使用测试集进行模型评估。

当然,我们还可以对训练进行可视化操作,可以使用wandb对模型进程进行记录,并在网站上实时显示,也可以在训练期间记录下模型准确度、损失函数大小等数据保存下来,便于分析是否出现过拟合等情况。就像下图所示,将参数绘制成图,非常便于对模型进行评估。

RUSBOOST图像分类 图像分类流程_数据集

除此之外还有一些需要注意的事项:
1.千万不要把训练集里面的数据拿到测试集中或者把测试集的数据拿到训练集,这样会很容易导致出现虚假的准确度或者过拟合现象。
2.模型最后评估出的准确度仅仅是一种参考,影响准确度的因素有很多,数据集的完善情况、损失函数的选择、训练方式的选择、预训练模型的选择等都会有影响,所以想要训练出更好的模型,这些因素都要考虑到并进行相应的调整。