均值漂移#
- class sklearn.cluster.MeanShift(*, bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, n_jobs=None, max_iter=300)[source]#
使用平面核的均值漂移聚类。
均值漂移聚类旨在发现样本平滑密度中的“斑块”。它是一种基于质心的算法,其工作原理是将质心的候选值更新为给定区域内点的平均值。然后在后处理阶段过滤这些候选值以消除近似重复项,从而形成最终的质心集。
为了提高可扩展性,使用分箱技术进行种子初始化。
有关如何使用 MeanShift 聚类的示例,请参阅:均值漂移聚类算法演示。
在 用户指南 中了解更多信息。
- 参数:
- bandwidth浮点数,默认值=None
平面核中使用的带宽。
如果未给出,则使用 sklearn.cluster.estimate_bandwidth 估计带宽;有关可扩展性的提示,请参阅该函数的文档(另请参阅下面的注释)。
- seeds形状为 (n_samples, n_features) 的类数组,默认值=None
用于初始化内核的种子。如果未设置,则种子由 clustering.get_bin_seeds 计算,带宽作为网格大小,其他参数使用默认值。
- bin_seeding布尔值,默认值=False
如果为真,则初始内核位置不是所有点的位 置,而是点的离散化版本的 位置,其中点被分箱到一个网格上,其粗糙度对应于带宽。将此选项设置为 True 将加快算法速度,因为初始化的种子更少。默认值为 False。如果 seeds 参数不是 None,则忽略。
- min_bin_freq整数,默认值=1
为了加快算法速度,只接受至少有 min_bin_freq 个点的箱作为种子。
- cluster_all布尔值,默认值=True
如果为真,则对所有点进行聚类,即使是那些不在任何内核中的孤立点。孤立点被分配给最近的内核。如果为假,则孤立点的聚类标签为 -1。
- n_jobs整数,默认值=None
用于计算的作业数。以下任务受益于并行化:
带宽估计和标签分配的最近邻搜索。详情请参见
NearestNeighbors
类的文档字符串。所有种子的爬山优化。
更多详情请参见 词汇表。
None
表示 1,除非在joblib.parallel_backend
上下文中。-1
表示使用所有处理器。更多详情请参见 词汇表。- max_iter整数,默认值=300
每个种子点在聚类操作终止(对于该种子点)之前的最大迭代次数,如果尚未收敛。
在 0.22 版本中添加。
- 属性:
另请参见
KMeans
K 均值聚类。
注释
可扩展性
因为此实现使用平面核和球树来查找每个核的成员,所以复杂度在较低维度下将趋向于 O(T*n*log(n)),其中 n 是样本数,T 是点数。在较高维度下,复杂度将趋向于 O(T*n^2)。
可以使用较少的种子来提高可扩展性,例如在 get_bin_seeds 函数中使用较高的 min_bin_freq 值。
请注意,estimate_bandwidth 函数的可扩展性远低于均值漂移算法,如果使用它,它将成为瓶颈。
参考文献
Dorin Comaniciu 和 Peter Meer,“均值漂移:一种用于特征空间分析的鲁棒方法”。IEEE 模式分析和机器智能汇刊。2002 年。第 603-619 页。
示例
>>> from sklearn.cluster import MeanShift >>> import numpy as np >>> X = np.array([[1, 1], [2, 1], [1, 0], ... [4, 7], [3, 5], [3, 6]]) >>> clustering = MeanShift(bandwidth=2).fit(X) >>> clustering.labels_ array([1, 1, 1, 0, 0, 0]) >>> clustering.predict([[0, 0], [5, 5]]) array([1, 0]) >>> clustering MeanShift(bandwidth=2)
- fit(X, y=None)[source]#
执行聚类。
- 参数:
- X形状为 (n_samples, n_features) 的类数组
要聚类的样本。
- y忽略
未使用,根据约定保留以保持 API 一致性。
- 返回:
- self对象
拟合后的实例。
- fit_predict(X, y=None, **kwargs)[source]#
对
X
执行聚类并返回聚类标签。- 参数:
- X形状为 (n_samples, n_features) 的类数组
输入数据。
- y忽略
未使用,根据约定保留以保持 API 一致性。
- **kwargs字典
要传递给
fit
的参数。在 1.4 版本中添加。
- 返回:
- labels形状为 (n_samples,),dtype=np.int64 的 ndarray
聚类标签。
- get_metadata_routing()[source]#
获取此对象的元数据路由。
请查看用户指南了解路由机制的工作原理。
- 返回:
- routingMetadataRequest
一个
MetadataRequest
封装了路由信息。
- get_params(deep=True)[source]#
获取此估计器的参数。
- 参数:
- deepbool, default=True
如果为 True,则将返回此估计器和包含的作为估计器的子对象的参数。
- 返回:
- paramsdict
参数名称与其值的映射。