混淆矩阵的绘制(Plot a confusion matrix)

  • A+
所属分类:机器学习
摘要这一篇简单介绍一下混淆矩阵的计算和绘制,混淆矩阵可以用来判断模型预测的结果。

介绍

这一篇主要介绍一下绘制混淆矩阵(confusion matrix)的方式。通常在看model的效果的时候,我们会使用混淆矩阵来进行检测。

主要参考资料 : 

具体绘制方式

混淆矩阵的计算

混淆矩阵就是我们会计算最后分类错误的个数, 如计算将class1分为class2的个数,以此类推。

我们可以使用下面的方式来进行混淆矩阵的计算。

  1. # 绘制混淆矩阵
  2. def confusion_matrix(preds, labels, conf_matrix):
  3.     preds = torch.argmax(preds, 1)
  4.     for p, t in zip(preds, labels):
  5.         conf_matrix[p, t] += 1
  6.     return conf_matrix
  7. conf_matrix = torch.zeros(10, 10)
  8. for data, target in test_loader:
  9.     output = fullModel(data.to(device))
  10.     conf_matrix = confusion_matrix(output, target, conf_matrix)

最后得到的conf_matrix就是混淆矩阵的值。

混淆矩阵的绘制(Plot a confusion matrix)

混淆矩阵的可视化

有了上面的混淆矩阵中具体的值,下面就是进行可视化的步骤。可视化我们使用seaborn来进行完成。因为我这里conf_matrix的值是tensor, 所以需要先转换为Numpy.

  1. import seaborn as sn
  2. df_cm = pd.DataFrame(conf_matrix.numpy(),
  3.                      index = [i for i in list(Attack2Index.keys())],
  4.                      columns = [i for i in list(Attack2Index.keys())])
  5. plt.figure(figsize = (10,7))
  6. sn.heatmap(df_cm, annot=True, cmap="BuPu")

最终的混淆矩阵的图如下所示:

混淆矩阵的绘制(Plot a confusion matrix)

混淆矩阵的可视化(进行美化)

当然, 我们还可以对混淆矩阵做更多的处理, 使得显示的时候能更加好看一些. 下面的绘制混淆矩阵的函数我是在下面的链接里看到的, 最终的效果很是不错。

参考链接 : AE_RL_NSL_KDD

这里简单贴一下代码,可以方便直接进行使用。

  1. import itertools
  2. # 绘制混淆矩阵
  3. def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
  4.     """
  5.     This function prints and plots the confusion matrix.
  6.     Normalization can be applied by setting `normalize=True`.
  7.     Input
  8.     - cm : 计算出的混淆矩阵的值
  9.     - classes : 混淆矩阵中每一行每一列对应的列
  10.     - normalize : True:显示百分比, False:显示个数
  11.     """
  12.     if normalize:
  13.         cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
  14.         print("Normalized confusion matrix")
  15.     else:
  16.         print('Confusion matrix, without normalization')
  17.     print(cm)
  18.     plt.imshow(cm, interpolation='nearest', cmap=cmap)
  19.     plt.title(title)
  20.     plt.colorbar()
  21.     tick_marks = np.arange(len(classes))
  22.     plt.xticks(tick_marks, classes, rotation=45)
  23.     plt.yticks(tick_marks, classes)
  24.     fmt = '.2f' if normalize else 'd'
  25.     thresh = cm.max() / 2.
  26.     for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
  27.         plt.text(j, i, format(cm[i, j], fmt),
  28.                  horizontalalignment="center",
  29.                  color="white" if cm[i, j] > thresh else "black")
  30.     plt.tight_layout()
  31.     plt.ylabel('True label')
  32.     plt.xlabel('Predicted label')

测试数据如下所示:

  1. cnf_matrix = np.array([[8707, 64, 731, 164, 45],
  2.                       [1821, 5530, 79, 0, 28],
  3.                       [266, 167, 1982, 4, 2],
  4.                       [691, 0, 107, 1930, 26],
  5.                       [30, 0, 111, 17, 42]])
  6. attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']

我们分别测试normalize=True/False的效果。

  1. plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')
混淆矩阵的绘制(Plot a confusion matrix)
  1. plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=False, title='Normalized confusion matrix')
混淆矩阵的绘制(Plot a confusion matrix)
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: