文章目录(Table of Contents)
简介
这一篇是对于决策树的介绍, 使用决策树来解决分类问题, 同时我们会将决策树的结果进行可视化, 来查看他的分类的过程.
参考资料
主要参考内容来自sklearn的官方教程: 1.10. Decision Trees
这也是一个进行可视化的博客, 最后保存和显示的方式会有不同. Creating and Visualizing Decision Trees with Python
模型的训练与可视化
导入数据集
首先我们导入我们需要的库.
- import sklearn.datasets as datasets
- import pandas as pd
接下来我们就可以导入数据集
- iris = datasets.load_iris()
- df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
- y = iris.target
我们简单看一下使用的数据集.
训练决策树模型
接着, 我们训练决策树模型, 为了最后的显示效果, 我们控制决策树的深度.
- from sklearn.tree import DecisionTreeClassifier
- # 训练决策树模型(控制决策树的深度, 这里控制最大深度是2)
- dtree = DecisionTreeClassifier(max_depth=2)
- dtree.fit(df, y)
- """
- DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=2,
- max_features=None, max_leaf_nodes=None,
- min_impurity_decrease=0.0, min_impurity_split=None,
- min_samples_leaf=1, min_samples_split=2,
- min_weight_fraction_leaf=0.0, presort=False, random_state=None,
- splitter='best')
- """
关于sklearn中决策树的参考文档: sklearn.tree.DecisionTreeClassifier
模型的评价
关于模型的评价, 可以参考下面的链接: 模型评价指标说明与实践–混淆矩阵的说明
- from tools import *
我们对上面的决策树进行评价(我直接在训练集上进行看准确率了, 这里就是做一个演示, 实际使用的时候需要划分测试集).
- # 模型的评估
- y_pre = dtree.predict(df)
- display_model_performance_metrics(true_labels=y,
- predicted_labels=y_pre,
- classes=[0,1,2])
最终的模型的准确率如下所示, 可以看到即使深度不是很深, 准确率也是可以的.
结果的可视化
接下来我们就将决策树的结果进行可视化. 我们会使用到graphviz.
- from sklearn.tree import export_graphviz
- import graphviz
需要将graphviz添加到路径中, 具体的内容看下面的一些问题中的内容.
- import os
- os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin'
对于生成的图像, 我们可以设置每一类的lable的名字, 等一些其他的设置.
- dot_data = export_graphviz(dtree, out_file=None,
- feature_names=iris.feature_names,
- class_names=iris.target_names,
- filled=True, rounded=True,
- special_characters=True)
- # 可以设置图像的显示
- graph = graphviz.Source(dot_data)
最后只需要将图像进行保存即可.
- graph.render(filename ="iris", directory ='./', format='pdf')
关于图像保存的参数的说明, 可以参考文档: graphviz.render
这个是graphviz的说明文档: graphviz稳定版文档
最终的可视化的效果, 如下图所示:
简单说明一下上面的图像, 每一个叶子节点中有class, 表示按照上面的规则, 会被分到哪一个类别中. 同时, 每一个节点中有values, 表示到这一个节点中每一个类别的样本有多少个, 如上面的例子中一共有3类样本, 所以values中有三个数字, 分别是三个类别的样本的个数.
一些问题
Make sure the Graphviz executables are on your systems’ PATH
详细的报错信息:
- RuntimeError: failed to execute ['dot', '-Tpdf', '-O', 'test'], make sure the Graphviz executables are on your systems' path
在进行可视化的时候, 会出现如上的报错, 我们可以使用下面的方式进行解决.
- 首先下载: Graphviz - Graph Visualization Software
- 接着在运行的时候添加到路径即可:
- import os
- os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin'
这样再次运行即可.
- 微信公众号
- 关注微信公众号
- QQ群
- 我们的QQ群号
评论