多分类混淆矩阵(Confusion Matrix)及其在Python中的实现

引言

在机器学习和统计学中,我们经常需要评估一个分类模型的性能。混淆矩阵是一种常用的评估分类模型性能的工具。它可以帮助我们了解模型在每个类别上的表现,并进一步计算出各种性能指标,如准确率、召回率和 F1 分数。

本文将介绍什么是多分类混淆矩阵,如何使用 Python 中的混淆矩阵库来计算和可视化混淆矩阵,并给出实际案例和代码示例。

多分类混淆矩阵是什么?

多分类混淆矩阵是一种将分类模型的预测结果与真实标签进行比较的表格。它将模型的预测结果按照实际类别进行分类,并统计每个分类的数量。通过观察混淆矩阵,我们可以了解模型在每个类别上的表现,包括正确分类和错误分类的数量。

多分类混淆矩阵通常是一个 N × N 的矩阵,其中 N 是类别的数量。矩阵的每一行表示真实标签,每一列表示模型的预测标签。对角线上的元素表示模型正确分类的数量,而非对角线上的元素表示错误分类的数量。

Python 中的混淆矩阵库

在 Python 中,我们可以使用 scikit-learn 库中的 confusion_matrix 函数来计算多分类混淆矩阵。该函数接受两个参数:真实标签和模型预测标签。它将返回一个 N × N 的矩阵,其中 N 是类别的数量。

以下是使用 confusion_matrix 函数计算多分类混淆矩阵的示例代码:

from sklearn.metrics import confusion_matrix

# 定义真实标签和模型预测标签
y_true = [0, 1, 2, 0, 1, 2, 0, 1, 2]
y_pred = [0, 0, 2, 0, 1, 1, 0, 0, 2]

# 计算多分类混淆矩阵
cm = confusion_matrix(y_true, y_pred)
print(cm)

输出结果为:

[[3 0 0]
 [2 1 0]
 [1 0 2]]

可视化多分类混淆矩阵

除了计算多分类混淆矩阵,我们还可以使用 Python 中的 seaborn 库来可视化混淆矩阵。seaborn 是一个基于 matplotlib 的数据可视化库,提供了更美观和直观的图表风格。

以下是使用 seaborn 库可视化多分类混淆矩阵的示例代码:

import seaborn as sns

# 使用之前的混淆矩阵
cm = [[3, 0, 0],
      [2, 1, 0],
      [1, 0, 2]]

# 创建热力图
sns.heatmap(cm, annot=True, cmap="Blues")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

运行以上代码将显示一个热力图,其中 x 轴表示模型预测的标签,y 轴表示真实的标签。每个单元格的数字表示该类别的样本数量。通过颜色的深浅,我们可以直观地了解模型在每个类别上的性能。

Confusion Matrix Heatmap

应用案例

为了更好地理解多分类混淆矩阵的应用,让我们假设我们正在构建一个垃圾邮件分类器。我们要根据邮件的内容将其分为垃圾邮件和非垃圾邮件。

在训练和测试后