检查估计器#

sklearn.utils.estimator_checks.check_estimator(estimator=None, generate_only=False, *, legacy: bool = True, expected_failed_checks: dict[str, str] | None = None, on_skip: Literal['warn'] | None = 'warn', on_fail: Literal['raise', 'warn'] | None = 'raise', callback: Callable | None = None)[source]#

检查估计器是否符合scikit-learn约定。

此函数将运行广泛的测试套件,用于输入验证、形状等,确保估计器符合scikit-learn约定,详见 创建您自己的估计器。如果估计器类继承自sklearn.base中的相应mixin,则将运行分类器、回归器、聚类或转换器的附加测试。

scikit-learn还提供了一个特定于pytest的装饰器,parametrize_with_checks,使测试多个估计器更容易。

检查分为以下几组:

参数:
estimator估计器对象

要检查的估计器实例。

generate_only布尔值,默认为 False

False时,在调用check_estimator时评估检查。当True时,check_estimator返回一个生成器,该生成器产生 (estimator, check) 元组。通过调用check(estimator)运行检查。

0.22版本新增。

1.6版本已弃用: generate_only 将在 1.8 版本中移除。请改用 estimator_checks_generator

legacy布尔值,默认为 True

是否包含旧版检查。随着时间的推移,我们将从此类别中删除检查并将其移动到其特定类别。

1.6版本新增。

expected_failed_checks字典,默认为 None

形式为:

{
    "check_name": "this check is expected to fail because ...",
}

其中"check_name"是检查的名称,"my reason"是检查失败的原因。

1.6版本新增。

on_skip“warn”,None,默认为“warn”

此参数控制跳过检查时发生的情况。

  • “warn”:记录SkipTestWarning警告,并继续运行测试。

  • None:不记录警告,并继续运行测试。

1.6版本新增。

on_fail {“raise”, “warn”},None,默认为“raise”

此参数控制检查失败时发生的情况。

  • “raise”:引发第一个失败检查引发的异常,并中止正在运行的测试。这并不包括预期会失败的测试。

  • “warn”:记录一个EstimatorCheckFailedWarning警告,并继续运行测试。

  • None:不引发异常,也不记录警告。

请注意,如果on_fail != "raise",则即使检查失败,也不会引发异常。您需要检查check_estimator的返回结果,以检查是否有任何检查失败。

1.6版本新增。

callback可调用对象或 None,默认为 None

此回调将使用估计器和检查名称、异常(如有)、检查的状态(xfail、failed、skipped、passed)以及如果检查预期失败则预期失败的原因来调用。可调用的签名需要为:

def callback(
    estimator,
    check_name: str,
    exception: Exception,
    status: Literal["xfail", "failed", "skipped", "passed"],
    expected_to_fail: bool,
    expected_to_fail_reason: str,
)

callback不能与on_fail="raise"一起提供。

1.6版本新增。

返回值:
test_results列表

包含失败测试结果的字典列表,格式如下:

{
    "estimator": estimator,
    "check_name": check_name,
    "exception": exception,
    "status": status (one of "xfail", "failed", "skipped", "passed"),
    "expected_to_fail": expected_to_fail,
    "expected_to_fail_reason": expected_to_fail_reason,
}
estimator_checks_generator生成器

生成 (估算器, 检查) 元组的生成器。当 generate_only=True 时返回。

1.6版本已弃用: generate_only 将在 1.8 版本中移除。请改用 estimator_checks_generator

引发:
异常

如果 on_fail="raise",则引发第一个失败检查引发的异常,并中止正在运行的测试。

请注意,如果on_fail != "raise",则即使检查失败,也不会引发异常。您需要检查check_estimator的返回结果,以检查是否有任何检查失败。

另请参阅

带检查的参数化

用于参数化估算器检查的 Pytest 特定装饰器。

估计器检查生成器

生成 (估算器, 检查) 元组的生成器。

示例

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.linear_model import LogisticRegression
>>> check_estimator(LogisticRegression())
[...]