检查估计器#
- sklearn.utils.estimator_checks.check_estimator(estimator=None, generate_only=False, *, legacy: bool = True, expected_failed_checks: dict[str, str] | None = None, on_skip: Literal['warn'] | None = 'warn', on_fail: Literal['raise', 'warn'] | None = 'raise', callback: Callable | None = None)[source]#
检查估计器是否符合scikit-learn约定。
此函数将运行广泛的测试套件,用于输入验证、形状等,确保估计器符合
scikit-learn
约定,详见 创建您自己的估计器。如果估计器类继承自sklearn.base中的相应mixin,则将运行分类器、回归器、聚类或转换器的附加测试。scikit-learn还提供了一个特定于pytest的装饰器,
parametrize_with_checks
,使测试多个估计器更容易。检查分为以下几组:
API 检查:一组检查以确保与 scikit-learn 的 API 兼容性。请参考 https://scikit-learn.cn/dev/developers/develop.html,这是 scikit-learn 估计器的要求。
legacy:一组检查,这些检查将逐渐被归入其他类别。
- 参数:
- estimator估计器对象
要检查的估计器实例。
- generate_only布尔值,默认为 False
当
False
时,在调用check_estimator
时评估检查。当True
时,check_estimator
返回一个生成器,该生成器产生 (estimator, check) 元组。通过调用check(estimator)
运行检查。0.22版本新增。
1.6版本已弃用:
generate_only
将在 1.8 版本中移除。请改用estimator_checks_generator
。- legacy布尔值,默认为 True
是否包含旧版检查。随着时间的推移,我们将从此类别中删除检查并将其移动到其特定类别。
1.6版本新增。
- expected_failed_checks字典,默认为 None
形式为:
{ "check_name": "this check is expected to fail because ...", }
其中
"check_name"
是检查的名称,"my reason"
是检查失败的原因。1.6版本新增。
- on_skip“warn”,None,默认为“warn”
此参数控制跳过检查时发生的情况。
“warn”:记录
SkipTestWarning
警告,并继续运行测试。None:不记录警告,并继续运行测试。
1.6版本新增。
- on_fail {“raise”, “warn”},None,默认为“raise”
此参数控制检查失败时发生的情况。
“raise”:引发第一个失败检查引发的异常,并中止正在运行的测试。这并不包括预期会失败的测试。
“warn”:记录一个
EstimatorCheckFailedWarning
警告,并继续运行测试。None:不引发异常,也不记录警告。
请注意,如果
on_fail != "raise"
,则即使检查失败,也不会引发异常。您需要检查check_estimator
的返回结果,以检查是否有任何检查失败。1.6版本新增。
- callback可调用对象或 None,默认为 None
此回调将使用估计器和检查名称、异常(如有)、检查的状态(xfail、failed、skipped、passed)以及如果检查预期失败则预期失败的原因来调用。可调用的签名需要为:
def callback( estimator, check_name: str, exception: Exception, status: Literal["xfail", "failed", "skipped", "passed"], expected_to_fail: bool, expected_to_fail_reason: str, )
callback
不能与on_fail="raise"
一起提供。1.6版本新增。
- 返回值:
- test_results列表
包含失败测试结果的字典列表,格式如下:
{ "estimator": estimator, "check_name": check_name, "exception": exception, "status": status (one of "xfail", "failed", "skipped", "passed"), "expected_to_fail": expected_to_fail, "expected_to_fail_reason": expected_to_fail_reason, }
- estimator_checks_generator生成器
生成 (估算器, 检查) 元组的生成器。当
generate_only=True
时返回。1.6版本已弃用:
generate_only
将在 1.8 版本中移除。请改用estimator_checks_generator
。
- 引发:
- 异常
如果
on_fail="raise"
,则引发第一个失败检查引发的异常,并中止正在运行的测试。请注意,如果
on_fail != "raise"
,则即使检查失败,也不会引发异常。您需要检查check_estimator
的返回结果,以检查是否有任何检查失败。
示例
>>> from sklearn.utils.estimator_checks import check_estimator >>> from sklearn.linear_model import LogisticRegression >>> check_estimator(LogisticRegression()) [...]