注意
转到末尾以下载完整示例代码。或者通过 JupyterLite 或 Binder 在浏览器中运行此示例
绘制学习曲线和检查模型的扩展性#
在此示例中,我们将展示如何使用 LearningCurveDisplay
类轻松绘制学习曲线。此外,我们还将对朴素贝叶斯分类器和 SVM 分类器所获得的学习曲线进行解释。
然后,我们将通过考察这些预测模型的计算成本,而不仅仅是其统计准确性,来探讨并得出关于其可扩展性的一些结论。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
学习曲线#
学习曲线显示了在训练过程中增加更多样本的影响。这种影响通过检查模型在训练分数和测试分数方面的统计性能来体现。
在这里,我们使用数字数据集计算朴素贝叶斯分类器和带有 RBF 核的 SVM 分类器的学习曲线。
from sklearn.datasets import load_digits
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
X, y = load_digits(return_X_y=True)
naive_bayes = GaussianNB()
svc = SVC(kernel="rbf", gamma=0.001)
from_estimator
方法根据数据集和要分析的预测模型显示学习曲线。为了获得分数不确定性的估计,该方法使用交叉验证过程。
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import LearningCurveDisplay, ShuffleSplit
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 6), sharey=True)
common_params = {
"X": X,
"y": y,
"train_sizes": np.linspace(0.1, 1.0, 5),
"cv": ShuffleSplit(n_splits=50, test_size=0.2, random_state=0),
"score_type": "both",
"n_jobs": 4,
"line_kw": {"marker": "o"},
"std_display_style": "fill_between",
"score_name": "Accuracy",
}
for ax_idx, estimator in enumerate([naive_bayes, svc]):
LearningCurveDisplay.from_estimator(estimator, **common_params, ax=ax[ax_idx])
handles, label = ax[ax_idx].get_legend_handles_labels()
ax[ax_idx].legend(handles[:2], ["Training Score", "Test Score"])
ax[ax_idx].set_title(f"Learning Curve for {estimator.__class__.__name__}")

我们首先分析朴素贝叶斯分类器的学习曲线。它的形状在更复杂的数据集中很常见:当使用少量样本进行训练时,训练分数非常高,随着样本数量的增加而降低,而测试分数一开始非常低,然后随着样本的增加而升高。当所有样本都用于训练时,训练和测试分数会变得更真实。
我们看到了带有 RBF 核的 SVM 分类器的另一个典型学习曲线。无论训练集大小如何,训练分数都保持高位。另一方面,测试分数随着训练数据集的大小而增加。实际上,它会增加到一个点,达到一个平台期。观察到这样的平台期表明,获取新数据来训练模型可能不再有用,因为模型的泛化性能将不再提高。
复杂性分析#
除了这些学习曲线之外,还可以从训练时间和评分时间方面考察预测模型的可扩展性。
LearningCurveDisplay
类不提供此类信息。我们需要转而使用 learning_curve
函数并手动绘制图表。
from sklearn.model_selection import learning_curve
common_params = {
"X": X,
"y": y,
"train_sizes": np.linspace(0.1, 1.0, 5),
"cv": ShuffleSplit(n_splits=50, test_size=0.2, random_state=0),
"n_jobs": 4,
"return_times": True,
}
train_sizes, _, test_scores_nb, fit_times_nb, score_times_nb = learning_curve(
naive_bayes, **common_params
)
train_sizes, _, test_scores_svm, fit_times_svm, score_times_svm = learning_curve(
svc, **common_params
)
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16, 12), sharex=True)
for ax_idx, (fit_times, score_times, estimator) in enumerate(
zip(
[fit_times_nb, fit_times_svm],
[score_times_nb, score_times_svm],
[naive_bayes, svc],
)
):
# scalability regarding the fit time
ax[0, ax_idx].plot(train_sizes, fit_times.mean(axis=1), "o-")
ax[0, ax_idx].fill_between(
train_sizes,
fit_times.mean(axis=1) - fit_times.std(axis=1),
fit_times.mean(axis=1) + fit_times.std(axis=1),
alpha=0.3,
)
ax[0, ax_idx].set_ylabel("Fit time (s)")
ax[0, ax_idx].set_title(
f"Scalability of the {estimator.__class__.__name__} classifier"
)
# scalability regarding the score time
ax[1, ax_idx].plot(train_sizes, score_times.mean(axis=1), "o-")
ax[1, ax_idx].fill_between(
train_sizes,
score_times.mean(axis=1) - score_times.std(axis=1),
score_times.mean(axis=1) + score_times.std(axis=1),
alpha=0.3,
)
ax[1, ax_idx].set_ylabel("Score time (s)")
ax[1, ax_idx].set_xlabel("Number of training samples")

我们看到 SVM 和朴素贝叶斯分类器的可扩展性截然不同。SVM 分类器在拟合和评分时的复杂度随样本数量迅速增加。实际上,众所周知,该分类器的拟合时间复杂度与样本数量呈二次以上关系,这使得其难以扩展到样本量超过数万的数据集。相比之下,朴素贝叶斯分类器在拟合和评分时的复杂度较低,扩展性要好得多。
随后,我们可以检查训练时间增加与交叉验证分数之间的权衡。
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))
for ax_idx, (fit_times, test_scores, estimator) in enumerate(
zip(
[fit_times_nb, fit_times_svm],
[test_scores_nb, test_scores_svm],
[naive_bayes, svc],
)
):
ax[ax_idx].plot(fit_times.mean(axis=1), test_scores.mean(axis=1), "o-")
ax[ax_idx].fill_between(
fit_times.mean(axis=1),
test_scores.mean(axis=1) - test_scores.std(axis=1),
test_scores.mean(axis=1) + test_scores.std(axis=1),
alpha=0.3,
)
ax[ax_idx].set_ylabel("Accuracy")
ax[ax_idx].set_xlabel("Fit time (s)")
ax[ax_idx].set_title(
f"Performance of the {estimator.__class__.__name__} classifier"
)
plt.show()

在这些图中,我们可以寻找交叉验证分数不再增加,而只有训练时间增加的拐点。
脚本总运行时间:(0 分 23.630 秒)
相关示例