通常情况下,定义训练和评估网络并直接运行,已经可以满足基本需求。

一方面,Model可以在一定程度上简化代码。例如:无需手动遍历数据集;在不需要自定义nn.TrainOneStepCell的场景下,可以借助Model自动构建训练网络;可以使用Model的eval接口进行模型评估,直接输出评估结果,无需手动调用评价指标的clear、update、eval函数等。

另一方面,Model提供了很多高阶功能,如数据下沉、混合精度等,在不借助Model的情况下,使用这些功能需要花费较多的时间仿照Model进行自定义。

本文档首先对MindSpore的Model进行基本介绍,然后重点讲解如何使用Model进行模型训练、评估和推理。

昇思MindSpore学习入门-高阶封装_神经网络

Model基本介绍

Model是MindSpore提供的高阶API,可以进行模型训练、评估和推理。其接口的常用参数如下:

  • network:用于训练或推理的神经网络。
  • loss_fn:所使用的损失函数。
  • optimizer:所使用的优化器。
  • metrics:用于模型评估的评价函数。
  • eval_network:模型评估所使用的网络,未定义情况下,Model会使用network和loss_fn进行封装。

Model提供了以下接口用于模型训练、评估和推理:

  • fit:边训练边评估模型。
  • train:用于在训练集上进行模型训练。
  • eval:用于在验证集上进行模型评估。
  • predict:用于对输入的一组数据进行推理,输出预测结果。

使用Model接口

对于简单场景的神经网络,可以在定义Model时指定前向网络network、损失函数loss_fn、优化器optimizer和评价函数metrics

下载并处理数据集

使用download库下载数据集,通过 vison.Rescale 接口对图片进行缩放, vision.Normalize 接口对输入图片进行归一化处理, vision.HWC2CHW 接口对数据格式进行转换。

昇思MindSpore学习入门-高阶封装_数据集_02

 

创建模型

关于模型创建的讲解可以参考 网络构建 。

定义损失函数和优化器

要训练神经网络模型,需要定义损失函数和优化器函数。

  • 损失函数这里使用交叉熵损失函数CrossEntropyLoss。
  • 优化器这里使用SGD。

训练及保存模型

在开始训练之前,MindSpore需要提前声明网络模型在训练过程中是否需要保存中间过程和结果,因此使用ModelCheckpoint接口用于保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。

通过MindSpore提供的model.fit接口可以方便地进行网络的训练与评估,LossMonitor可以监控训练过程中loss值的变化。

训练过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。

通过模型运行测试数据集得到的结果,验证模型的泛化能力:

  1. 使用model.eval接口读入测试数据集。
  2. 使用保存后的模型参数进行推理。

 

昇思MindSpore学习入门-高阶封装_损失函数_03

昇思MindSpore学习入门-高阶封装_损失函数_04