梯度提升中的提前停止#

梯度提升是一种集成技术,它结合多个弱学习器(通常是决策树)来创建一个强大且鲁棒的预测模型。它以迭代的方式进行,其中每个新阶段(树)都纠正前面阶段的错误。

提前停止是梯度提升中的一种技术,它允许我们找到构建一个能够很好地泛化到未见数据的模型所需的最佳迭代次数,并避免过拟合。这个概念很简单:我们将数据集的一部分作为验证集(使用validation_fraction指定)来评估模型在训练过程中的性能。随着模型通过增加阶段(树)进行迭代构建,其在验证集上的性能作为步骤数量的函数进行监控。

当模型在验证集上的性能在一个特定数量的连续阶段(由n_iter_no_change指定)内趋于平稳或变差(在由tol指定的偏差范围内)时,提前停止变得有效。这表明模型已经达到了一个点,进一步迭代可能会导致过拟合,是时候停止训练了。

应用提前停止后,最终模型中的估计器(树)数量可以使用n_estimators_属性访问。总的来说,提前停止是在梯度提升中平衡模型性能和效率的宝贵工具。

许可证:BSD 3条款

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

数据准备#

首先,我们加载并准备加利福尼亚州房价数据集用于训练和评估。它对数据集进行子集化,并将其分成训练集和验证集。

import time

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_california_housing
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

data = fetch_california_housing()
X, y = data.data[:600], data.target[:600]

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

模型训练和比较#

训练了两个GradientBoostingRegressor模型:一个使用提前停止,另一个不使用。目的是比较它们的性能。它还计算训练时间和两个模型使用的n_estimators_

params = dict(n_estimators=1000, max_depth=5, learning_rate=0.1, random_state=42)

gbm_full = GradientBoostingRegressor(**params)
gbm_early_stopping = GradientBoostingRegressor(
    **params,
    validation_fraction=0.1,
    n_iter_no_change=10,
)

start_time = time.time()
gbm_full.fit(X_train, y_train)
training_time_full = time.time() - start_time
n_estimators_full = gbm_full.n_estimators_

start_time = time.time()
gbm_early_stopping.fit(X_train, y_train)
training_time_early_stopping = time.time() - start_time
estimators_early_stopping = gbm_early_stopping.n_estimators_

误差计算#

代码计算前面部分训练模型的训练数据集和验证数据集的mean_squared_error。它计算每个提升迭代的误差。目的是评估模型的性能和收敛性。

train_errors_without = []
val_errors_without = []

train_errors_with = []
val_errors_with = []

for i, (train_pred, val_pred) in enumerate(
    zip(
        gbm_full.staged_predict(X_train),
        gbm_full.staged_predict(X_val),
    )
):
    train_errors_without.append(mean_squared_error(y_train, train_pred))
    val_errors_without.append(mean_squared_error(y_val, val_pred))

for i, (train_pred, val_pred) in enumerate(
    zip(
        gbm_early_stopping.staged_predict(X_train),
        gbm_early_stopping.staged_predict(X_val),
    )
):
    train_errors_with.append(mean_squared_error(y_train, train_pred))
    val_errors_with.append(mean_squared_error(y_val, val_pred))

可视化比较#

它包括三个子图

  1. 绘制两个模型在提升迭代中的训练误差。

  2. 绘制两个模型在提升迭代中的验证误差。

  3. 创建一个条形图来比较使用和不使用提前停止的模型的训练时间和使用的估计器。

fig, axes = plt.subplots(ncols=3, figsize=(12, 4))

axes[0].plot(train_errors_without, label="gbm_full")
axes[0].plot(train_errors_with, label="gbm_early_stopping")
axes[0].set_xlabel("Boosting Iterations")
axes[0].set_ylabel("MSE (Training)")
axes[0].set_yscale("log")
axes[0].legend()
axes[0].set_title("Training Error")

axes[1].plot(val_errors_without, label="gbm_full")
axes[1].plot(val_errors_with, label="gbm_early_stopping")
axes[1].set_xlabel("Boosting Iterations")
axes[1].set_ylabel("MSE (Validation)")
axes[1].set_yscale("log")
axes[1].legend()
axes[1].set_title("Validation Error")

training_times = [training_time_full, training_time_early_stopping]
labels = ["gbm_full", "gbm_early_stopping"]
bars = axes[2].bar(labels, training_times)
axes[2].set_ylabel("Training Time (s)")

for bar, n_estimators in zip(bars, [n_estimators_full, estimators_early_stopping]):
    height = bar.get_height()
    axes[2].text(
        bar.get_x() + bar.get_width() / 2,
        height + 0.001,
        f"Estimators: {n_estimators}",
        ha="center",
        va="bottom",
    )

plt.tight_layout()
plt.show()
Training Error, Validation Error

gbm_fullgbm_early_stopping之间训练误差的差异源于gbm_early_stoppingvalidation_fraction的训练数据作为内部验证集。提前停止是根据这个内部验证分数决定的。

总结#

在加利福尼亚州房价数据集上使用GradientBoostingRegressor模型的示例中,我们演示了提前停止的实际好处。

  • 防止过拟合:我们展示了验证误差如何在某个点之后稳定下来或开始增加,这表明模型更好地泛化到未见数据。这是通过在过拟合发生之前停止训练过程来实现的。

  • 提高训练效率:我们比较了使用和不使用提前停止的模型的训练时间。使用提前停止的模型在获得相当的准确性的同时,所需的估计器数量明显减少,从而加快了训练速度。

脚本的总运行时间:(0分钟3.826秒)

相关示例

随机梯度下降的提前停止

随机梯度下降的提前停止

梯度提升回归

梯度提升回归

缩放SVC的正则化参数

缩放SVC的正则化参数

比较随机森林和直方图梯度提升模型

比较随机森林和直方图梯度提升模型

由Sphinx-Gallery生成的图库