师兄说学目标检测之前先学分类
坏了,内容好多!学学学
感谢up主,好人一生平安
混淆矩阵
- 什么是混淆矩阵:
- 横坐标:每一列属于该类的所有验证样本。每一列所有元素对应真实类别。
- 纵坐标:网络的预测类别。每一行对应预测结果属于该类的所有样本。
- 对角线:预测正确的样本个数。
- 预测值在对角线上分布的越密集,模型的性能就越好。
还能通过混淆矩阵看到这个网络对哪些类别更容易分类出错。 - 混淆矩阵的指标:
- 精确率precision不等于准确率accuracy!!
准确率:所有预测正确的样本个数 / 所有用于验证的样本个数
(对角线上所有数据之和 / 混淆矩阵所有数据之和) - 二分类简单示例:
- 每一列:预测值标签;
每一列:真实值标签。
TP、TN 都代表网络预测正确的部分。(越大越好)
FP、FN 都代表网络预测错误的部分。(越小越好)
- 准确率、精确率、灵敏度/召回率、特异度
准确率:对所有类别的统计
精确率、灵敏度/召回率、特异度:针对某个类别 - 实例:
以猫为例,可以把狗和猪的类别混在一起,统一整合为不为猫的情况。得到混淆矩阵: - 参考博文:
计算混淆矩阵与相关指标
这里使用numpy进行统一计算(可以同时在TensorFlow和pytorch中使用):
- 制定一个类:ConfusionMatrix
若图像显示不正常,就升级matplotlib
prettytable:将输出展示成列表的形式 - 初始化函数:init
- 1:传入了两个变量:分类网络的分类类别个数num_classes、分类标签列表:labels。
2:初始化一个行数和列数相等且均为num_classes的正方形的、值为零的矩阵。
3:将num_classes赋值给类变量num_classes。 - 第一个实现的类:update
- 1:预测值preds、真实标签labels
2:累加到混淆矩阵中。将预测、真实标签打包组合,进行遍历。
p:预测值;t:真实类别标签
3:矩阵[预测值(行),真实值(列)],[第t行,第p列] - 第一个实现的类:summary
统计计算各个指标
①准确率: - 遍历,0~num_classes-1,
统计对角线上的元素和,
计算acc值
②:计算每个类别的精确率、召回率、特异度 - (库)prettytable:将输出展示成列表的形式
初始化一张表table,
在第一行添加一些描述信息,
使用for i in range遍历每一个类别,
对于第i个类别,TP(true positive):对角线上元素m[i],
FP(false positive):这一行的所有元素之和(第i行)-TP,
FN(false negative):这一列的和(第i列)-TP - round:小数部分只取三位。
- 将当前类别信息添加到刚刚初始化的table里面
- 类别标签,precision,recall,specificity
- 绘制混淆矩阵:plot
- 1 将matrix赋值给matrix
2 打印混淆矩阵
3 使用imshow函数展示混淆矩阵。颜色变换:从白色到蓝色。 - 4 对于label:默认是0、1、2、3这种坐标。但是希望它展示的是标签的类别。
使用xticks,将原来x轴的信息(0~num_classes-1)替换成为labels,对x轴旋转45°
5 y轴同理
6 混淆矩阵右侧像色谱一样的colorbar。数值的密集程度,颜色越深,数值就越密集。 - 7 横坐标 True labels
8 纵坐标 predicted labels
9 图像标题 Confusion matrix
将每个区域的数值标注在图像上
1 设置阈值,指定数字文本的颜色。取matrix最大数值的一半
2 遍历x坐标(显示图像的时候,坐标原点一般在图像的左上角),x从左到右,y从上到下。
3 遍历y坐标
4 对每一个坐标,获取它的matrix信息**[y,x]!!!不是[x,y]!!!**。 取整,得到当前位置的统计个数
5 通过text方法,将info绘制在[x,y]坐标处。
6、7 绘制在水平方向、竖直方向的中心位置处。
8 color对应数字的颜色,大于阈值:白色
9 让图形显示更加紧凑,否则部分信息可能被遮挡
10 展示混淆矩阵
使用pytorch计算分类模型的混淆矩阵
1 判断设备,是否使用GPU
2 打印设备信息
3 之前训练目标net网络时,使用针对验证集的一个处理方式,直接使用了当时已经处理好的模型权重,所以此处要与它保持相同的预处理方式。
5 使用花分类数据集的验证集
6 dataloader载入验证集
1 实例化网络MobileNetV2(之前写过的,这里拿来继续用)
2 、3 载入之前已经训练好的MobileNetV2的模型权重
4 将模型连到设备上去
1、2 载入之前生成的json文件(对应着索引与类别信息)。读入后提取出所有的标签信息。载入后是字典形式,而我们只需要它的标签
4 label for_,label in class_indect.items():不要key,只要value
5 实例化上面定义的ConfusionMatrix类
6 启动验证模式
7 上下文管理器 torch.no_grad(),停止pytorch对变量梯度的跟踪
8 遍历dataloader数据集
9 分为图片,标签
10 把图片储存到设备中,输入网络,得到输出
11 softmax处理
12 通过argmax得到最大的元素
13 调用.update()方法输入预测值(outputs.numpy())、真实标签值(val_labels.numpy())
14 plot 绘制混淆矩阵
15 打印各个指标信息