check_estimator#

sklearn.utils.estimator_checks.check_estimator(estimator=None, generate_only=False, *, legacy: bool = True, expected_failed_checks: dict[str, str] | None = None, on_skip: Literal['warn'] | None = None, on_fail: Literal['raise', 'warn'] | None = 'raise', callback: Callable | None = None)[source]#

检查估计器是否符合 scikit-learn 约定。

此函数将运行一个广泛的测试套件,用于输入验证、形状等,确保估计器符合 scikit-learn 约定,如自定义估计器中所述。如果估计器类继承自 sklearn.base 中的相应混入类,则会运行分类器、回归器、聚类或转换器的额外测试。

scikit-learn 还提供了一个 pytest 特定的装饰器,parametrize_with_checks,使得测试多个估计器更加容易。

检查分为以下几组

参数:
estimator估计器对象

要检查的估计器实例。

generate_only布尔型, 默认为 False

False 时,检查在调用 check_estimator 时进行评估。当 True 时,check_estimator 返回一个生成器,生成 (estimator, check) 元组。通过调用 check(estimator) 来运行检查。

0.22 版新增。

自 1.6 版弃用:generate_only 将在 1.8 版中移除。请改用 estimator_checks_generator

legacy布尔型, 默认为 True

是否包含遗留检查。随着时间的推移,我们将从该类别中移除检查并将其移入特定类别。

1.6 版新增。

expected_failed_checks字典型, 默认为 None

形式如下的字典

{
    "check_name": "this check is expected to fail because ...",
}

其中 "check_name" 是检查的名称,"my reason" 是检查失败的原因。

1.6 版新增。

on_skip“warn”, None, 默认为 “warn”

此参数控制检查跳过时的行为。

  • “warn”:“SkipTestWarning” 被记录,并且测试继续运行。

  • None:不记录警告,并且测试继续运行。

1.6 版新增。

on_fail{“raise”, “warn”}, None, 默认为 “raise”

此参数控制检查失败时的行为。

  • “raise”:第一个失败的检查所引发的异常会被抛出,并且中止正在运行的测试。这不包括预期会失败的测试。

  • “warn”:EstimatorCheckFailedWarning 被记录,并且测试继续运行。

  • None:不抛出异常,也不记录警告。

请注意,如果 on_fail != "raise",即使检查失败也不会抛出异常。您需要检查 check_estimator 的返回值来检查是否有任何检查失败。

1.6 版新增。

callback可调用对象, 或 None, 默认为 None

此回调将被调用,并带有估计器、检查名称、异常(如果有)、检查状态(xfail, failed, skipped, passed),以及如果检查预期失败时的预期失败原因。可调用对象的签名需要是

def callback(
    estimator,
    check_name: str,
    exception: Exception,
    status: Literal["xfail", "failed", "skipped", "passed"],
    expected_to_fail: bool,
    expected_to_fail_reason: str,
)

callback 不能与 on_fail="raise" 同时提供。

1.6 版新增。

返回:
test_results列表

包含失败测试结果的字典列表,形式如下

{
    "estimator": estimator,
    "check_name": check_name,
    "exception": exception,
    "status": status (one of "xfail", "failed", "skipped", "passed"),
    "expected_to_fail": expected_to_fail,
    "expected_to_fail_reason": expected_to_fail_reason,
}
estimator_checks_generator生成器

生成器,生成 (estimator, check) 元组。当 generate_only=True 时返回。

自 1.6 版弃用:generate_only 将在 1.8 版中移除。请改用 estimator_checks_generator

引发:
异常

如果 on_fail="raise",第一个失败的检查所引发的异常会被抛出,并且中止正在运行的测试。

请注意,如果 on_fail != "raise",即使检查失败也不会抛出异常。您需要检查 check_estimator 的返回值来检查是否有任何检查失败。

另请参阅

parametrize_with_checks

用于参数化估计器检查的 Pytest 特定装饰器。

estimator_checks_generator

生成器,生成 (估计器, 检查) 元组。

示例

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.linear_model import LogisticRegression
>>> check_estimator(LogisticRegression())
[...]