MeanShift#
- class sklearn.cluster.MeanShift(*, bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, n_jobs=None, max_iter=300)[source]#
使用平坦核的均值漂移聚类。
均值漂移聚类旨在在样本的平滑密度中发现“团块”。它是一种基于质心的算法,通过将质心的候选点更新为给定区域内点的平均值来工作。然后,这些候选点在后处理阶段进行过滤,以消除近似重复项,从而形成最终的质心集合。
播种是使用分箱技术进行的,以实现可伸缩性。
有关如何使用 MeanShift 聚类的示例,请参阅:均值漂移聚类算法演示。
在用户指南中阅读更多内容。
- 参数:
- bandwidth浮点数,默认值=None
平坦核中使用的带宽。
如果未给出,则使用 sklearn.cluster.estimate_bandwidth 估计带宽;有关该函数的文档,请参阅有关可伸缩性的提示(另请参阅下面的注释)。
- seeds形状为 (n_samples, n_features) 的类数组,默认值=None
用于初始化核的种子。如果未设置,则种子将由 clustering.get_bin_seeds 计算,其中带宽作为网格大小,其他参数使用默认值。
- bin_seeding布尔值,默认值=False
如果为 True,初始核位置不是所有点的位置,而是点的离散化版本的位置,其中点被分箱到粗细与带宽相对应的网格上。将此选项设置为 True 将加快算法速度,因为将初始化更少的种子。默认值为 False。如果 `seeds` 参数不是 None,则忽略此参数。
- min_bin_freq整数,默认值=1
为了加快算法速度,只接受至少包含 `min_bin_freq` 个点的箱作为种子。
- cluster_all布尔值,默认值=True
如果为 True,则所有点都被聚类,即使是那些不在任何核内的孤立点。孤立点被分配到最近的核。如果为 False,则孤立点被赋予聚类标签 -1。
- n_jobs整数,默认值=None
用于计算的作业数。以下任务受益于并行化
用于带宽估计和标签分配的最近邻搜索。详情请参见
NearestNeighbors
类的文档字符串。所有种子的爬山优化。
有关更多详细信息,请参阅术语表。
None
表示 1,除非在joblib.parallel_backend
上下文中。-1
表示使用所有处理器。有关更多详细信息,请参阅术语表。- max_iter整数,默认值=300
每个种子点在聚类操作终止(对于该种子点)前的最大迭代次数,如果尚未收敛。
在 0.22 版本中新增。
- 属性:
另请参阅
KMeans
K-均值聚类。
注释
可伸缩性
由于此实现使用平坦核和球树来查找每个核的成员,因此在较低维度中,复杂度将趋向于 O(T*n*log(n)),其中 n 是样本数量,T 是点数量。在较高维度中,复杂度将趋向于 O(T*n^2)。
通过使用更少的种子可以提高可伸缩性,例如在 get_bin_seeds 函数中使用更高的 `min_bin_freq` 值。
请注意,estimate_bandwidth 函数的可伸缩性远低于均值漂移算法,如果使用它,它将成为瓶颈。
参考文献
Dorin Comaniciu and Peter Meer, “Mean Shift: A robust approach toward feature space analysis”. IEEE Transactions on Pattern Analysis and Machine Intelligence. 2002. pp. 603-619.
示例
>>> from sklearn.cluster import MeanShift >>> import numpy as np >>> X = np.array([[1, 1], [2, 1], [1, 0], ... [4, 7], [3, 5], [3, 6]]) >>> clustering = MeanShift(bandwidth=2).fit(X) >>> clustering.labels_ array([1, 1, 1, 0, 0, 0]) >>> clustering.predict([[0, 0], [5, 5]]) array([1, 0]) >>> clustering MeanShift(bandwidth=2)
有关均值漂移聚类与其他聚类算法的比较,请参阅在玩具数据集上比较不同的聚类算法
- fit(X, y=None)[source]#
执行聚类。
- 参数:
- X形状为 (n_samples, n_features) 的类数组
要聚类的样本。
- y忽略
未使用,按约定为 API 一致性而存在。
- 返回:
- self对象
拟合实例。
- fit_predict(X, y=None, **kwargs)[source]#
在
X
上执行聚类并返回聚类标签。- 参数:
- X形状为 (n_samples, n_features) 的类数组
输入数据。
- y忽略
未使用,按约定为 API 一致性而存在。
- **kwargs字典
要传递给
fit
的参数。在 1.4 版本中新增。
- 返回:
- labels形状为 (n_samples,) 且 dtype=np.int64 的 ndarray
聚类标签。
- get_metadata_routing()[source]#
获取此对象的元数据路由。
请查看 用户指南,了解路由机制的工作原理。
- 返回:
- routingMetadataRequest
一个封装路由信息的
MetadataRequest
。
- get_params(deep=True)[source]#
获取此估计器的参数。
- 参数:
- deep布尔值,默认值=True
如果为 True,则返回此估计器及其包含的作为估计器的子对象的参数。
- 返回:
- params字典
参数名称与其值之间的映射。