师兄说学目标检测之前先学分类

坏了,内容好多!学学学

感谢up主,好人一生平安

python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀

混淆矩阵

  1. 什么是混淆矩阵
  2. python 代码在线混淆 膨胀 pytorch 混淆矩阵_分类_02

  3. 横坐标:每一列属于该类的所有验证样本。每一列所有元素对应真实类别。
  4. python 代码在线混淆 膨胀 pytorch 混淆矩阵_混淆矩阵_03

  5. 纵坐标:网络的预测类别。每一行对应预测结果属于该类的所有样本。
  6. python 代码在线混淆 膨胀 pytorch 混淆矩阵_学习_04

  7. 对角线:预测正确的样本个数。
  8. python 代码在线混淆 膨胀 pytorch 混淆矩阵_人工智能_05

  9. 预测值在对角线上分布的越密集,模型的性能就越好。
    还能通过混淆矩阵看到这个网络对哪些类别更容易分类出错。
  10. 混淆矩阵的指标
  11. python 代码在线混淆 膨胀 pytorch 混淆矩阵_混淆矩阵_06

  12. 精确率precision不等于准确率accuracy!!
    准确率:所有预测正确的样本个数 / 所有用于验证的样本个数
    (对角线上所有数据之和 / 混淆矩阵所有数据之和)
  13. 二分类简单示例
  14. python 代码在线混淆 膨胀 pytorch 混淆矩阵_人工智能_07

  15. 每一列:预测值标签;
    每一列:真实值标签。

TP、TN 都代表网络预测正确的部分。(越大越好)
FP、FN 都代表网络预测错误的部分。(越小越好)

  1. 准确率、精确率、灵敏度/召回率、特异度

    准确率:对所有类别的统计
    精确率、灵敏度/召回率、特异度:针对某个类别
  2. 实例

    以猫为例,可以把狗和猪的类别混在一起,统一整合为不为猫的情况。得到混淆矩阵:

  3. 参考博文

计算混淆矩阵与相关指标

这里使用numpy进行统一计算(可以同时在TensorFlow和pytorch中使用):

python 代码在线混淆 膨胀 pytorch 混淆矩阵_混淆矩阵_08

  1. 制定一个类:ConfusionMatrix
    若图像显示不正常,就升级matplotlib
    prettytable:将输出展示成列表的形式
  2. python 代码在线混淆 膨胀 pytorch 混淆矩阵_人工智能_09


  3. python 代码在线混淆 膨胀 pytorch 混淆矩阵_分类_10


  4. python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_11


  5. python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_12

  6. 初始化函数:init
  7. python 代码在线混淆 膨胀 pytorch 混淆矩阵_学习_13

  8. 1:传入了两个变量:分类网络的分类类别个数num_classes、分类标签列表:labels。
    2:初始化一个行数和列数相等且均为num_classes的正方形的、值为零的矩阵。
    3:将num_classes赋值给类变量num_classes。
  9. 第一个实现的类:update
  10. python 代码在线混淆 膨胀 pytorch 混淆矩阵_分类_14

  11. 1:预测值preds、真实标签labels
    2:累加到混淆矩阵中。将预测、真实标签打包组合,进行遍历。
    p:预测值;t:真实类别标签
    3:矩阵[预测值(行),真实值(列)],[第t行,第p列]
  12. 第一个实现的类:summary
    统计计算各个指标
    ①准确率:
  13. python 代码在线混淆 膨胀 pytorch 混淆矩阵_分类_15

  14. 遍历,0~num_classes-1,
    统计对角线上的元素和,
    计算acc值
    ②:计算每个类别的精确率、召回率、特异度
  15. python 代码在线混淆 膨胀 pytorch 混淆矩阵_混淆矩阵_16

  16. (库)prettytable:将输出展示成列表的形式
    初始化一张表table,
    在第一行添加一些描述信息,
    使用for i in range遍历每一个类别,
    对于第i个类别,TP(true positive):对角线上元素m[i],
    FP(false positive):这一行的所有元素之和(第i行)-TP,
    FN(false negative):这一列的和(第i列)-TP
  17. python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_17

  18. round:小数部分只取三位。
  19. python 代码在线混淆 膨胀 pytorch 混淆矩阵_人工智能_18

  20. 将当前类别信息添加到刚刚初始化的table里面
  21. python 代码在线混淆 膨胀 pytorch 混淆矩阵_学习_19

  22. 类别标签,precision,recall,specificity
  23. 绘制混淆矩阵:plot
  24. python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_20

  25. 1 将matrix赋值给matrix
    2 打印混淆矩阵
    3 使用imshow函数展示混淆矩阵。颜色变换:从白色到蓝色。
  26. python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_21

  27. 4 对于label:默认是0、1、2、3这种坐标。但是希望它展示的是标签的类别。
    使用xticks,将原来x轴的信息(0~num_classes-1)替换成为labels,对x轴旋转45°
    5 y轴同理
    6 混淆矩阵右侧像色谱一样的colorbar。数值的密集程度,颜色越深,数值就越密集。
  28. python 代码在线混淆 膨胀 pytorch 混淆矩阵_学习_22

  29. 7 横坐标 True labels
    8 纵坐标 predicted labels
    9 图像标题 Confusion matrix

将每个区域的数值标注在图像上

python 代码在线混淆 膨胀 pytorch 混淆矩阵_学习_23


1 设置阈值,指定数字文本的颜色。取matrix最大数值的一半

2 遍历x坐标(显示图像的时候,坐标原点一般在图像的左上角),x从左到右,y从上到下。

3 遍历y坐标

4 对每一个坐标,获取它的matrix信息**[y,x]!!!不是[x,y]!!!**。 取整,得到当前位置的统计个数

python 代码在线混淆 膨胀 pytorch 混淆矩阵_混淆矩阵_24


5 通过text方法,将info绘制在[x,y]坐标处。

6、7 绘制在水平方向、竖直方向的中心位置处。

8 color对应数字的颜色,大于阈值:白色

9 让图形显示更加紧凑,否则部分信息可能被遮挡

10 展示混淆矩阵

使用pytorch计算分类模型的混淆矩阵

python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_25


1 判断设备,是否使用GPU

2 打印设备信息

3 之前训练目标net网络时,使用针对验证集的一个处理方式,直接使用了当时已经处理好的模型权重,所以此处要与它保持相同的预处理方式。

5 使用花分类数据集的验证集

6 dataloader载入验证集

python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_26


1 实例化网络MobileNetV2(之前写过的,这里拿来继续用)

python 代码在线混淆 膨胀 pytorch 混淆矩阵_分类_27


python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_28


2 、3 载入之前已经训练好的MobileNetV2的模型权重

4 将模型连到设备上去

python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_29


python 代码在线混淆 膨胀 pytorch 混淆矩阵_分类_30


1、2 载入之前生成的json文件(对应着索引与类别信息)。读入后提取出所有的标签信息。载入后是字典形式,而我们只需要它的标签

python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_31


4 label for_,label in class_indect.items():不要key,只要value

5 实例化上面定义的ConfusionMatrix类

6 启动验证模式

7 上下文管理器 torch.no_grad(),停止pytorch对变量梯度的跟踪

8 遍历dataloader数据集

9 分为图片,标签

10 把图片储存到设备中,输入网络,得到输出

11 softmax处理

12 通过argmax得到最大的元素

13 调用.update()方法输入预测值(outputs.numpy())、真实标签值(val_labels.numpy())

14 plot 绘制混淆矩阵

15 打印各个指标信息

python 代码在线混淆 膨胀 pytorch 混淆矩阵_python 代码在线混淆 膨胀_32