GMM 初始化方法#

高斯混合模型中不同初始化方法的示例

有关估计器的更多信息,请参阅 高斯混合模型

这里我们生成了一些样本数据,其中包含四个易于识别的聚类。本示例的目的是展示初始化参数 *init_param* 的四种不同方法。

四种初始化方法是 *kmeans*(默认)、*random*、*random_from_data* 和 *k-means++*。

橙色菱形表示由 *init_param* 生成的 GMM 的初始化中心。其余数据表示为十字,颜色表示 GMM 完成后最终关联的分类。

每个子图右上角的数字表示 GaussianMixture 收敛所需的迭代次数以及算法初始化部分运行的相对时间。初始化时间越短,收敛所需的迭代次数往往越多。

初始化时间是该方法所用时间与默认 *kmeans* 方法所用时间的比率。如您所见,与 *kmeans* 相比,所有三种替代方法的初始化时间都更短。

在本例中,当使用 *random_from_data* 或 *random* 初始化时,模型需要更多次迭代才能收敛。这里 *k-means++* 在初始化时间短和 GaussianMixture 迭代次数少方面都做得很好。

GMM iterations and relative time taken to initialize, kmeans, Iter 8 | Init Time 1.00x, random_from_data, Iter 137 | Init Time 0.74x, k-means++, Iter 11 | Init Time 1.12x, random, Iter 47 | Init Time 0.55x
# Author: Gordon Walsh <[email protected]>
# Data generation code from Jake Vanderplas <[email protected]>

from timeit import default_timer as timer

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets._samples_generator import make_blobs
from sklearn.mixture import GaussianMixture
from sklearn.utils.extmath import row_norms

print(__doc__)

# Generate some data

X, y_true = make_blobs(n_samples=4000, centers=4, cluster_std=0.60, random_state=0)
X = X[:, ::-1]

n_samples = 4000
n_components = 4
x_squared_norms = row_norms(X, squared=True)


def get_initial_means(X, init_params, r):
    # Run a GaussianMixture with max_iter=0 to output the initialization means
    gmm = GaussianMixture(
        n_components=4, init_params=init_params, tol=1e-9, max_iter=0, random_state=r
    ).fit(X)
    return gmm.means_


methods = ["kmeans", "random_from_data", "k-means++", "random"]
colors = ["navy", "turquoise", "cornflowerblue", "darkorange"]
times_init = {}
relative_times = {}

plt.figure(figsize=(4 * len(methods) // 2, 6))
plt.subplots_adjust(
    bottom=0.1, top=0.9, hspace=0.15, wspace=0.05, left=0.05, right=0.95
)

for n, method in enumerate(methods):
    r = np.random.RandomState(seed=1234)
    plt.subplot(2, len(methods) // 2, n + 1)

    start = timer()
    ini = get_initial_means(X, method, r)
    end = timer()
    init_time = end - start

    gmm = GaussianMixture(
        n_components=4, means_init=ini, tol=1e-9, max_iter=2000, random_state=r
    ).fit(X)

    times_init[method] = init_time
    for i, color in enumerate(colors):
        data = X[gmm.predict(X) == i]
        plt.scatter(data[:, 0], data[:, 1], color=color, marker="x")

    plt.scatter(
        ini[:, 0], ini[:, 1], s=75, marker="D", c="orange", lw=1.5, edgecolors="black"
    )
    relative_times[method] = times_init[method] / times_init[methods[0]]

    plt.xticks(())
    plt.yticks(())
    plt.title(method, loc="left", fontsize=12)
    plt.title(
        "Iter %i | Init Time %.2fx" % (gmm.n_iter_, relative_times[method]),
        loc="right",
        fontsize=10,
    )
plt.suptitle("GMM iterations and relative time taken to initialize")
plt.show()

**脚本总运行时间:**(0 分 0.808 秒)

相关示例

GMM 协方差

GMM 协方差

K-Means++ 初始化示例

K-Means++ 初始化示例

手写数字数据上的 K 均值聚类演示

手写数字数据上的 K 均值聚类演示

高斯混合模型椭球

高斯混合模型椭球

由 Sphinx-Gallery 生成的图库