理解决策树结构#

可以分析决策树结构以进一步了解特征与目标预测之间的关系。在此示例中,我们将展示如何检索

  • 二叉树结构;

  • 每个节点的深度以及它是否是叶子节点;

  • 使用decision_path方法到达的节点;

  • 使用apply方法到达的叶子节点;

  • 用于预测样本的规则;

  • 一组样本共享的决策路径。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
from matplotlib import pyplot as plt

from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

训练树分类器#

首先,我们使用DecisionTreeClassifierload_iris数据集拟合模型。

iris = load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
在Jupyter环境中,请重新运行此单元格以显示HTML表示或信任笔记本。
在GitHub上,HTML表示无法呈现,请尝试使用nbviewer.org加载此页面。


树结构#

决策分类器有一个名为tree_的属性,它允许访问低级属性,例如node_count(节点总数)和max_depth(树的最大深度)。tree_.compute_node_depths()方法计算树中每个节点的深度。tree_还存储整个二叉树结构,表示为多个并行数组。每个数组的第i个元素包含有关节点i的信息。节点0是树的根。一些数组仅适用于叶子节点或分割节点。在这种情况下,另一种类型的节点的值是任意的。例如,数组featurethreshold仅适用于分割节点。因此,这些数组中叶子节点的值是任意的。

在这些数组中,我们有

  • children_left[i]:节点i的左子节点的ID,如果为叶子节点则为-1

  • children_right[i]:节点i的右子节点的ID,如果为叶子节点则为-1

  • feature[i]:用于分割节点i的特征

  • threshold[i]:节点i处的阈值

  • n_node_samples[i]:到达节点i的训练样本数

  • impurity[i]:节点i处的杂质

  • weighted_n_node_samples[i]:到达节点i的加权训练样本数

  • value[i, j, k]:对于输出j和类别k,到达节点i的训练样本的汇总(对于回归树,类别设置为1)。有关value的更多信息,请参见下文。

使用这些数组,我们可以遍历树结构来计算各种属性。下面,我们将计算每个节点的深度以及它是否是叶子节点。

n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
values = clf.tree_.value

node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
stack = [(0, 0)]  # start with the root node id (0) and its depth (0)
while len(stack) > 0:
    # `pop` ensures each node is only visited once
    node_id, depth = stack.pop()
    node_depth[node_id] = depth

    # If the left and right child of a node is not the same we have a split
    # node
    is_split_node = children_left[node_id] != children_right[node_id]
    # If a split node, append left and right children and depth to `stack`
    # so we can loop through them
    if is_split_node:
        stack.append((children_left[node_id], depth + 1))
        stack.append((children_right[node_id], depth + 1))
    else:
        is_leaves[node_id] = True

print(
    "The binary tree structure has {n} nodes and has "
    "the following tree structure:\n".format(n=n_nodes)
)
for i in range(n_nodes):
    if is_leaves[i]:
        print(
            "{space}node={node} is a leaf node with value={value}.".format(
                space=node_depth[i] * "\t", node=i, value=np.around(values[i], 3)
            )
        )
    else:
        print(
            "{space}node={node} is a split node with value={value}: "
            "go to node {left} if X[:, {feature}] <= {threshold} "
            "else to node {right}.".format(
                space=node_depth[i] * "\t",
                node=i,
                left=children_left[i],
                feature=feature[i],
                threshold=threshold[i],
                right=children_right[i],
                value=np.around(values[i], 3),
            )
        )
The binary tree structure has 5 nodes and has the following tree structure:

node=0 is a split node with value=[[0.33  0.304 0.366]]: go to node 1 if X[:, 3] <= 0.800000011920929 else to node 2.
        node=1 is a leaf node with value=[[1. 0. 0.]].
        node=2 is a split node with value=[[0.    0.453 0.547]]: go to node 3 if X[:, 2] <= 4.950000047683716 else to node 4.
                node=3 is a leaf node with value=[[0.    0.917 0.083]].
                node=4 is a leaf node with value=[[0.    0.026 0.974]].

此处使用的values数组是什么?#

tree_.value数组是一个形状为[n_nodesn_classesn_outputs]的3D数组,它提供了针对每个类别和每个输出到达节点的样本比例。每个节点都有一个value数组,它是相对于父节点到达此节点的加权样本比例,针对每个输出和类别。

可以通过将此数字乘以给定节点的tree_.weighted_n_node_samples[node_idx]来将其转换为到达节点的绝对加权样本数。请注意,在此示例中未使用样本权重,因此加权样本数是到达节点的样本数,因为默认情况下每个样本的权重为1。

例如,在上例中,在鸢尾花数据集上构建的树的根节点具有value = [0.33, 0.304, 0.366],这表明根节点处有33%的类别0样本、30.4%的类别1样本和36.6%的类别2样本。可以通过乘以到达根节点的样本数tree_.weighted_n_node_samples[0]将其转换为样本的绝对数量。然后,根节点具有value = [37, 34, 41],这表明根节点处有37个类别0样本、34个类别1样本和41个类别2样本。

遍历树时,样本会被拆分,结果是到达每个节点的value数组会发生变化。根节点的左子节点具有value = [1., 0, 0](或转换为样本绝对数量时为value = [37, 0, 0]),因为左子节点中的所有 37 个样本都属于类别 0。

注意:在这个例子中,n_outputs=1,但是树分类器也可以处理多输出问题。value数组在每个节点上将只是一个二维数组。

我们可以将上述输出与决策树的图进行比较。在这里,我们显示到达每个节点的每个类别的样本比例,这对应于tree_.value数组的实际元素。

tree.plot_tree(clf, proportion=True)
plt.show()
plot unveil tree structure

决策路径#

我们还可以检索感兴趣样本的决策路径。decision_path方法输出一个指示矩阵,允许我们检索感兴趣的样本遍历的节点。指示矩阵中位置(i, j)的非零元素表示样本i通过节点j。或者,对于一个样本i,指示矩阵中第i行的非零元素的位置指定该样本经过的节点的ID。

可以使用apply方法获得感兴趣样本到达的叶节点ID。这将返回一个数组,其中包含每个感兴趣样本到达的叶节点的节点ID。使用叶节点ID和decision_path,我们可以获得用于预测样本或样本组的分割条件。首先,让我们对一个样本进行操作。注意node_index是一个稀疏矩阵。

node_indicator = clf.decision_path(X_test)
leaf_id = clf.apply(X_test)

sample_id = 0
# obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
node_index = node_indicator.indices[
    node_indicator.indptr[sample_id] : node_indicator.indptr[sample_id + 1]
]

print("Rules used to predict sample {id}:\n".format(id=sample_id))
for node_id in node_index:
    # continue to the next node if it is a leaf node
    if leaf_id[sample_id] == node_id:
        continue

    # check if value of the split feature for sample 0 is below threshold
    if X_test[sample_id, feature[node_id]] <= threshold[node_id]:
        threshold_sign = "<="
    else:
        threshold_sign = ">"

    print(
        "decision node {node} : (X_test[{sample}, {feature}] = {value}) "
        "{inequality} {threshold})".format(
            node=node_id,
            sample=sample_id,
            feature=feature[node_id],
            value=X_test[sample_id, feature[node_id]],
            inequality=threshold_sign,
            threshold=threshold[node_id],
        )
    )
Rules used to predict sample 0:

decision node 0 : (X_test[0, 3] = 2.4) > 0.800000011920929)
decision node 2 : (X_test[0, 2] = 5.1) > 4.950000047683716)

对于一组样本,我们可以确定样本共同经过的节点。

sample_ids = [0, 1]
# boolean array indicating the nodes both samples go through
common_nodes = node_indicator.toarray()[sample_ids].sum(axis=0) == len(sample_ids)
# obtain node ids using position in array
common_node_id = np.arange(n_nodes)[common_nodes]

print(
    "\nThe following samples {samples} share the node(s) {nodes} in the tree.".format(
        samples=sample_ids, nodes=common_node_id
    )
)
print("This is {prop}% of all nodes.".format(prop=100 * len(common_node_id) / n_nodes))
The following samples [0, 1] share the node(s) [0 2] in the tree.
This is 40.0% of all nodes.

脚本总运行时间:(0 分钟 0.081 秒)

相关示例

绘制层次聚类树状图

绘制层次聚类树状图

使用成本复杂度剪枝后剪枝决策树

使用成本复杂度剪枝后剪枝决策树

决策树回归

决策树回归

绘制在鸢尾花数据集上训练的决策树的决策面

绘制在鸢尾花数据集上训练的决策树的决策面

由 Sphinx-Gallery 生成的图库