check_estimator#

sklearn.utils.estimator_checks.check_estimator(estimator=None, *, legacy: bool = True, expected_failed_checks: dict[str, str] | None = None, on_skip: Literal['warn'] | None = 'warn', on_fail: Literal['raise', 'warn'] | None = 'raise', callback: Callable | None = None)[source]#

检查估算器是否符合 scikit-learn 约定。

此函数将运行一套全面的测试套件,用于输入验证、形状等,确保估计器符合 scikit-learn 约定,如 编写自己的估计器 中所述。如果估计器类继承自 sklearn.base 中的相应 mixin,则将运行针对分类器、回归器、聚类器或转换器的额外测试。

scikit-learn 还提供了一个 pytest 特定装饰器 parametrize_with_checks,使测试多个估计器更加容易。

检查分为以下几组

参数:
estimator估计器对象

要检查的估计器实例。

legacybool, default=True

是否包含遗留检查。随着时间的推移,我们会从这个类别中删除检查,并将它们移动到特定的类别中。

版本 1.6 中新增。

expected_failed_checksdict, default=None

一个字典,形式为

{
    "check_name": "this check is expected to fail because ...",
}

其中 "check_name" 是检查名称,"my reason" 是检查失败的原因。

版本 1.6 中新增。

on_skip“warn”, None, default=”warn”

此参数控制检查被跳过时发生的情况。

  • “warn”:记录 SkipTestWarning 并继续运行测试。

  • None:不记录警告并继续运行测试。

版本 1.6 中新增。

on_fail{“raise”, “warn”}, None, default=”raise”

此参数控制检查失败时发生的情况。

  • “raise”:引发第一个失败检查引发的异常,并中止正在运行的测试。这不包括预期会失败的测试。

  • “warn”:记录 EstimatorCheckFailedWarning 并继续运行测试。

  • None:不引发异常,不记录警告。

请注意,如果 on_fail != "raise",即使检查失败,也不会引发异常。您需要检查 check_estimator 的返回结果以查看是否有任何检查失败。

版本 1.6 中新增。

callbackcallable, or None, default=None

此回调函数将使用估计器和检查名称、异常(如果有)、检查状态(xfail、failed、skipped、passed)以及预期失败原因(如果检查预期会失败)进行调用。回调函数的签名必须为

def callback(
    estimator,
    check_name: str,
    exception: Exception,
    status: Literal["xfail", "failed", "skipped", "passed"],
    expected_to_fail: bool,
    expected_to_fail_reason: str,
)

callback 不能与 on_fail="raise" 一起提供。

版本 1.6 中新增。

返回:
test_resultslist

包含失败测试结果的字典列表,形式为

{
    "estimator": estimator,
    "check_name": check_name,
    "exception": exception,
    "status": status (one of "xfail", "failed", "skipped", "passed"),
    "expected_to_fail": expected_to_fail,
    "expected_to_fail_reason": expected_to_fail_reason,
}
Raises:
异常

如果 on_fail="raise",则引发第一个失败检查引发的异常,并中止正在运行的测试。

请注意,如果 on_fail != "raise",即使检查失败,也不会引发异常。您需要检查 check_estimator 的返回结果以查看是否有任何检查失败。

另请参阅

parametrize_with_checks

用于参数化估算器检查的 pytest 特定装饰器。

estimator_checks_generator

生成器,生成 (estimator, check) 元组。

示例

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.linear_model import LogisticRegression
>>> check_estimator(LogisticRegression())
[...]