当训练任务结束,常常需要评价函数(Metrics)来评估模型的好坏。不同的训练任务往往需要不同的Metrics函数。例如,对于二分类问题,常用的评价指标有precision(准确率)、recall(召回率)等,而对于多分类任务,可使用宏平均(Macro)和微平均(Micro)来评估。

MindSpore提供了大部分常见任务的评价函数,如Accuracy、Precision、MAE和MSE等,由于MindSpore提供的评价函数无法满足所有任务的需求,很多情况下用户需要针对具体的任务自定义Metrics来评估训练的模型。

本章主要介绍如何自定义Metrics以及如何在mindspore.train.Model中使用Metrics。

自定义Metrics

自定义Metrics函数需要继承mindspore.train.Metric父类,并重新实现父类中的clear方法、update方法和eval方法。

  • clear:初始化相关的内部参数。
  • update:接收网络预测输出和标签,计算误差,每次step后并更新内部评估结果。
  • eval:计算最终评估结果,在每次epoch结束后计算最终的评估结果。

平均绝对误差(MAE)算法如式(1)所示:

昇思MindSpore学习入门-评价指标_自定义

下面以简单的MAE算法为例,介绍clear、update和eval三个函数及其使用方法。

昇思MindSpore学习入门-评价指标_评价指标_02

模型训练中使用Metrics

mindspore.train.Model是用于训练和评估的高层API,可以将自定义或MindSpore已有的Metrics作为参数传入,Model能够自动调用传入的Metrics进行评估。

在网络模型训练后,需要使用评价指标,来评估网络模型的训练效果,因此在演示具体代码之前首先简单拟定数据集,对数据集进行加载和定义一个简单的线性回归网络模型:

昇思MindSpore学习入门-评价指标_自定义_03

昇思MindSpore学习入门-评价指标_评价指标_04

使用内置评价指标

使用MindSpore内置的Metrics作为参数传入Model时,Metrics可以定义为一个字典类型,字典的key值为字符串类型,字典的value值为MindSpore内置的评价指标,如下示例使用train.Accuracy计算分类的准确率。

昇思MindSpore学习入门-评价指标_评价指标_05

使用自定义评价指标

如下示例在Model中传入上述自定义的评估指标MAE(),将验证数据集传入model.fit()接口边训练边验证。

验证结果为一个字典类型,验证结果的key值与metrics的key值相同,验证结果的value值为预测值与实际值的平均绝对误差。

昇思MindSpore学习入门-评价指标_数据集_06