scikit-learn 1.6 的发布亮点#

我们很高兴地宣布 scikit-learn 1.6 发布!此次发布新增了许多错误修复和改进,以及一些关键的新功能。下面我们详细介绍此次发布的亮点。有关所有更改的详尽列表,请参阅发布说明

安装最新版本(使用 pip)

pip install --upgrade scikit-learn

或使用 conda

conda install -c conda-forge scikit-learn

FrozenEstimator:冻结估计器#

此元估计器允许您接受一个估计器并冻结其 fit 方法,这意味着调用 fit 不会执行任何操作;此外,fit_predictfit_transform 分别调用 predicttransform 而不调用 fit。原始估计器的其他方法和属性保持不变。一个有趣的用例是,将预拟合模型用作管道中的转换器步骤,或将预拟合模型传递给某些元估计器。这里有一个简短的例子

import time

from sklearn.datasets import make_classification
from sklearn.frozen import FrozenEstimator
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import FixedThresholdClassifier

X, y = make_classification(n_samples=1000, random_state=0)

start = time.time()
classifier = SGDClassifier().fit(X, y)
print(f"Fitting the classifier took {(time.time() - start) * 1_000:.2f} milliseconds")

start = time.time()
threshold_classifier = FixedThresholdClassifier(
    estimator=FrozenEstimator(classifier), threshold=0.9
).fit(X, y)
print(
    f"Fitting the threshold classifier took {(time.time() - start) * 1_000:.2f} "
    "milliseconds"
)
Fitting the classifier took 2.96 milliseconds
Fitting the threshold classifier took 0.61 milliseconds

拟合阈值分类器跳过了内部 SGDClassifier 的拟合。更多详细信息请参阅示例使用 FrozenEstimator 的示例

转换管道中除 X 之外的数据#

Pipeline 现在支持在必要时转换除了 X 之外传入的数据。这可以通过设置新的 transform_input 参数来完成。这在通过管道传递验证集时特别有用。

例如,假设 EstimatorWithValidationSet 是一个接受验证集的估计器。我们现在可以有一个管道,它将转换验证集并将其传递给估计器

with sklearn.config_context(enable_metadata_routing=True):
    est_gs = GridSearchCV(
        Pipeline(
            (
                StandardScaler(),
                EstimatorWithValidationSet(...).set_fit_request(X_val=True, y_val=True),
            ),
            # telling pipeline to transform these inputs up to the step which is
            # requesting them.
            transform_input=["X_val"],
        ),
        param_grid={"estimatorwithvalidationset__param_to_optimize": list(range(5))},
        cv=5,
    ).fit(X, y, X_val=X_val, y_val=y_val)

在上述代码中,关键部分是调用 set_fit_request 以指定 EstimatorWithValidationSet.fit 方法需要 X_valy_val,以及 transform_input 参数以告诉管道在将 X_val 传递给 EstimatorWithValidationSet.fit 之前对其进行转换。

请注意,目前 scikit-learn 估计器尚未扩展为接受用户指定的验证集。此功能提前发布是为了收集可能从中受益的第三方库的反馈。

LogisticRegression(solver="newton-cholesky") 的多类支持#

"newton-cholesky" 求解器(最初在 scikit-learn 1.2 版本中引入)之前仅限于二元 LogisticRegression 和其他一些广义线性回归估计器(即 PoissonRegressorGammaRegressorTweedieRegressor)。

此新版本包含了对多类(多项式)LogisticRegression 的支持。

此求解器在特征数量为中小型时特别有用。经验表明,在某些带有独热编码分类特征的中型数据集上,它比其他求解器收敛更可靠、更快,如拉取请求的基准测试结果所示。

Extra Trees 对缺失值的支持#

ensemble.ExtraTreesClassifierensemble.ExtraTreesRegressor 现在支持缺失值。更多详细信息请参阅用户指南

import numpy as np

from sklearn.ensemble import ExtraTreesClassifier

X = np.array([0, 1, 6, np.nan]).reshape(-1, 1)
y = [0, 0, 1, 1]

forest = ExtraTreesClassifier(random_state=0).fit(X, y)
forest.predict(X)
array([0, 0, 1, 1])

从网络下载任何数据集#

函数 datasets.fetch_file 允许从任何给定 URL 下载文件。此便捷函数提供内置的本地磁盘缓存、sha256 摘要完整性检查以及网络错误时的自动重试机制。

目标是提供与数据集获取器相同的便利性和可靠性,同时提供从任意在线源和文件格式处理数据的灵活性。

下载的文件可以通过通用或特定领域函数(如 pandas.read_csvpandas.read_parquet 等)加载。

Array API 支持#

自 1.5 版本以来,更多估计器和函数已更新以支持 Array API 兼容输入,特别是来自 sklearn.model_selection 模块的超参数调优元估计器和来自 sklearn.metrics 模块的度量指标。

请参阅Array API 支持页面,了解如何将 scikit-learn 与 PyTorch 或 CuPy 等 Array API 兼容库一起使用。

几乎完全的元数据路由支持#

除 AdaBoost 外,所有剩余的估计器和函数都已添加元数据路由支持。有关更多详细信息,请参阅元数据路由用户指南

Free-threaded CPython 3.13 支持#

scikit-learn 对 free-threaded CPython 提供了初步支持,特别是我们所有支持的平台都提供了 free-threaded wheels。

Free-threaded(也称为 nogil)CPython 3.13 是 CPython 3.13 的一个实验版本,旨在通过移除全局解释器锁 (GIL) 来实现高效的多线程用例。

有关 free-threaded CPython 的更多详细信息,请参阅 py-free-threading 文档,特别是 如何安装 free-threaded CPython生态系统兼容性跟踪

请随时在您的用例中尝试 free-threaded CPython 并报告任何问题!

第三方库开发者 API 的改进#

我们一直致力于改进第三方库的开发者 API。这项工作仍在进行中,但在此版本中已完成大量工作。此版本包括

  • 引入了 sklearn.utils.validation.validate_data 并取代了之前私有的 BaseEstimator._validate_data 方法。此函数扩展了 check_array 并添加了记住输入特征计数和名称的支持。

  • 估计器标签现已通过 sklearn.utils.Tags 重新设计并成为公共 API 的一部分。估计器现在应重写 BaseEstimator.__sklearn_tags__ 方法,而不是实现 _more_tags 方法。如果您希望支持多个 scikit-learn 版本,则可以在类中同时实现这两种方法。

  • 由于开发了公共标签 API,我们已删除了 _xfail_checks 标签,并且预期会失败的测试直接传递给 check_estimatorparametrize_with_checks。有关更多详细信息,请参阅其相应的 API 文档。

  • 通用测试套件中的许多测试已更新并提供更有用的错误消息。我们还添加了一些新测试,这应该有助于您更轻松地修复估计器中的潜在问题。

我们的开发 scikit-learn 估计器的更新版本也已推出,我们建议您查看。

脚本总运行时间: (0 分钟 0.104 秒)

相关示例

scikit-learn 1.7 的发布亮点

scikit-learn 1.7 的发布亮点

scikit-learn 1.2 的发布亮点

scikit-learn 1.2 的发布亮点

scikit-learn 0.22 的发布亮点

scikit-learn 0.22 的发布亮点

scikit-learn 1.3 的发布亮点

scikit-learn 1.3 的发布亮点

由 Sphinx-Gallery 生成的图库