parametrize_with_checks#

sklearn.utils.estimator_checks.parametrize_with_checks(estimators, *, legacy: bool = True, expected_failed_checks: Callable | None = None)[源代码]#

用于参数化评估器检查的 Pytest 特殊装饰器。

检查分为以下几组:

  • API 检查:一组检查,用于确保与 scikit-learn 的 API 兼容性。请参考 https://scikit-learn.cn/dev/developers/develop.html,这是 scikit-learn 评估器的一个要求。

  • legacy(旧版):一组将逐渐归类到其他类别的检查。

每个检查的 id 被设置为评估器的 pprint 版本以及检查名称及其关键字参数。这允许使用 pytest -k 来指定要运行的测试。

pytest test_check_estimators.py -k check_estimators_fit_returns_self
参数:
estimators评估器实例列表

用于生成检查的评估器。

0.24 版本中的变化: 在 0.23 版本中,传递类已被弃用,并在 0.24 版本中移除了对类的支持。请改为传递实例。

0.24 版本新增。

legacy布尔值, 默认值=True

是否包含旧版检查。随着时间推移,我们将从该类别中移除检查,并将其移至其特定类别。

1.6 版本新增。

expected_failed_checks可调用对象, 默认值=None

一个可调用对象,它接受一个评估器作为输入,并返回一个形式为以下所示的字典:

{
    "check_name": "my reason",
}

其中 "check_name" 是检查的名称,"my reason" 是检查失败的原因。如果检查失败,这些测试将被标记为 xfail。

1.6 版本新增。

返回:
decoratorpytest.mark.parametrize

另请参阅

check_estimator

检查评估器是否符合 scikit-learn 约定。

示例

>>> from sklearn.utils.estimator_checks import parametrize_with_checks
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.tree import DecisionTreeRegressor
>>> @parametrize_with_checks([LogisticRegression(),
...                           DecisionTreeRegressor()])
... def test_sklearn_compatible_estimator(estimator, check):
...     check(estimator)