注意
Go to the end 下载完整的示例代码,或者通过 JupyterLite 或 Binder 在浏览器中运行此示例。
高斯混合模型正弦曲线#
此示例演示了高斯混合模型(GMM)在不从高斯随机变量混合中采样的数据上拟合时的行为。数据集由100个点组成,这些点松散地沿着一条有噪声的正弦曲线分布。因此,高斯分量的数量没有地面真值(ground truth)。
第一个模型是一个经典的高斯混合模型,有10个分量,使用期望最大化(Expectation-Maximization)算法进行拟合。
第二个模型是一个贝叶斯高斯混合模型,带有狄利克雷过程先验(Dirichlet process prior),使用变分推断(variational inference)进行拟合。浓度先验(concentration prior)的低值使模型倾向于较低数量的活动分量。该模型“决定”将其建模能力集中在数据集结构的宏观层面:用非对角协方差矩阵建模的、方向交替的点组。这些交替的方向大致捕捉了原始正弦信号的交替性质。
第三个模型也是一个带有狄利克雷过程先验的贝叶斯高斯混合模型,但这次浓度先验的值更高,这使模型有更大的自由度来建模数据的细粒度结构。结果是一个具有较多活动分量的混合模型,与第一个模型相似,我们在第一个模型中任意决定将分量数量固定为10。
哪个模型最好是一个主观判断的问题:我们是希望偏爱只捕捉宏观层面来总结和解释大部分数据结构而忽略细节的模型,还是偏爱紧密跟随信号高密度区域的模型?
最后两个面板展示了我们如何从最后两个模型中采样。得到的样本分布与原始数据分布并不完全相同。这种差异主要源于我们所做的近似误差,即使用了一个假设数据是由有限数量的高斯分量生成的模型,而不是一个连续的带噪声的正弦曲线。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import itertools
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg
from sklearn import mixture
color_iter = itertools.cycle(["navy", "c", "cornflowerblue", "gold", "darkorange"])
def plot_results(X, Y, means, covariances, index, title):
splot = plt.subplot(5, 1, 1 + index)
for i, (mean, covar, color) in enumerate(zip(means, covariances, color_iter)):
v, w = linalg.eigh(covar)
v = 2.0 * np.sqrt(2.0) * np.sqrt(v)
u = w[0] / linalg.norm(w[0])
# as the DP will not use every component it has access to
# unless it needs it, we shouldn't plot the redundant
# components.
if not np.any(Y == i):
continue
plt.scatter(X[Y == i, 0], X[Y == i, 1], 0.8, color=color)
# Plot an ellipse to show the Gaussian component
angle = np.arctan(u[1] / u[0])
angle = 180.0 * angle / np.pi # convert to degrees
ell = mpl.patches.Ellipse(mean, v[0], v[1], angle=180.0 + angle, color=color)
ell.set_clip_box(splot.bbox)
ell.set_alpha(0.5)
splot.add_artist(ell)
plt.xlim(-6.0, 4.0 * np.pi - 6.0)
plt.ylim(-5.0, 5.0)
plt.title(title)
plt.xticks(())
plt.yticks(())
def plot_samples(X, Y, n_components, index, title):
plt.subplot(5, 1, 4 + index)
for i, color in zip(range(n_components), color_iter):
# as the DP will not use every component it has access to
# unless it needs it, we shouldn't plot the redundant
# components.
if not np.any(Y == i):
continue
plt.scatter(X[Y == i, 0], X[Y == i, 1], 0.8, color=color)
plt.xlim(-6.0, 4.0 * np.pi - 6.0)
plt.ylim(-5.0, 5.0)
plt.title(title)
plt.xticks(())
plt.yticks(())
# Parameters
n_samples = 100
# Generate random sample following a sine curve
np.random.seed(0)
X = np.zeros((n_samples, 2))
step = 4.0 * np.pi / n_samples
for i in range(X.shape[0]):
x = i * step - 6.0
X[i, 0] = x + np.random.normal(0, 0.1)
X[i, 1] = 3.0 * (np.sin(x) + np.random.normal(0, 0.2))
plt.figure(figsize=(10, 10))
plt.subplots_adjust(
bottom=0.04, top=0.95, hspace=0.2, wspace=0.05, left=0.03, right=0.97
)
# Fit a Gaussian mixture with EM using ten components
gmm = mixture.GaussianMixture(
n_components=10, covariance_type="full", max_iter=100
).fit(X)
plot_results(
X, gmm.predict(X), gmm.means_, gmm.covariances_, 0, "Expectation-maximization"
)
dpgmm = mixture.BayesianGaussianMixture(
n_components=10,
covariance_type="full",
weight_concentration_prior=1e-2,
weight_concentration_prior_type="dirichlet_process",
mean_precision_prior=1e-2,
covariance_prior=1e0 * np.eye(2),
init_params="random",
max_iter=100,
random_state=2,
).fit(X)
plot_results(
X,
dpgmm.predict(X),
dpgmm.means_,
dpgmm.covariances_,
1,
"Bayesian Gaussian mixture models with a Dirichlet process prior "
r"for $\gamma_0=0.01$.",
)
X_s, y_s = dpgmm.sample(n_samples=2000)
plot_samples(
X_s,
y_s,
dpgmm.n_components,
0,
"Gaussian mixture with a Dirichlet process prior "
r"for $\gamma_0=0.01$ sampled with $2000$ samples.",
)
dpgmm = mixture.BayesianGaussianMixture(
n_components=10,
covariance_type="full",
weight_concentration_prior=1e2,
weight_concentration_prior_type="dirichlet_process",
mean_precision_prior=1e-2,
covariance_prior=1e0 * np.eye(2),
init_params="kmeans",
max_iter=100,
random_state=2,
).fit(X)
plot_results(
X,
dpgmm.predict(X),
dpgmm.means_,
dpgmm.covariances_,
2,
"Bayesian Gaussian mixture models with a Dirichlet process prior "
r"for $\gamma_0=100$",
)
X_s, y_s = dpgmm.sample(n_samples=2000)
plot_samples(
X_s,
y_s,
dpgmm.n_components,
1,
"Gaussian mixture with a Dirichlet process prior "
r"for $\gamma_0=100$ sampled with $2000$ samples.",
)
plt.show()
脚本总运行时间: (0 minutes 0.389 seconds)
相关示例