BaseEstimator#

class sklearn.base.BaseEstimator[source]#

scikit-learn 中所有估计器的基类。

继承此类的对象将获得以下默认实现:

  • 用于 GridSearchCV 及相关对象的参数设置与获取;

  • 在终端和 IDE 中显示的文本和 HTML 表示;

  • 估计器序列化;

  • 参数验证;

  • 数据验证;

  • 特征名称验证。

请参阅用户指南以了解更多信息。

备注

所有估计器都应在其 __init__ 方法中,将所有可在类级别设置的参数指定为显式关键字参数(不使用 *args**kwargs)。

示例

>>> import numpy as np
>>> from sklearn.base import BaseEstimator
>>> class MyEstimator(BaseEstimator):
...     def __init__(self, *, param=1):
...         self.param = param
...     def fit(self, X, y=None):
...         self.is_fitted_ = True
...         return self
...     def predict(self, X):
...         return np.full(shape=X.shape[0], fill_value=self.param)
>>> estimator = MyEstimator(param=2)
>>> estimator.get_params()
{'param': 2}
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([2, 2, 2])
>>> estimator.set_params(param=3).fit(X, y).predict(X)
array([3, 3, 3])
get_metadata_routing()[source]#

获取此对象的元数据路由。

请查阅用户指南,了解路由机制的工作原理。

返回:
routingMetadataRequest

一个封装了路由信息的 MetadataRequest 对象。

get_params(deep=True)[source]#

获取此估计器的参数。

参数:
deep布尔型,默认为 True

如果为 True,将返回此估计器及其包含的子估计器的参数。

返回:
params字典

参数名称及其对应值的映射。

set_params(**params)[source]#

设置此估计器的参数。

此方法适用于简单估计器以及嵌套对象(例如 Pipeline)。后者具有 <component>__<parameter> 形式的参数,因此可以更新嵌套对象的每个组件。

参数:
**params字典

估计器参数。

返回:
self估计器实例

估计器实例。