check_cv#
- sklearn.model_selection.check_cv(cv=5, y=None, *, classifier=False)[源代码]#
用于构建交叉验证器的输入检查实用程序。
- 参数:
- cvint, cross-validation generator, iterable or None, default=5
确定交叉验证拆分策略。cv 的可能输入是: - None,使用默认的 5 折交叉验证, - 整数,指定折数。 - CV 分割器, - 一个生成 (训练, 测试) 拆分的索引数组的可迭代对象。
对于整数/None 输入,如果 classifier 为 True 且
y是二元或多元分类,则使用StratifiedKFold。在所有其他情况下,使用KFold。有关此处可使用的各种交叉验证策略,请参阅 用户指南。
版本 0.22 中已更改:
cv的默认值从 3 折更改为 5 折。- yarray-like, default=None
用于监督学习问题的目标变量。
- classifierbool, default=False
任务是否为分类任务,如果是,则将使用分层 K 折。
- 返回:
- checked_cva cross-validator instance.
返回值是一个交叉验证器,它通过
split方法生成训练/测试拆分。
示例
>>> from sklearn.model_selection import check_cv >>> check_cv(cv=5, y=None, classifier=False) KFold(...) >>> check_cv(cv=5, y=[1, 1, 0, 0, 0, 0], classifier=True) StratifiedKFold(...)