使用绘图 API 进行开发#

Scikit-learn 定义了一个简单的 API,用于创建机器学习的可视化。此 API 的主要特点是只运行一次计算,并能在事后灵活调整可视化。本节旨在为希望开发或维护绘图工具的开发人员提供指导。有关用法,用户应参阅用户指南

绘图 API 概述#

此逻辑封装在一个显示对象中,其中计算数据被存储,并在 plot 方法中完成绘图。显示对象的 __init__ 方法仅包含创建可视化所需的数据。plot 方法接受仅与可视化相关的参数,例如 matplotlib 轴。 plot 方法将 matplotlib artists 作为属性存储,允许通过显示对象进行样式调整。Display 类应定义一个或两个类方法:from_estimatorfrom_predictions。这些方法允许从估计器和一些数据或从真实值和预测值创建 Display 对象。在这些类方法使用计算值创建显示对象后,然后调用显示对象的 plot 方法。请注意,plot 方法定义了与 matplotlib 相关的属性,例如线条艺术家。这允许在调用 plot 方法后进行自定义。

例如,RocCurveDisplay 定义了以下方法和属性

class RocCurveDisplay:
    def __init__(self, fpr, tpr, roc_auc, estimator_name):
        ...
        self.fpr = fpr
        self.tpr = tpr
        self.roc_auc = roc_auc
        self.estimator_name = estimator_name

    @classmethod
    def from_estimator(cls, estimator, X, y):
        # get the predictions
        y_pred = estimator.predict_proba(X)[:, 1]
        return cls.from_predictions(y, y_pred, estimator.__class__.__name__)

    @classmethod
    def from_predictions(cls, y, y_pred, estimator_name):
        # do ROC computation from y and y_pred
        fpr, tpr, roc_auc = ...
        viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
        return viz.plot()

    def plot(self, ax=None, name=None, **kwargs):
        ...
        self.line_ = ...
        self.ax_ = ax
        self.figure_ = ax.figure_

详情请参阅带可视化 API 的 ROC 曲线用户指南

使用多个轴进行绘图#

某些绘图工具,例如 from_estimatorPartialDependenceDisplay 支持在多个轴上绘图。支持两种不同的情况:

1. 如果传入轴列表,plot 将检查轴的数量是否与预期轴的数量一致,然后在这些轴上绘图。2. 如果传入单个轴,该轴将定义一个空间,用于放置多个轴。在这种情况下,我们建议使用 matplotlib 的 ~matplotlib.gridspec.GridSpecFromSubplotSpec 来分割空间。

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec

fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())

ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])

默认情况下,plot 中的 ax 关键字为 None。在这种情况下,会创建一个单个轴,并使用 gridspec API 创建用于绘图的区域。

例如,请参阅 from_estimator,它使用此 API 绘制多条线和等高线。定义边界框的轴保存在 bounding_ax_ 属性中。创建的各个轴存储在 axes_ ndarray 中,对应于网格上的轴位置。未使用的位置设置为 None。此外,matplotlib Artists 存储在 lines_contours_ 中,其中键是网格上的位置。当传入轴列表时,axes_lines_contours_ 是一个与传入轴列表对应的 1D ndarray。