元数据路由#

本文档展示了如何使用 scikit-learn 中的 元数据路由机制 将元数据路由到使用它们的估计器、评分器和 CV 分割器。

为了更好地理解以下文档,我们需要介绍两个概念:路由器和消费者。路由器是一个将给定数据和元数据转发到其他对象的对象。在大多数情况下,路由器是一个元估计器,即一个将另一个估计器作为参数的估计器。像sklearn.model_selection.cross_validate这样的函数,它接受一个估计器作为参数并转发数据和元数据,也是一个路由器。

另一方面,消费者是一个接受并使用某些给定元数据的对象。例如,一个在fit方法中考虑sample_weight的估计器是sample_weight的消费者。

一个对象可以同时是路由器和消费者。例如,元估计器可能会在某些计算中考虑sample_weight,但它也可能将其路由到底层估计器。

首先是一些导入和一些用于脚本其余部分的随机数据。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from pprint import pprint

import numpy as np

from sklearn import set_config
from sklearn.base import (
    BaseEstimator,
    ClassifierMixin,
    MetaEstimatorMixin,
    RegressorMixin,
    TransformerMixin,
    clone,
)
from sklearn.linear_model import LinearRegression
from sklearn.utils import metadata_routing
from sklearn.utils.metadata_routing import (
    ,
    ,
    ,
    ,
)
from sklearn.utils.validation import check_is_fitted

n_samples, n_features = 100, 4
rng = np.random.RandomState(42)
X = rng.rand(n_samples, n_features)
y = rng.randint(0, 2, size=n_samples)
my_groups = rng.randint(0, 10, size=n_samples)
my_weights = rng.rand(n_samples)
my_other_weights = rng.rand(n_samples)

只有显式启用元数据路由才可用。

set_config(enable_metadata_routing=True)

此实用程序函数是一个虚拟函数,用于检查是否传递了元数据。

def check_metadata(obj, **kwargs):
    for key, value in kwargs.items():
        if value is not None:
            print(
                f"Received {key} of length = {len(value)} in {obj.__class__.__name__}."
            )
        else:
            print(f"{key} is None in {obj.__class__.__name__}.")

一个用于很好地打印对象路由信息的实用程序函数。

def print_routing(obj):
    pprint(obj.get_metadata_routing()._serialize())

使用估计器#

在这里,我们演示了估计器如何公开所需的API以支持元数据路由作为消费者。想象一个简单的分类器,在其fit方法中接受sample_weight作为元数据,在其predict方法中接受groups作为元数据。

class ExampleClassifier(ClassifierMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        # all classifiers need to expose a classes_ attribute once they're fit.
        self.classes_ = np.array([0, 1])
        return self

    def predict(self, X, groups=None):
        check_metadata(self, groups=groups)
        # return a constant value of 1, not a very smart classifier!
        return np.ones(len(X))

上述估计器现在具备了使用元数据所需的一切。这是通过BaseEstimator中完成的一些操作实现的。上述类现在公开了三个方法:set_fit_requestset_predict_requestget_metadata_routing。还有一个用于sample_weightset_score_request,因为它存在于ClassifierMixin中,后者实现了一个接受sample_weightscore方法。继承自RegressorMixin的回归器也适用。

默认情况下,不请求任何元数据,我们可以看到:

print_routing(ExampleClassifier())
{'fit': {'sample_weight': None},
 'predict': {'groups': None},
 'score': {'sample_weight': None}}

上述输出意味着ExampleClassifier没有请求sample_weightgroups,如果给路由器提供这些元数据,它应该抛出错误,因为用户没有显式设置它们是否必需。 score方法中的sample_weight也是如此,它继承自ClassifierMixin。为了显式设置这些元数据的请求值,我们可以使用这些方法:

est = (
    ExampleClassifier()
    .set_fit_request(sample_weight=False)
    .set_predict_request(groups=True)
    .set_score_request(sample_weight=False)
)
print_routing(est)
{'fit': {'sample_weight': False},
 'predict': {'groups': True},
 'score': {'sample_weight': False}}

注意

请注意,只要上述估计器不用于元估计器,用户就不需要为元数据设置任何请求,并且设置的值会被忽略,因为消费者不会验证或路由给定的元数据。上述估计器的简单用法将按预期工作。

est = ExampleClassifier()
est.fit(X, y, sample_weight=my_weights)
est.predict(X[:3, :], groups=my_groups)
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleClassifier.

array([1., 1., 1.])

路由元估计器#

现在,我们展示如何设计一个元估计器作为路由器。作为一个简化的例子,这是一个除了路由元数据之外什么都不做的元估计器。

class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        # This method defines the routing for this meta-estimator.
        # In order to do so, a `MetadataRouter` instance is created, and the
        # routing is added to it. More explanations follow below.
        router = (owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=()
            .add(caller="fit", callee="fit")
            .add(caller="predict", callee="predict")
            .add(caller="score", callee="score"),
        )
        return router

    def fit(self, X, y, **fit_params):
        # `get_routing_for_object` returns a copy of the `MetadataRouter`
        # constructed by the above `get_metadata_routing` method, that is
        # internally called.
        request_router = (self)
        # Meta-estimators are responsible for validating the given metadata.
        # `method` refers to the parent's method, i.e. `fit` in this example.
        request_router.validate_metadata(params=fit_params, method="fit")
        # `MetadataRouter.route_params` maps the given metadata to the metadata
        # required by the underlying estimator based on the routing information
        # defined by the MetadataRouter. The output of type `Bunch` has a key
        # for each consuming object and those hold keys for their consuming
        # methods, which then contain key for the metadata which should be
        # routed to them.
        routed_params = request_router.route_params(params=fit_params, caller="fit")

        # A sub-estimator is fitted and its classes are attributed to the
        # meta-estimator.
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # As in `fit`, we get a copy of the object's MetadataRouter,
        request_router = (self)
        # then we validate the given metadata,
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying `predict` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

让我们分解上述代码的不同部分。

首先,get_routing_for_object 获取我们的元估计器(self)并返回一个MetadataRouter,或者如果对象是消费者,则返回一个MetadataRequest,这基于估计器的get_metadata_routing方法的输出。

然后在每个方法中,我们使用route_params方法构造一个{"object_name": {"method_name": {"metadata": value}}}形式的字典,传递到底层估计器的方法。 object_name(在上面的routed_params.estimator.fit示例中为estimator)与在get_metadata_routing中添加的对象名相同。validate_metadata确保所有给定的元数据都被请求,以避免出现静默错误。

接下来,我们将说明不同的行为,特别是引发的错误类型。

meta_est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
meta_est.fit(X, y, sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
在Jupyter环境中,请重新运行此单元格以显示HTML表示或信任笔记本。
在GitHub上,HTML表示无法呈现,请尝试使用nbviewer.org加载此页面。


请注意,以上示例通过 ExampleClassifier 调用我们的实用函数 check_metadata()。它检查 sample_weight 是否正确传递。如果未正确传递,例如在以下示例中,它将打印 sample_weightNone

meta_est.fit(X, y)
sample_weight is None in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
在Jupyter环境中,请重新运行此单元格以显示HTML表示或信任笔记本。
在GitHub上,HTML表示无法呈现,请尝试使用nbviewer.org加载此页面。


如果我们传递未知元数据,则会引发错误。

try:
    meta_est.fit(X, y, test=my_weights)
except TypeError as e:
    print(e)
MetaClassifier.fit got unexpected argument(s) {'test'}, which are not routed to any object.

如果我们传递未明确请求的元数据。

try:
    meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)
except ValueError as e:
    print(e)
Received sample_weight of length = 100 in ExampleClassifier.
[groups] are passed but are not explicitly set as requested or not requested for ExampleClassifier.predict, which is used within MetaClassifier.predict. Call `ExampleClassifier.set_predict_request({metadata}=True/False)` for each metadata you want to request/ignore.

此外,如果我们明确将其设置为未请求,但它被提供。

meta_est = MetaClassifier(
    estimator=ExampleClassifier()
    .set_fit_request(sample_weight=True)
    .set_predict_request(groups=False)
)
try:
    meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)
except TypeError as e:
    print(e)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier.predict got unexpected argument(s) {'groups'}, which are not routed to any object.

另一个需要介绍的概念是**别名元数据 (aliased metadata)**。当估计器请求的元数据变量名与默认变量名不同时,就会出现这种情况。例如,在一个管道中存在两个估计器的设置中,一个可以请求 sample_weight1,另一个可以请求 sample_weight2。请注意,这不会改变估计器期望的内容,它只告诉元估计器如何将提供的元数据映射到所需的内容。这是一个示例,我们将 aliased_sample_weight 传递给元估计器,但元估计器理解 aliased_sample_weightsample_weight 的别名,并将它作为 sample_weight 传递给底层估计器。

meta_est = MetaClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
meta_est.fit(X, y, aliased_sample_weight=my_weights)
Received sample_weight of length = 100 in ExampleClassifier.
MetaClassifier(estimator=ExampleClassifier())
在Jupyter环境中,请重新运行此单元格以显示HTML表示或信任笔记本。
在GitHub上,HTML表示无法呈现,请尝试使用nbviewer.org加载此页面。


在此处传递 sample_weight 将失败,因为它使用别名请求,并且未请求同名 sample_weight

try:
    meta_est.fit(X, y, sample_weight=my_weights)
except TypeError as e:
    print(e)
MetaClassifier.fit got unexpected argument(s) {'sample_weight'}, which are not routed to any object.

这将我们引向 get_metadata_routing。Scikit-learn 中路由的工作方式是,使用者请求他们需要的内容,路由器将这些内容传递过去。此外,路由器还会公开它自身所需的内容,以便它可以用于另一个路由器中,例如网格搜索对象中的管道。 get_metadata_routing 的输出是一个 MetadataRouter 的字典表示,它包含所有嵌套对象请求的元数据的完整树及其对应的使用方法路由,即元估计器的哪个方法使用了子估计器的哪个方法。

print_routing(meta_est)
{'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

如您所见,方法 fit 唯一请求的元数据是 "sample_weight",其别名为 "aliased_sample_weight"~utils.metadata_routing.MetadataRouter 类使我们能够轻松创建路由对象,该对象将为我们的 get_metadata_routing 创建所需的输出。

为了理解元估计器中别名是如何工作的,让我们想象一下我们的元估计器在另一个元估计器内部。

meta_meta_est = MetaClassifier(estimator=meta_est).fit(
    X, y, aliased_sample_weight=my_weights
)
Received sample_weight of length = 100 in ExampleClassifier.

在上面的示例中,这就是 meta_meta_estfit 方法如何调用其子估计器的 fit 方法。

# user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:
meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):
    ...

    # the first sub-estimator (`meta_est`) expects `aliased_sample_weight`
    self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):
        ...

        # the second sub-estimator (`est`) expects `sample_weight`
        self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):
            ...

使用和路由元估计器 (Consuming and routing Meta-Estimator)#

对于稍微复杂一些的示例,考虑一个元估计器,它像以前一样将元数据路由到底层估计器,但它还在自己的方法中使用一些元数据。这个元估计器同时是使用者和路由器。实现它与我们之前的实现非常相似,但有一些调整。

class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def get_metadata_routing(self):
        router = (
            (owner=self.__class__.__name__)
            # defining metadata routing request values for usage in the meta-estimator
            .add_self_request(self)
            # defining metadata routing request values for usage in the sub-estimator
            .add(
                estimator=self.estimator,
                method_mapping=()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict")
                .add(caller="score", callee="score"),
            )
        )
        return router

    # Since `sample_weight` is used and consumed here, it should be defined as
    # an explicit argument in the method's signature. All other metadata which
    # are only routed, will be passed as `**fit_params`:
    def fit(self, X, y, sample_weight, **fit_params):
        if self.estimator is None:
            raise ValueError("estimator cannot be None!")

        check_metadata(self, sample_weight=sample_weight)

        # We add `sample_weight` to the `fit_params` dictionary.
        if sample_weight is not None:
            fit_params["sample_weight"] = sample_weight

        request_router = (self)
        request_router.validate_metadata(params=fit_params, method="fit")
        routed_params = request_router.route_params(params=fit_params, caller="fit")
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)
        self.classes_ = self.estimator_.classes_
        return self

    def predict(self, X, **predict_params):
        check_is_fitted(self)
        # As in `fit`, we get a copy of the object's MetadataRouter,
        request_router = (self)
        # we validate the given metadata,
        request_router.validate_metadata(params=predict_params, method="predict")
        # and then prepare the input to the underlying ``predict`` method.
        routed_params = request_router.route_params(
            params=predict_params, caller="predict"
        )
        return self.estimator_.predict(X, **routed_params.estimator.predict)

上述元估计器与我们之前的元估计器不同的关键部分在于,它在 fit 中显式地接受 sample_weight 并将其包含在 fit_params 中。由于 sample_weight 是一个显式参数,我们可以确定此方法存在 set_fit_request(sample_weight=...)。元估计器既是使用者,也是 sample_weight 的路由器。

get_metadata_routing 中,我们使用 add_self_requestself 添加到路由中,以指示此估计器也在使用 sample_weight 并且是一个路由器;这也向路由信息添加了一个 $self_request 键,如下所示。现在让我们看一些例子。

  • 未请求元数据。

meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': None},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}
  • 子估计器请求 sample_weight

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight=True)
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': True},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}
  • 元估计器请求 sample_weight

meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(
    sample_weight=True
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': True},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': None},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

请注意上面请求的元数据表示之间的区别。

  • 我们还可以为元估计器和子估计器的 fit 方法传递不同的值来为元数据设置别名。

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"),
).set_fit_request(sample_weight="meta_clf_sample_weight")
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': 'meta_clf_sample_weight'},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'clf_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

但是,元估计器的 fit 只需要子估计器的别名,并将其自身的样本权重指定为 sample_weight,因为它不验证和路由其自身所需的元数据。

meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
RouterConsumerClassifier(estimator=ExampleClassifier())
在Jupyter环境中,请重新运行此单元格以显示HTML表示或信任笔记本。
在GitHub上,HTML表示无法呈现,请尝试使用nbviewer.org加载此页面。


  • 仅在子估计器上使用别名。

当我们不希望元估计器使用元数据,但子估计器应该使用时,这很有用。

meta_est = RouterConsumerClassifier(
    estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight")
)
print_routing(meta_est)
{'$self_request': {'fit': {'sample_weight': None},
                   'score': {'sample_weight': None}},
 'estimator': {'mapping': [{'callee': 'fit', 'caller': 'fit'},
                           {'callee': 'predict', 'caller': 'predict'},
                           {'callee': 'score', 'caller': 'score'}],
               'router': {'fit': {'sample_weight': 'aliased_sample_weight'},
                          'predict': {'groups': None},
                          'score': {'sample_weight': None}}}}

元估计器无法使用 aliased_sample_weight,因为它期望它作为 sample_weight 传递。即使在其上设置了 set_fit_request(sample_weight=True),这也适用。

简单管道 (Simple Pipeline)#

一个稍微复杂一点的用例是一个类似于Pipeline的元估计器。这是一个元估计器,它接受一个转换器和一个分类器。当调用它的fit方法时,它会在对转换后的数据运行分类器之前应用转换器的fittransform。在predict时,它会在使用分类器的predict方法对转换后的新数据进行预测之前应用转换器的transform

class SimplePipeline(ClassifierMixin, BaseEstimator):
    def __init__(self, transformer, classifier):
        self.transformer = transformer
        self.classifier = classifier

    def get_metadata_routing(self):
        router = (
            (owner=self.__class__.__name__)
            # We add the routing for the transformer.
            .add(
                transformer=self.transformer,
                method_mapping=()
                # The metadata is routed such that it retraces how
                # `SimplePipeline` internally calls the transformer's `fit` and
                # `transform` methods in its own methods (`fit` and `predict`).
                .add(caller="fit", callee="fit")
                .add(caller="fit", callee="transform")
                .add(caller="predict", callee="transform"),
            )
            # We add the routing for the classifier.
            .add(
                classifier=self.classifier,
                method_mapping=()
                .add(caller="fit", callee="fit")
                .add(caller="predict", callee="predict"),
            )
        )
        return router

    def fit(self, X, y, **fit_params):
        routed_params = (self, "fit", **fit_params)

        self.transformer_ = clone(self.transformer).fit(
            X, y, **routed_params.transformer.fit
        )
        X_transformed = self.transformer_.transform(
            X, **routed_params.transformer.transform
        )

        self.classifier_ = clone(self.classifier).fit(
            X_transformed, y, **routed_params.classifier.fit
        )
        return self

    def predict(self, X, **predict_params):
        routed_params = (self, "predict", **predict_params)

        X_transformed = self.transformer_.transform(
            X, **routed_params.transformer.transform
        )
        return self.classifier_.predict(
            X_transformed, **routed_params.classifier.predict
        )

注意MethodMapping的用法,它声明子估计器(被调用者)的哪些方法在元估计器(调用者)的哪些方法中使用。正如你所看到的,SimplePipelinefit中使用转换器的transformfit方法,并在predict中使用它的transform方法,这就是你在管道类的路由结构中看到的实现。

上面这个例子与之前的例子另一个不同之处在于使用了process_routing,它处理输入参数,进行必要的验证,并返回我们在之前的例子中创建的routed_params。这减少了开发人员在每个元估计器方法中需要编写的样板代码。强烈建议开发人员使用此函数,除非有充分理由反对。

为了测试上面的管道,让我们添加一个示例转换器。

class ExampleTransformer(TransformerMixin, BaseEstimator):
    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def transform(self, X, groups=None):
        check_metadata(self, groups=groups)
        return X

    def fit_transform(self, X, y, sample_weight=None, groups=None):
        return self.fit(X, y, sample_weight).transform(X, groups)

请注意,在上面的例子中,我们实现了fit_transform,它使用适当的元数据调用fittransform。只有当transform接受元数据时,才需要这样做,因为TransformerMixin中的默认fit_transform实现不会将元数据传递给transform

现在我们可以测试我们的管道,看看元数据是否正确传递。这个例子使用我们的SimplePipeline,我们的ExampleTransformer,以及使用我们的ExampleClassifierRouterConsumerClassifier

pipe = SimplePipeline(
    transformer=ExampleTransformer()
    # we set transformer's fit to receive sample_weight
    .set_fit_request(sample_weight=True)
    # we set transformer's transform to receive groups
    .set_transform_request(groups=True),
    classifier=RouterConsumerClassifier(
        estimator=ExampleClassifier()
        # we want this sub-estimator to receive sample_weight in fit
        .set_fit_request(sample_weight=True)
        # but not groups in predict
        .set_predict_request(groups=False),
    )
    # and we want the meta-estimator to receive sample_weight as well
    .set_fit_request(sample_weight=True),
)
pipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(
    X[:3], groups=my_groups
)
Received sample_weight of length = 100 in ExampleTransformer.
Received groups of length = 100 in ExampleTransformer.
Received sample_weight of length = 100 in RouterConsumerClassifier.
Received sample_weight of length = 100 in ExampleClassifier.
Received groups of length = 100 in ExampleTransformer.
groups is None in ExampleClassifier.

array([1., 1., 1.])

弃用/默认值更改#

在本节中,我们将展示如何处理路由器也成为使用者的情况,尤其是在它使用与其子估计器相同的元数据,或者使用者开始使用在旧版本中没有使用的元数据的情况下。在这种情况下,应该发出警告一段时间,让用户知道行为与以前的版本有所改变。

class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, **fit_params):
        routed_params = (self, "fit", **fit_params)
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)

    def get_metadata_routing(self):
        router = (owner=self.__class__.__name__).add(
            estimator=self.estimator,
            method_mapping=().add(caller="fit", callee="fit"),
        )
        return router

如上所述,如果my_weights不应作为sample_weight传递给MetaRegressor,则这是一个有效的用法。

reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))
reg.fit(X, y, sample_weight=my_weights)

现在想象一下,我们进一步开发MetaRegressor,它现在也使用sample_weight

class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):
    # show warning to remind user to explicitly set the value with
    # `.set_{method}_request(sample_weight={boolean})`
    __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    def __init__(self, estimator):
        self.estimator = estimator

    def fit(self, X, y, sample_weight=None, **fit_params):
        routed_params = (
            self, "fit", sample_weight=sample_weight, **fit_params
        )
        check_metadata(self, sample_weight=sample_weight)
        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)

    def get_metadata_routing(self):
        router = (
            (owner=self.__class__.__name__)
            .add_self_request(self)
            .add(
                estimator=self.estimator,
                method_mapping=().add(caller="fit", callee="fit"),
            )
        )
        return router

上面的实现几乎与MetaRegressor相同,并且由于在__metadata_request__fit中定义的默认请求值,所以在拟合时会发出警告。

with warnings.catch_warnings(record=True) as record:
    WeightedMetaRegressor(
        estimator=LinearRegression().set_fit_request(sample_weight=False)
    ).fit(X, y, sample_weight=my_weights)
for w in record:
    print(w.message)
Received sample_weight of length = 100 in WeightedMetaRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.

当估计器使用它以前没有使用的元数据时,可以使用以下模式来警告用户。

class ExampleRegressor(RegressorMixin, BaseEstimator):
    __metadata_request__fit = {"sample_weight": metadata_routing.WARN}

    def fit(self, X, y, sample_weight=None):
        check_metadata(self, sample_weight=sample_weight)
        return self

    def predict(self, X):
        return np.zeros(shape=(len(X)))


with warnings.catch_warnings(record=True) as record:
    MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)
for w in record:
    print(w.message)
sample_weight is None in ExampleRegressor.
Support for sample_weight has recently been added to this class. To maintain backward compatibility, it is ignored now. Using `set_fit_request(sample_weight={True, False})` on this method of the class, you can set the request value to False to silence this warning, or to True to consume and use the metadata.

最后,我们禁用元数据路由的配置标志。

set_config(enable_metadata_routing=False)

第三方开发和scikit-learn依赖#

如上所示,信息使用MetadataRequestMetadataRouter在类之间进行通信。强烈建议不要这样做,但如果严格希望拥有一个与scikit-learn兼容的估计器,而无需依赖scikit-learn包,则可以对与元数据路由相关的工具进行供应商管理。如果满足以下所有条件,则完全无需修改代码

  • 你的估计器继承自BaseEstimator

  • 你的估计器方法(例如fit)使用的参数在方法的签名中明确定义,而不是*args*kwargs

  • 你的估计器不向底层对象路由任何元数据,即它不是路由器

脚本总运行时间:(0 分钟 0.043 秒)

相关示例

使用预计算的 Gram 矩阵和加权样本拟合弹性网络

使用预计算的 Gram 矩阵和加权样本拟合弹性网络

scikit-learn 1.4 发行亮点

scikit-learn 1.4 发行亮点

SGD:加权样本

SGD:加权样本

scikit-learn 1.6 发行亮点

scikit-learn 1.6 发行亮点

由 Sphinx-Gallery 生成的图库