文章目录(Table of Contents)
介绍
这一篇主要介绍一下绘制混淆矩阵(confusion matrix)的方式。通常在看model的效果的时候,我们会使用混淆矩阵来进行检测。
主要参考资料 :
- 混淆矩阵的计算方式 : How to check and read Confusion matrix?
- 混淆矩阵可视化方式 : How can I plot a confusion matrix? [duplicate]
具体绘制方式
混淆矩阵的计算
混淆矩阵就是我们会计算最后分类错误的个数, 如计算将class1分为class2的个数,以此类推。
我们可以使用下面的方式来进行混淆矩阵的计算。
- # 绘制混淆矩阵
- def confusion_matrix(preds, labels, conf_matrix):
- preds = torch.argmax(preds, 1)
- for p, t in zip(preds, labels):
- conf_matrix[p, t] += 1
- return conf_matrix
- conf_matrix = torch.zeros(10, 10)
- for data, target in test_loader:
- output = fullModel(data.to(device))
- conf_matrix = confusion_matrix(output, target, conf_matrix)
最后得到的conf_matrix就是混淆矩阵的值。
混淆矩阵的可视化
有了上面的混淆矩阵中具体的值,下面就是进行可视化的步骤。可视化我们使用seaborn来进行完成。因为我这里conf_matrix的值是tensor, 所以需要先转换为Numpy.
- import seaborn as sn
- df_cm = pd.DataFrame(conf_matrix.numpy(),
- index = [i for i in list(Attack2Index.keys())],
- columns = [i for i in list(Attack2Index.keys())])
- plt.figure(figsize = (10,7))
- sn.heatmap(df_cm, annot=True, cmap="BuPu")
最终的混淆矩阵的图如下所示:
混淆矩阵的可视化(进行美化)
当然, 我们还可以对混淆矩阵做更多的处理, 使得显示的时候能更加好看一些. 下面的绘制混淆矩阵的函数我是在下面的链接里看到的, 最终的效果很是不错。
参考链接 : AE_RL_NSL_KDD
这里简单贴一下代码,可以方便直接进行使用。
- import itertools
- # 绘制混淆矩阵
- def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
- """
- This function prints and plots the confusion matrix.
- Normalization can be applied by setting `normalize=True`.
- Input
- - cm : 计算出的混淆矩阵的值
- - classes : 混淆矩阵中每一行每一列对应的列
- - normalize : True:显示百分比, False:显示个数
- """
- if normalize:
- cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
- print("Normalized confusion matrix")
- else:
- print('Confusion matrix, without normalization')
- print(cm)
- plt.imshow(cm, interpolation='nearest', cmap=cmap)
- plt.title(title)
- plt.colorbar()
- tick_marks = np.arange(len(classes))
- plt.xticks(tick_marks, classes, rotation=45)
- plt.yticks(tick_marks, classes)
- fmt = '.2f' if normalize else 'd'
- thresh = cm.max() / 2.
- for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
- plt.text(j, i, format(cm[i, j], fmt),
- horizontalalignment="center",
- color="white" if cm[i, j] > thresh else "black")
- plt.tight_layout()
- plt.ylabel('True label')
- plt.xlabel('Predicted label')
测试数据如下所示:
- cnf_matrix = np.array([[8707, 64, 731, 164, 45],
- [1821, 5530, 79, 0, 28],
- [266, 167, 1982, 4, 2],
- [691, 0, 107, 1930, 26],
- [30, 0, 111, 17, 42]])
- attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']
我们分别测试normalize=True/False的效果。
- plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')
- plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=False, title='Normalized confusion matrix')
- 微信公众号
- 关注微信公众号
- QQ群
- 我们的QQ群号
评论