使用绘图 API 进行开发#
Scikit-learn 定义了一个简单的 API,用于创建机器学习的可视化效果。此 API 的关键特性是只需运行一次计算,并且能够灵活地调整后续的可视化效果。本节面向希望开发或维护绘图工具的开发者。关于用法,用户应参考用户指南。
绘图 API 概述#
此逻辑封装在一个显示对象中,其中存储了计算后的数据,绘图在一个plot
方法中完成。显示对象的__init__
方法只包含创建可视化所需的数据。plot
方法接收仅与可视化相关的参数,例如 matplotlib 坐标轴。plot
方法将 matplotlib 绘图元素存储为属性,允许通过显示对象进行样式调整。Display
类应定义一个或两个类方法:from_estimator
和from_predictions
。这些方法允许从估计器和一些数据或从真实值和预测值创建Display
对象。这些类方法使用计算后的值创建显示对象后,调用显示对象的 plot 方法。请注意,plot
方法定义了与 matplotlib 相关的属性,例如线条绘图元素。这允许在调用plot
方法后进行自定义。
例如,RocCurveDisplay
定义了以下方法和属性
class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name
@classmethod
def from_estimator(cls, estimator, X, y):
# get the predictions
y_pred = estimator.predict_proba(X)[:, 1]
return cls.from_predictions(y, y_pred, estimator.__class__.__name__)
@classmethod
def from_predictions(cls, y, y_pred, estimator_name):
# do ROC computation from y and y_pred
fpr, tpr, roc_auc = ...
viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
return viz.plot()
def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_
更多内容请阅读使用可视化 API 绘制 ROC 曲线和用户指南。
使用多个坐标轴进行绘图#
一些绘图工具,例如from_estimator
和PartialDependenceDisplay
支持在多个坐标轴上绘图。支持两种不同的场景
1. 如果传入坐标轴列表,plot
将检查坐标轴数量是否与其预期数量一致,然后在这些坐标轴上绘图。2. 如果传入单个坐标轴,该坐标轴定义了放置多个坐标轴的空间。在这种情况下,我们建议使用 matplotlib 的~matplotlib.gridspec.GridSpecFromSubplotSpec
来分割空间
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec
fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())
ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])
默认情况下,plot
中的ax
关键字为None
。在这种情况下,将创建单个坐标轴,并使用 gridspec api 创建绘图区域。
例如,参见from_estimator
,它使用此 API 绘制多条线和等高线。定义边界框的坐标轴存储在bounding_ax_
属性中。创建的各个坐标轴存储在axes_
ndarray 中,对应于网格上的坐标轴位置。未使用的位置设置为None
。此外,matplotlib 绘图元素存储在lines_
和contours_
中,其中键是网格上的位置。当传入坐标轴列表时,axes_
、lines_
和contours_
是对应于传入坐标轴列表的一维 ndarray。