开发 scikit-learn 估计器#

无论您是建议将估计器包含在 scikit-learn 中,开发与 scikit-learn 兼容的单独包,还是为自己的项目实现自定义组件,本章都详细介绍了如何开发与 scikit-learn 管道和模型选择工具安全交互的对象。

scikit-learn 对象的 API#

为了拥有统一的 API,我们尝试为所有对象提供一个通用的基本 API。此外,为了避免框架代码的泛滥,我们尝试采用简单的约定,并将对象必须实现的方法数量降至最低。

scikit-learn API 的元素在常见术语和 API 元素术语表中进行了更明确的描述。

不同的对象#

scikit-learn 中的主要对象是(一个类可以实现多个接口)

估计器:

基本对象,实现一个 fit 方法来从数据中学习,无论是

estimator = estimator.fit(data, targets)

还是

estimator = estimator.fit(data)
预测器:

对于监督学习或一些无监督问题,实现

prediction = predictor.predict(data)

分类算法通常还提供一种量化预测确定性的方法,可以使用 decision_functionpredict_proba

probability = predictor.predict_proba(data)
转换器:

用于以监督或无监督的方式修改数据(例如,通过添加、更改或删除列,但不添加或删除行)。实现

new_data = transformer.transform(data)

当拟合和转换可以比单独执行更有效地一起执行时,实现

new_data = transformer.fit_transform(data)
模型:

一个可以给出拟合优度度量或未见数据的似然的模型,实现(越高越好)

score = model.score(data)

估计器#

API 拥有一个主要对象:估计器。估计器是一个根据一些训练数据拟合模型的对象,并且能够推断新数据的某些属性。它可以是分类器或回归器。所有估计器都实现 fit 方法

estimator.fit(X, y)

所有内置估计器还具有一个 set_params 方法,该方法设置与数据无关的参数(覆盖之前传递给 __init__ 的参数值)。

scikit-learn 主代码库中的所有估计器都应该继承自 sklearn.base.BaseEstimator

实例化#

这涉及到对象的创建。对象的 __init__ 方法可能会接受常量作为参数,这些常量决定估计器的行为(例如 SVM 中的 C 常量)。但是,它不应该将实际训练数据作为参数,因为这留给了 fit() 方法

clf2 = SVC(C=2.3)
clf3 = SVC([[1, 2], [2, 3]], [-1, 1]) # WRONG!

__init__ 接受的参数都应该是具有默认值的关键字参数。换句话说,用户应该能够在不向估计器传递任何参数的情况下实例化估计器。所有参数都应该对应于描述模型或估计器试图解决的优化问题的超参数。这些初始参数(或参数)始终由估计器记住。还要注意,它们不应该在“属性”部分中记录,而应该在该估计器的“参数”部分中记录。

此外,每个由 __init__ 接受的关键字参数都应该对应于实例上的一个属性。scikit-learn 依赖于此来查找在进行模型选择时设置估计器的相关属性。

总之,一个 __init__ 应该看起来像

def __init__(self, param1=1, param2=2):
    self.param1 = param1
    self.param2 = param2

不应该有任何逻辑,甚至没有输入验证,并且参数不应该被更改。相应的逻辑应该放在使用参数的地方,通常在 fit 中。以下是不正确的

def __init__(self, param1=1, param2=2, param3=3):
    # WRONG: parameters should not be modified
    if param1 > 1:
        param2 += 1
    self.param1 = param1
    # WRONG: the object's attributes should have exactly the name of
    # the argument in the constructor
    self.param3 = param2

推迟验证的原因是,相同的验证必须在 set_params 中执行,而 set_params 用于 GridSearchCV 等算法中。

拟合#

您可能想要做的下一件事是估计模型中的一些参数。这在 fit() 方法中实现。

fit() 方法将训练数据作为参数,在无监督学习的情况下可以是一个数组,在监督学习的情况下可以是两个数组。

请注意,模型是使用 Xy 拟合的,但对象不保存对 Xy 的引用。但是,有一些例外,例如在预先计算的内核的情况下,必须存储这些数据以供 predict 方法使用。

参数

X

形状为 (n_samples, n_features) 的类数组

y

形状为 (n_samples,) 的类数组

kwargs

可选的与数据相关的参数

X.shape[0] 应该与 y.shape[0] 相同。如果未满足此要求,则应引发类型为 ValueError 的异常。

y 在无监督学习的情况下可能会被忽略。但是,为了能够将估计器用作可以混合监督和无监督转换器的管道的一部分,即使是无监督估计器也需要在第二个位置接受一个 y=None 关键字参数,该参数只是被估计器忽略。出于同样的原因,fit_predictfit_transformscorepartial_fit 方法需要在第二个位置接受一个 y 参数(如果它们已实现)。

该方法应该返回对象 (self)。这种模式对于能够在 IPython 会话中实现快速的一行代码很有用,例如

y_predicted = SVC(C=100).fit(X_train, y_train).predict(X_test)

根据算法的性质,fit 有时也可以接受额外的关键字参数。但是,任何可以在访问数据之前分配值的参数都应该是 __init__ 关键字参数。fit 参数应该仅限于直接与数据相关的变量。例如,从数据矩阵 X 预先计算的 Gram 矩阵或亲和矩阵是与数据相关的。容差停止标准 tol 不是直接与数据相关的(尽管根据某些评分函数的最佳值可能是)。

当调用 fit 时,应忽略对 fit 的任何先前调用。通常,调用 estimator.fit(X1) 然后调用 estimator.fit(X2) 应该与仅调用 estimator.fit(X2) 相同。但是,当 fit 依赖于某些随机过程时,这在实践中可能并不成立,请参阅random_state。此规则的另一个例外是当为支持它的估计器将超参数 warm_start 设置为 True 时。 warm_start=True 表示重用估计器的可训练参数的先前状态,而不是使用默认的初始化策略。

估计的属性#

从数据中估计的属性始终必须以尾部下划线结尾,例如,某些回归估计器的系数将在调用 fit 后存储在 coef_ 属性中。

当您第二次调用 fit 时,预计会覆盖估计的属性。

可选参数#

在迭代算法中,迭代次数应由一个名为 n_iter 的整数指定。

通用属性#

期望表格输入的估计器应在 fit 时设置一个 n_features_in_ 属性,以指示估计器在后续调用 predicttransform 时期望的特征数量。有关详细信息,请参阅 SLEP010

创建自己的估计器#

如果您想实现一个与 scikit-learn 兼容的新估计器,无论是为了自己使用还是为了贡献给 scikit-learn,除了上面概述的 scikit-learn API 之外,您还应该了解 scikit-learn 的几个内部机制。您可以通过对实例运行 check_estimator 来检查您的估计器是否符合 scikit-learn 接口和标准。 parametrize_with_checks pytest 装饰器也可以使用(有关详细信息和与 pytest 的可能交互,请参阅其文档字符串)。

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.svm import LinearSVC
>>> check_estimator(LinearSVC())  # passes

使一个类与 scikit-learn 估计器接口兼容的主要动机可能是您想将它与模型评估和选择工具一起使用,例如 model_selection.GridSearchCVpipeline.Pipeline

在详细说明下面所需的接口之前,我们将描述两种更轻松地实现正确接口的方法。

get_params 和 set_params#

所有 scikit-learn 估计器都有 get_paramsset_params 函数。 get_params 函数不接受任何参数,并返回估计器 __init__ 参数的字典,以及它们的值。

它必须接受一个关键字参数 deep,它接收一个布尔值,该值决定方法是否应该返回子估计器的参数(对于大多数估计器,这可以忽略)。 deep 的默认值应为 True。例如,考虑以下估计器

>>> from sklearn.base import BaseEstimator
>>> from sklearn.linear_model import LogisticRegression
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, subestimator=None, my_extra_param="random"):
...         self.subestimator = subestimator
...         self.my_extra_param = my_extra_param

参数 deep 将控制是否应该报告 subestimator 的参数。因此,当 deep=True 时,输出将为

>>> my_estimator = MyEstimator(subestimator=LogisticRegression())
>>> for param, value in my_estimator.get_params(deep=True).items():
...     print(f"{param} -> {value}")
my_extra_param -> random
subestimator__C -> 1.0
subestimator__class_weight -> None
subestimator__dual -> False
subestimator__fit_intercept -> True
subestimator__intercept_scaling -> 1
subestimator__l1_ratio -> None
subestimator__max_iter -> 100
subestimator__multi_class -> deprecated
subestimator__n_jobs -> None
subestimator__penalty -> l2
subestimator__random_state -> None
subestimator__solver -> lbfgs
subestimator__tol -> 0.0001
subestimator__verbose -> 0
subestimator__warm_start -> False
subestimator -> LogisticRegression()

通常,subestimator 具有一个名称(例如,Pipeline 对象中的命名步骤),在这种情况下,键将变为 <name>__C<name>__class_weight 等。

而当 deep=False 时,输出将为

>>> for param, value in my_estimator.get_params(deep=False).items():
...     print(f"{param} -> {value}")
my_extra_param -> random
subestimator -> LogisticRegression()

另一方面,set_params__init__ 的参数作为关键字参数,将它们解包到 'parameter': value 形式的字典中,并使用此字典设置估计器的参数。返回值必须是估计器本身。

虽然 get_params 机制不是必需的(请参阅下面的 克隆),但 set_params 函数是必要的,因为它用于在网格搜索期间设置参数。

实现这些函数的最简单方法,以及获得合理的 __repr__ 方法,是从 sklearn.base.BaseEstimator 继承。如果您不想使您的代码依赖于 scikit-learn,实现接口的最简单方法是

def get_params(self, deep=True):
    # suppose this estimator has parameters "alpha" and "recursive"
    return {"alpha": self.alpha, "recursive": self.recursive}

def set_params(self, **parameters):
    for parameter, value in parameters.items():
        setattr(self, parameter, value)
    return self

参数和 init#

由于 model_selection.GridSearchCV 使用 set_params 将参数设置应用于估计器,因此调用 set_params 的效果必须与使用 __init__ 方法设置参数的效果相同。实现此目的最简单且推荐的方法是 **不要在** __init__ **中进行任何参数验证**。所有关于估计器参数的逻辑,例如将字符串参数转换为函数,都应该在 fit 中完成。

此外,预计以 _ 结尾的参数 **不应在** __init__ **方法中设置**。所有且仅由 fit 设置的公共属性以 _ 结尾。因此,以 _ 结尾的参数的存在用于检查估计器是否已拟合。

克隆#

为了与 model_selection 模块一起使用,估计器必须支持 base.clone 函数来复制估计器。这可以通过提供 get_params 方法来完成。如果存在 get_params,则 clone(estimator) 将是 type(estimator) 的实例,在该实例上已使用 estimator.get_params() 结果的克隆调用了 set_params

如果将 safe=False 传递给 clone,则不提供此方法的对象将被深度复制(使用 Python 标准函数 copy.deepcopy)。

估计器可以通过定义 __sklearn_clone__ 方法来自定义 base.clone 的行为。 __sklearn_clone__ 必须返回估计器的实例。 __sklearn_clone__ 在估计器需要在对估计器调用 base.clone 时保持某种状态时很有用。例如,可以将用于转换器的冻结元估计器定义如下

class FrozenTransformer(BaseEstimator):
    def __init__(self, fitted_transformer):
        self.fitted_transformer = fitted_transformer

    def __getattr__(self, name):
        # `fitted_transformer`'s attributes are now accessible
        return getattr(self.fitted_transformer, name)

    def __sklearn_clone__(self):
        return self

    def fit(self, X, y):
        # Fitting does not change the state of the estimator
        return self

    def fit_transform(self, X, y=None):
        # fit_transform only transforms the data
        return self.fitted_transformer.transform(X, y)

管道兼容性#

为了使估计器能够与 pipeline.Pipeline 一起使用(除了最后一步之外),它需要提供 fitfit_transform 函数。为了能够在除训练集之外的任何数据上评估管道,它还需要提供一个 transform 函数。管道中最后一步没有特殊要求,除了它有一个 fit 函数。所有 fitfit_transform 函数必须接受参数 X, y,即使不使用 y。类似地,为了使 score 可用,管道最后一步需要有一个 score 函数,该函数接受可选的 y

估计器类型#

一些常见的函数依赖于传递的估计器类型。例如,在 model_selection.GridSearchCVmodel_selection.cross_val_score 中,交叉验证默认情况下在分类器上使用时是分层的,但在其他情况下则不是。类似地,用于平均精度的评分器,需要对连续预测进行评分,需要对分类器调用 decision_function,但对回归器调用 predict。这种分类器和回归器之间的区别是通过 _estimator_type 属性实现的,该属性取字符串值。对于分类器,它应该是 "classifier",对于回归器,它应该是 "regressor",对于聚类方法,它应该是 "clusterer",才能按预期工作。继承自 ClassifierMixinRegressorMixinClusterMixin 将自动设置该属性。当元估计器需要区分估计器类型时,应该使用像 base.is_classifier 这样的辅助函数,而不是直接检查 _estimator_type

特定模型#

分类器应该接受 y(目标)参数传递给 fit,这些参数应该是字符串或整数的序列(列表、数组)。它们不应该假设类标签是连续的整数范围;相反,它们应该在一个 classes_ 属性或属性中存储一个类列表。此属性中类标签的顺序应与 predict_probapredict_log_probadecision_function 返回其值的顺序匹配。实现这一点最简单的方法是将

self.classes_, y = np.unique(y, return_inverse=True)

放入 fit 中。这将返回一个新的 y,它包含类索引,而不是标签,范围为 [0, n_classes)。

分类器的 predict 方法应该返回包含来自 classes_ 的类标签的数组。在实现 decision_function 的分类器中,这可以通过以下方式实现

def predict(self, X):
    D = self.decision_function(X)
    return self.classes_[np.argmax(D, axis=1)]

在线性模型中,系数存储在一个名为 coef_ 的数组中,独立项存储在 intercept_ 中。 sklearn.linear_model._base 包含一些实现常见线性模型模式的基类和混合类。

multiclass 模块包含用于处理多类和多标签问题的有用函数。

估计器标签#

警告

估计器标签是实验性的,API 可能会发生变化。

Scikit-learn 在版本 0.21 中引入了估计器标签。这些是估计器的注释,允许以编程方式检查其功能,例如稀疏矩阵支持、支持的输出类型和支持的方法。估计器标签是一个字典,由方法 _get_tags() 返回。这些标签用于由 check_estimator 函数和 parametrize_with_checks 装饰器运行的通用检查。标签决定要运行哪些检查以及哪些输入数据是合适的。标签可能取决于估计器参数,甚至取决于系统架构,通常只能在运行时确定。

当前的估计器标签集是

allow_nan (默认值为 False)

估计器是否支持将缺失值编码为 np.nan 的数据

array_api_support (默认值为 False)

估计器是否支持与 Array API 兼容的输入。

binary_only (默认值为 False)

估计器是否支持二元分类,但缺乏多类分类支持。

multilabel (默认值为 False)

估计器是否支持多标签输出

multioutput (默认值为 False)

回归器是否支持多目标输出,或者分类器是否支持多类多输出。

multioutput_only (默认值为 False)

估计器是否只支持多输出分类或回归。

no_validation (默认值为 False)

估计器是否跳过输入验证。这仅适用于无状态和虚拟转换器!

non_deterministic (默认值为 False)

给定固定的 random_state,估计器是否不确定性

pairwise (默认值为 False)

此布尔属性指示数据 (X) fit 和类似方法是否包含样本上的成对度量,而不是每个样本的特征表示。它通常为 True,其中估计器具有 metricaffinitykernel 参数,其值为 ‘precomputed’。它的主要目的是支持一个 元估计器 或一个交叉验证过程,该过程提取一个用于成对估计器的子样本数据,其中数据需要在两个轴上进行索引。具体来说,此标签由 sklearn.utils.metaestimators._safe_split 用于切片行和列。

preserves_dtype (默认值为 ``[np.float64]``)

仅适用于转换器。它对应于将被保留的数据类型,使得在调用 transformer.transform(X) 后,X_trans.dtypeX.dtype 相同。如果此列表为空,则转换器预计不会保留数据类型。列表中的第一个值被认为是默认数据类型,对应于当输入数据类型不会被保留时输出的数据类型。

poor_score (默认值为 False)

估计器是否无法提供“合理”的测试集分数,目前对于回归来说是 make_regression(n_samples=200, n_features=10, n_informative=1, bias=5.0, noise=20, random_state=42) 上的 R2 为 0.5,对于分类来说是 make_blobs(n_samples=300, random_state=0) 上的准确率为 0.83。这些数据集和值基于 sklearn 中当前的估计器,可能会被更系统的东西取代。

requires_fit (默认值为 True)

估计器是否需要在调用 transformpredictpredict_probadecision_function 之前进行拟合。

requires_positive_X (默认值为 False)

估计器是否需要正 X。

requires_y (默认值为 False)

估计器是否需要将 y 传递给 fitfit_predictfit_transform 方法。对于继承自 ~sklearn.base.RegressorMixin~sklearn.base.ClassifierMixin 的估计器,该标签为 True。

requires_positive_y (默认值为 False)

估计器是否需要正 y(仅适用于回归)。

_skip_test (默认值为 False)

是否完全跳过通用测试。除非你有充分的理由,否则不要使用它。

_xfail_checks (默认值为 False)

字典 {check_name: reason} 包含一些常见的检查,这些检查在使用 parametrize_with_checks 时,将被标记为 XFAIL 用于 pytest。这些检查将被简单地忽略,不会被 check_estimator 执行,但会引发 SkipTestWarning。除非你的估计器有非常充分的理由无法通过检查,否则不要使用此标记。还要注意,此标记的使用方式可能会发生很大变化,因为我们正在努力使其更加灵活:请做好应对未来破坏性更改的准备。

无状态 (默认值:False)

估计器是否需要访问数据进行拟合。即使估计器是无状态的,它可能仍然需要调用 fit 进行初始化。

X_types (默认值:[‘2darray’])

X 支持的输入类型,以字符串列表形式表示。目前,只有当列表中包含 ‘2darray’ 时才会运行测试,这表示估计器接受连续的二维 numpy 数组作为输入。默认值为 [‘2darray’]。其他可能的类型包括 'string''sparse''categorical'dict'1dlabels''2dlabels'。目标是在未来,支持的输入类型将决定测试期间使用的数据,特别是对于 'string''sparse''categorical' 数据。目前,稀疏数据的测试没有使用 'sparse' 标记。

每个标记的默认值不太可能适合你的特定估计器的需求。可以通过定义一个 _more_tags() 方法来创建额外的标记或覆盖默认标记,该方法返回一个包含所需覆盖标记或新标记的字典。例如

class MyMultiOutputEstimator(BaseEstimator):

    def _more_tags(self):
        return {'multioutput_only': True,
                'non_deterministic': True}

任何不在 _more_tags() 中的标记将回退到上面记录的默认值。

即使不建议这样做,也可以覆盖 _get_tags() 方法。但是请注意,**所有标记都必须出现在字典中**。如果上面记录的任何键不在 _get_tags() 的输出中,就会发生错误。

除了标记之外,估计器还需要在 _required_parameters 类属性中声明对 __init__ 的任何非可选参数,该属性是一个列表或元组。如果 _required_parameters 仅为 ["estimator"]["base_estimator"],则估计器将在测试中使用 LogisticRegression(或如果估计器是回归器,则使用 RidgeRegression)的实例进行实例化。选择这两个模型是有点特立独行的,但它们都应该提供稳健的闭式解。

用于 set_output 的开发者 API#

随着 SLEP018 的发布,scikit-learn 引入了 set_output API,用于配置转换器以输出 pandas DataFrame。如果转换器定义了 get_feature_names_out 并子类化了 base.TransformerMixin,则会自动定义 set_output API。 get_feature_names_out 用于获取 pandas 输出的列名。

base.OneToOneFeatureMixinbase.ClassNamePrefixFeaturesOutMixin 是用于定义 get_feature_names_out 的有用 mixin。 base.OneToOneFeatureMixin 在转换器对输入特征和输出特征之间存在一对一对应关系时很有用,例如 StandardScalerbase.ClassNamePrefixFeaturesOutMixin 在转换器需要生成自己的特征名称输出时很有用,例如 PCA

可以通过在定义自定义子类时设置 auto_wrap_output_keys=None 来选择退出 set_output API。

class MyTransformer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):

    def fit(self, X, y=None):
        return self
    def transform(self, X, y=None):
        return X
    def get_feature_names_out(self, input_features=None):
        ...

auto_wrap_output_keys 的默认值为 ("transform",),它会自动包装 fit_transformtransformTransformerMixin 使用 __init_subclass__ 机制来使用 auto_wrap_output_keys 并将所有其他关键字参数传递给它的超类。超类的 __init_subclass__ **不应该** 依赖于 auto_wrap_output_keys

对于在 transform 中返回多个数组的转换器,自动包装将只包装第一个数组,而不会更改其他数组。

有关如何使用 API 的示例,请参阅 介绍 set_output API

用于 check_is_fitted 的开发者 API#

默认情况下, check_is_fitted 检查实例中是否存在任何以下划线结尾的属性,例如 coef_。估计器可以通过实现一个不接受任何输入并返回布尔值的 __sklearn_is_fitted__ 方法来更改行为。如果此方法存在, check_is_fitted 将简单地返回其输出。

有关如何使用 API 的示例,请参阅 __sklearn_is_fitted__ 作为开发者 API

用于 HTML 表示的开发者 API#

警告

HTML 表示 API 处于实验阶段,API 可能会发生变化。

继承自 BaseEstimator 的估计器在交互式编程环境(如 Jupyter 笔记本)中显示其自身的 HTML 表示。例如,我们可以显示此 HTML 图表

from sklearn.base import BaseEstimator

BaseEstimator()

通过在估计器实例上调用函数 estimator_html_repr 来获取原始 HTML 表示。

要自定义指向估计器文档的 URL 链接(即单击“?”图标时),请覆盖 _doc_link_module_doc_link_template 属性。此外,还可以提供一个 _doc_link_url_param_generator 方法。将 _doc_link_module 设置为包含估计器的(顶级)模块的名称。如果该值与顶级模块名称不匹配,则 HTML 表示将不包含指向文档的链接。对于 scikit-learn 估计器,它被设置为 "sklearn"

_doc_link_template 用于构建最终 URL。默认情况下,它可以包含两个变量:estimator_module(包含估计器的模块的完整名称)和 estimator_name(估计器的类名)。如果你需要更多变量,则应该实现 _doc_link_url_param_generator 方法,该方法应该返回一个包含变量及其值的字典。此字典将用于渲染 _doc_link_template

编码指南#

以下是一些关于如何编写新代码以包含在 scikit-learn 中的指南,这些指南也可能适用于外部项目。当然,存在特殊情况,这些规则也会有例外。但是,在提交新代码时遵循这些规则可以使审查更容易,从而可以更快地集成新代码。

统一格式化的代码使代码共享所有权变得更容易。scikit-learn 项目试图严格遵循 PEP8 中详细说明的官方 Python 指南,这些指南详细说明了代码的格式和缩进方式。请阅读并遵循它。

此外,我们添加以下指南

  • 使用下划线分隔非类名称中的单词:n_samples 而不是 nsamples

  • 避免在一行中使用多个语句。在控制流语句 (if/for) 之后首选换行。

  • 对 scikit-learn 内部引用使用相对导入。

  • 单元测试是之前规则的例外;它们应该使用绝对导入,就像客户端代码一样。推论是,如果 sklearn.foo 导出在 sklearn.foo.bar.baz 中实现的类或函数,则测试应该从 sklearn.foo 导入它。

  • 请不要使用 import * 在任何情况下。它被 官方 Python 建议 认为是有害的。它使代码更难阅读,因为符号的来源不再被明确引用,但最重要的是,它阻止使用像 pyflakes 这样的静态分析工具来自动查找 scikit-learn 中的错误。

  • 在所有文档字符串中使用 numpy 文档字符串标准

我们喜欢的代码的一个很好的例子可以在这里找到 这里

输入验证#

模块 sklearn.utils 包含用于执行输入验证和转换的各种函数。有时,np.asarray 足以进行验证;不要使用 np.asanyarraynp.atleast_2d,因为这些允许 NumPy 的 np.matrix 通过,它具有不同的 API(例如,*np.matrix 上表示点积,但在 np.ndarray 上表示 Hadamard 积)。

在其他情况下,请确保在传递给 scikit-learn API 函数的任何类数组参数上调用 check_array。要使用的确切参数主要取决于是否以及哪些 scipy.sparse 矩阵必须被接受。

有关更多信息,请参阅 开发者实用程序 页面。

随机数#

如果您的代码依赖于随机数生成器,请不要使用 numpy.random.random() 或类似例程。为了确保错误检查中的可重复性,例程应该接受一个关键字 random_state 并使用它来构造一个 numpy.random.RandomState 对象。请参阅 sklearn.utils.check_random_state开发者实用程序 中。

这是一个使用上述一些指南的简单代码示例

from sklearn.utils import check_array, check_random_state

def choose_random_sample(X, random_state=0):
    """Choose a random point from X.

    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        An array representing the data.
    random_state : int or RandomState instance, default=0
        The seed of the pseudo random number generator that selects a
        random sample. Pass an int for reproducible output across multiple
        function calls.
        See :term:`Glossary <random_state>`.

    Returns
    -------
    x : ndarray of shape (n_features,)
        A random point selected from X.
    """
    X = check_array(X)
    random_state = check_random_state(random_state)
    i = random_state.randint(X.shape[0])
    return X[i]

如果您在估计器中使用随机性而不是独立函数,则会应用一些额外的指南。

首先,估计器应该在其 __init__ 中接受一个 random_state 参数,其默认值为 None。它应该将该参数的值(未修改)存储在属性 random_state 中。 fit 可以调用 check_random_state 在该属性上获取实际的随机数生成器。如果由于某种原因,在 fit 之后需要随机性,则 RNG 应该存储在属性 random_state_ 中。以下示例应该使这一点清楚

class GaussianNoise(BaseEstimator, TransformerMixin):
    """This estimator ignores its input and returns random Gaussian noise.

    It also does not adhere to all scikit-learn conventions,
    but showcases how to handle randomness.
    """

    def __init__(self, n_components=100, random_state=None):
        self.random_state = random_state
        self.n_components = n_components

    # the arguments are ignored anyway, so we make them optional
    def fit(self, X=None, y=None):
        self.random_state_ = check_random_state(self.random_state)

    def transform(self, X):
        n_samples = X.shape[0]
        return self.random_state_.randn(n_samples, self.n_components)

这种设置的原因是可重复性:当估计器对相同的数据进行两次 fit 时,它应该在两次都产生相同的模型,因此在 fit 中进行验证,而不是在 __init__ 中进行验证。

测试中的数值断言#

当断言连续值数组的准相等性时,请使用 sklearn.utils._testing.assert_allclose

相对容差会根据提供的数组数据类型自动推断(特别是对于 float32 和 float64 数据类型),但您可以通过 rtol 覆盖。

当比较零元素数组时,请通过 atol 提供一个非零的绝对容差值。

有关更多信息,请参阅 sklearn.utils._testing.assert_allclose 的文档字符串。