check_estimator#
- sklearn.utils.estimator_checks.check_estimator(estimator=None, generate_only=False, *, legacy: bool = True, expected_failed_checks: dict[str, str] | None = None, on_skip: Literal['warn'] | None = None, on_fail: Literal['raise', 'warn'] | None = 'raise', callback: Callable | None = None)[source]#
检查估计器是否符合 scikit-learn 约定。
此函数将运行一个广泛的测试套件,用于输入验证、形状等,确保估计器符合
scikit-learn
约定,如自定义估计器中所述。如果估计器类继承自 sklearn.base 中的相应混入类,则会运行分类器、回归器、聚类或转换器的额外测试。scikit-learn 还提供了一个 pytest 特定的装饰器,
parametrize_with_checks
,使得测试多个估计器更加容易。检查分为以下几组
API 检查:一系列检查,确保与 scikit-learn 的 API 兼容性。请参阅 https://scikit-learn.cn/dev/developers/develop.html,这是 scikit-learn 估计器的要求之一。
遗留:一系列检查,将逐步归入其他类别。
- 参数:
- estimator估计器对象
要检查的估计器实例。
- generate_only布尔型, 默认为 False
当
False
时,检查在调用check_estimator
时进行评估。当True
时,check_estimator
返回一个生成器,生成 (estimator, check) 元组。通过调用check(estimator)
来运行检查。0.22 版新增。
自 1.6 版弃用:
generate_only
将在 1.8 版中移除。请改用estimator_checks_generator
。- legacy布尔型, 默认为 True
是否包含遗留检查。随着时间的推移,我们将从该类别中移除检查并将其移入特定类别。
1.6 版新增。
- expected_failed_checks字典型, 默认为 None
形式如下的字典
{ "check_name": "this check is expected to fail because ...", }
其中
"check_name"
是检查的名称,"my reason"
是检查失败的原因。1.6 版新增。
- on_skip“warn”, None, 默认为 “warn”
此参数控制检查跳过时的行为。
“warn”:“SkipTestWarning” 被记录,并且测试继续运行。
None:不记录警告,并且测试继续运行。
1.6 版新增。
- on_fail{“raise”, “warn”}, None, 默认为 “raise”
此参数控制检查失败时的行为。
“raise”:第一个失败的检查所引发的异常会被抛出,并且中止正在运行的测试。这不包括预期会失败的测试。
“warn”:
EstimatorCheckFailedWarning
被记录,并且测试继续运行。None:不抛出异常,也不记录警告。
请注意,如果
on_fail != "raise"
,即使检查失败也不会抛出异常。您需要检查check_estimator
的返回值来检查是否有任何检查失败。1.6 版新增。
- callback可调用对象, 或 None, 默认为 None
此回调将被调用,并带有估计器、检查名称、异常(如果有)、检查状态(xfail, failed, skipped, passed),以及如果检查预期失败时的预期失败原因。可调用对象的签名需要是
def callback( estimator, check_name: str, exception: Exception, status: Literal["xfail", "failed", "skipped", "passed"], expected_to_fail: bool, expected_to_fail_reason: str, )
callback
不能与on_fail="raise"
同时提供。1.6 版新增。
- 返回:
- test_results列表
包含失败测试结果的字典列表,形式如下
{ "estimator": estimator, "check_name": check_name, "exception": exception, "status": status (one of "xfail", "failed", "skipped", "passed"), "expected_to_fail": expected_to_fail, "expected_to_fail_reason": expected_to_fail_reason, }
- estimator_checks_generator生成器
生成器,生成 (estimator, check) 元组。当
generate_only=True
时返回。自 1.6 版弃用:
generate_only
将在 1.8 版中移除。请改用estimator_checks_generator
。
- 引发:
- 异常
如果
on_fail="raise"
,第一个失败的检查所引发的异常会被抛出,并且中止正在运行的测试。请注意,如果
on_fail != "raise"
,即使检查失败也不会抛出异常。您需要检查check_estimator
的返回值来检查是否有任何检查失败。
另请参阅
parametrize_with_checks
用于参数化估计器检查的 Pytest 特定装饰器。
estimator_checks_generator
生成器,生成 (估计器, 检查) 元组。
示例
>>> from sklearn.utils.estimator_checks import check_estimator >>> from sklearn.linear_model import LogisticRegression >>> check_estimator(LogisticRegression()) [...]