通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求。
一方面,Model可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义nn.TrainOneStepCell的场景下,可以借助Model自动构建训练网络;可以使用Model的eval接口进行模型评估,直接输出评估结果,无需手动调用评价指标的clear、update、eval函数等。
另一方面,Model提供了很多高阶功能,如数据下沉、混合精度等,在不借助Model的情况下,使用这些功能需要花费较多的时间仿照Model进行自定义。
本文档首先对MindSpore的Model进行基本介绍,然后重点讲解如何使用Model进行模型训练、评估和推理。
Model基本介绍
Model是MindSpore提供的高阶API,可以进行模型训练、评估和推理。其接口的常用参数如下:
- network:用于训练或推理的神经网络。
- loss_fn:所使用的损失函数。
- optimizer:所使用的优化器。
- metrics:用于模型评估的评价函数。
- eval_network:模型评估所使用的网络,未定义情况下,Model会使用network和loss_fn进行封装。
Model提供了以下接口用于模型训练、评估和推理:
- fit:边训练边评估模型。
- train:用于在训练集上进行模型训练。
- eval:用于在验证集上进行模型评估。
- predict:用于对输入的一组数据进行推理,输出预测结果。
使用Model接口
对于简单场景的神经网络,可以在定义Model时指定前向网络network、损失函数loss_fn、优化器optimizer和评价函数metrics。
下载并处理数据集
使用download库下载数据集,通过 vison.Rescale 接口对图片进行缩放, vision.Normalize 接口对输入图片进行归一化处理, vision.HWC2CHW 接口对数据格式进行转换。
创建模型
关于模型创建的讲解可以参考 网络构建 。
定义损失函数和优化器
要训练神经网络模型,需要定义损失函数和优化器函数。
- 损失函数这里使用交叉熵损失函数CrossEntropyLoss。
- 优化器这里使用SGD。
训练及保存模型
在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用ModelCheckpoint接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。
通过MindSpore提供的model.fit接口可以方便地进行网络的训练与评估,LossMonitor可以监控训练过程中loss值的变化。
训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。
通过模型运行测试数据集得到的结果,验证模型的泛化能力:
- 使用model.eval接口读入测试数据集。
- 使用保存后的模型参数进行推理。