连续减半迭代#

此示例说明了连续减半搜索 (HalvingGridSearchCVHalvingRandomSearchCV) 如何迭代地从多个候选者中选择最佳参数组合。

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import randint

from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier
from sklearn.experimental import enable_halving_search_cv  # noqa
from sklearn.model_selection import HalvingRandomSearchCV

我们首先定义参数空间并训练一个 HalvingRandomSearchCV 实例。

rng = np.random.RandomState(0)

X, y = datasets.make_classification(n_samples=400, n_features=12, random_state=rng)

clf = RandomForestClassifier(n_estimators=20, random_state=rng)

param_dist = {
    "max_depth": [3, None],
    "max_features": randint(1, 6),
    "min_samples_split": randint(2, 11),
    "bootstrap": [True, False],
    "criterion": ["gini", "entropy"],
}

rsh = HalvingRandomSearchCV(
    estimator=clf, param_distributions=param_dist, factor=2, random_state=rng
)
rsh.fit(X, y)
HalvingRandomSearchCV(estimator=RandomForestClassifier(n_estimators=20,
                                                       random_state=RandomState(MT19937) at 0x7F4DFC443240),
                      factor=2,
                      param_distributions={'bootstrap': [True, False],
                                           'criterion': ['gini', 'entropy'],
                                           'max_depth': [3, None],
                                           'max_features': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7f4e31b83940>,
                                           'min_samples_split': <scipy.stats._distn_infrastructure.rv_discrete_frozen object at 0x7f4df6cfc280>},
                      random_state=RandomState(MT19937) at 0x7F4DFC443240)
在 Jupyter 环境中,请重新运行此单元格以显示 HTML 表示或信任笔记本。
在 GitHub 上,HTML 表示无法呈现,请尝试使用 nbviewer.org 加载此页面。


我们现在可以使用搜索估计器的 cv_results_ 属性来检查和绘制搜索的演变。

results = pd.DataFrame(rsh.cv_results_)
results["params_str"] = results.params.apply(str)
results.drop_duplicates(subset=("params_str", "iter"), inplace=True)
mean_scores = results.pivot(
    index="iter", columns="params_str", values="mean_test_score"
)
ax = mean_scores.plot(legend=False, alpha=0.6)

labels = [
    f"iter={i}\nn_samples={rsh.n_resources_[i]}\nn_candidates={rsh.n_candidates_[i]}"
    for i in range(rsh.n_iterations_)
]

ax.set_xticks(range(rsh.n_iterations_))
ax.set_xticklabels(labels, rotation=45, multialignment="left")
ax.set_title("Scores of candidates over iterations")
ax.set_ylabel("mean test score", fontsize=15)
ax.set_xlabel("iterations", fontsize=15)
plt.tight_layout()
plt.show()
Scores of candidates over iterations

每次迭代的候选者数量和资源量#

在第一次迭代中,使用少量资源。这里的资源是估计器训练的样本数量。评估所有候选者。

在第二次迭代中,只评估了最佳候选者的前一半。分配的资源数量增加了一倍:候选者在两倍的样本上进行评估。

重复此过程,直到最后一次迭代,此时只剩下 2 个候选者。最佳候选者是在最后一次迭代中得分最高的候选者。

脚本的总运行时间:(0 分钟 4.979 秒)

相关示例

网格搜索和连续减半的比较

网格搜索和连续减半的比较

比较随机搜索和网格搜索以进行超参数估计

比较随机搜索和网格搜索以进行超参数估计

scikit-learn 0.24 的发布亮点

scikit-learn 0.24 的发布亮点

使用交叉验证的网格搜索的自定义重新拟合策略

使用交叉验证的网格搜索的自定义重新拟合策略

由 Sphinx-Gallery 生成的画廊