如何画混淆矩阵 Python
混淆矩阵(Confusion Matrix)是机器学习中用于评估分类模型性能的常见工具。它将模型的预测结果与真实标签进行比较,并将其分为四个不同的类别:真正例(True Positive,TP)、假正例(False Positive,FP)、真反例(True Negative,TN)和假反例(False Negative,FN)。在 Python 中,我们可以使用一些库来绘制混淆矩阵,如 matplotlib 和 seaborn。
在本文中,我们将介绍如何使用 Python 画混淆矩阵。首先,我们需要安装 matplotlib 和 seaborn 库。可以使用以下命令来安装它们:
pip install matplotlib seaborn
接下来,我们将使用一个简单的示例来说明如何画混淆矩阵。假设我们有一个二分类模型,用于预测某个人是否患有某种疾病。我们有一组真实标签和模型的预测结果,如下所示:
真实标签 | 预测结果 |
---|---|
1 | 1 |
1 | |
1 | 1 |
1 | |
首先,我们需要将真实标签和预测结果转换为 NumPy 数组。然后,我们可以使用 sklearn 库的 confusion_matrix 函数来计算混淆矩阵。接下来,我们可以使用 seaborn 库的 heatmap 函数绘制混淆矩阵。
以下是具体的代码示例:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# 真实标签和预测结果
y_true = np.array([1, 0, 1, 0, 1, 0])
y_pred = np.array([1, 1, 1, 0, 0, 0])
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 绘制混淆矩阵
sns.heatmap(cm, annot=True, cmap='Blues', fmt='d')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
上述代码中,首先导入了需要的库。然后,我们定义了真实标签(y_true)和预测结果(y_pred)。接下来,使用 confusion_matrix 函数计算混淆矩阵,并将结果存储在 cm 变量中。最后,使用 seaborn 库的 heatmap 函数绘制混淆矩阵。heatmap 函数的参数包括:混淆矩阵(cm)、是否在矩阵中显示数值(annot)、颜色映射(cmap)和数值的格式(fmt)。
执行上述代码,将会得到如下的混淆矩阵图:
![](