【机器学习基础】让人惊艳的决策树可视化
共 1914字,需浏览 4分钟
·
2020-12-22 17:27
本文目标
本文的目标是引入dtreeviz来可视化分类决策树,比scikit-learn包自带的可视化效果更好。我们将在Scikit学习使用iris数据集学习决策树教程。
请注意,如果我们使用决策树进行回归,可视化效果会有所不同。
scikit-learn和dtreeviz可视化对比
前期准备
首先使用pip或conda命令安装模块,如下所示。
pip install dtreeviz
pip install graphviz #这个包需要下载安装,并配置对应的变量环境
加载数据集
import numpy as np
from sklearn.datasets import load_iris, load_boston
from sklearn import tree
iris = load_iris()
df_iris = pd.DataFrame(iris['data'], columns=iris['feature_names'])
df_iris['target'] = iris['target']
df_iris.head()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target |
5.1 | 3.5 | 1.4 | 0.2 | 0 |
4.9 | 3 | 1.4 | 0.2 | 0 |
4.7 | 3.2 | 1.3 | 0.2 | 0 |
4.6 | 3.1 | 1.5 | 0.2 | 0 |
5 | 3.6 | 1.4 | 0.2 | 0 |
训练决策树
# Train the Decision tree model
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
可视化决策树
Scikit-learn
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph
dtreeviz
from dtreeviz.trees import dtreeviz
viz = dtreeviz(clf,
iris['data'],
iris['target'],
target_name='',
feature_names=np.array(iris['feature_names']),
class_names={0:'setosa',1:'versicolor',2:'virginica'})
viz
通过上面的对比,发现Treeviz的可视化会更好
你可以在每个节点上看到每个类的分布
你可以看到每个分割的决策边界在哪里
你可以看到每片叶子上的样品大小与圆的大小相同
原文名称:How to visualize a decision tree beyond scikit-learn
原文链接:https://h1ros.github.io/posts/how-to-visualize-a-decision-tree-beyond-scikit-learn/
往期精彩回顾
获取本站知识星球优惠券,复制链接直接打开:
https://t.zsxq.com/qFiUFMV
本站qq群704220115。
加入微信群请扫码: