ROC 曲线与可视化 API#

Scikit-learn 定义了一个简单的 API,用于创建机器学习的可视化。此 API 的关键功能是允许快速绘图和视觉调整,而无需重新计算。在此示例中,我们将演示如何通过比较 ROC 曲线来使用可视化 API。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

加载数据并训练 SVC#

首先,我们加载葡萄酒数据集并将其转换为二元分类问题。然后,我们在训练数据集上训练一个支持向量分类器。

import matplotlib.pyplot as plt

from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

X, y = load_wine(return_X_y=True)
y = y == 2

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
svc = SVC(random_state=42)
svc.fit(X_train, y_train)
SVC(random_state=42)
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示形式或信任 notebook。
在 GitHub 上,HTML 表示形式无法渲染,请尝试使用 nbviewer.org 加载此页面。


绘制 ROC 曲线#

接下来,我们通过一次调用 sklearn.metrics.RocCurveDisplay.from_estimator 来绘制 ROC 曲线。返回的 svc_disp 对象允许我们在未来的绘图中继续使用已经计算出的 SVC ROC 曲线。

svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
plt.show()
plot roc curve visualization api

训练随机森林并绘制 ROC 曲线#

我们训练一个随机森林分类器,并创建一个将其与 SVC ROC 曲线进行比较的图表。请注意 svc_disp 如何使用 plot 来绘制 SVC ROC 曲线,而无需重新计算 ROC 曲线本身的值。此外,我们将 alpha=0.8 传递给绘图函数以调整曲线的 alpha 值。

rfc = RandomForestClassifier(n_estimators=10, random_state=42)
rfc.fit(X_train, y_train)
ax = plt.gca()
rfc_disp = RocCurveDisplay.from_estimator(
    rfc, X_test, y_test, ax=ax, curve_kwargs=dict(alpha=0.8)
)
svc_disp.plot(ax=ax, curve_kwargs=dict(alpha=0.8))
plt.show()
plot roc curve visualization api

脚本总运行时间: (0 minutes 0.146 seconds)

相关示例

带交叉验证的接收者操作特征(ROC)

带交叉验证的接收者操作特征(ROC)

scikit-learn 0.22 发布亮点

scikit-learn 0.22 发布亮点

多类接收者操作特征(ROC)

多类接收者操作特征(ROC)

检测错误权衡(DET)曲线

检测错误权衡(DET)曲线

由 Sphinx-Gallery 生成的图库