RBF SVM 参数#

本示例说明了径向基函数 (RBF) 核 SVM 参数 gammaC 的影响。

直观上,gamma 参数定义了单个训练样本的影响范围,低值表示“远”,高值表示“近”。gamma 参数可以看作是模型选择为支持向量的样本影响半径的倒数。

参数 C 在正确分类训练样本和最大化决策函数间隔之间进行权衡。对于较大的 C 值,如果决策函数能更好地正确分类所有训练点,则会接受较小的间隔。较低的 C 值会鼓励更大的间隔,从而得到更简单的决策函数,但会牺牲训练精度。换句话说,C 在 SVM 中充当正则化参数。

第一张图是针对一个简化分类问题(仅涉及 2 个输入特征和 2 个可能的类(二分类))的多种参数值下的决策函数可视化。请注意,对于具有更多特征或目标类的问题,无法进行此类绘图。

第二张图是分类器交叉验证准确率的 heatmap,它是 Cgamma 的函数。在本示例中,我们为了说明目的探索了一个相对大的网格。在实践中,从 \(10^{-3}\)\(10^3\) 的对数网格通常就足够了。如果最佳参数位于网格的边界上,则可以在随后的搜索中向该方向扩展。

请注意,热力图具有特殊的颜色条,其中心值接近表现最佳模型的得分值,以便一目了然地识别它们之间的差异。

模型的行为对 gamma 参数非常敏感。如果 gamma 过大,支持向量的影响区域半径将只包括支持向量本身,并且任何 C 的正则化都无法阻止过拟合。

gamma 非常小时,模型受到的约束过大,无法捕捉数据的复杂性或“形状”。任何选定支持向量的影响区域将包括整个训练集。由此产生的模型将类似于一个线性模型,其中一组超平面分离了任意两类高密度区域的中心。

对于中间值,我们可以在第二张图上看到,良好的模型可以在 Cgamma 的对角线上找到。通过增加正确分类每个点的重要性(较大的 C 值),平滑模型(较低的 gamma 值)可以变得更复杂,因此形成了良好性能模型的对角线。

最后,还可以观察到,对于某些中间的 gamma 值,当 C 变得非常大时,我们得到了性能相同的模型。这表明支持向量集不再变化。RBF 核的半径本身就起到了很好的结构正则化作用。进一步增加 C 没有帮助,可能是因为不再有违反(在间隔内或错误分类)的训练点,或者至少找不到更好的解决方案。在分数相同的情况下,使用较小的 C 值可能更合理,因为非常高的 C 值通常会增加拟合时间。

另一方面,较低的 C 值通常会导致更多的支持向量,这可能会增加预测时间。因此,降低 C 值需要在拟合时间和预测时间之间进行权衡。

我们还应该注意,得分的微小差异是由于交叉验证过程中的随机分割造成的。通过增加交叉验证迭代次数 n_splits,可以消除这些虚假的变化,但代价是计算时间增加。增加 C_rangegamma_range 步长的数量将提高超参数热力图的分辨率。

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

用于将颜色映射的中间点移动到感兴趣值附近的实用类。

import numpy as np
from matplotlib.colors import Normalize


class MidpointNormalize(Normalize):
    def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
        self.midpoint = midpoint
        Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
        return np.ma.masked_array(np.interp(value, x, y))

加载和准备数据集#

用于网格搜索的数据集

from sklearn.datasets import load_iris

iris = load_iris()
X = iris.data
y = iris.target

用于决策函数可视化的数据集:我们只保留 X 中的前两个特征,并对数据集进行子采样,使其仅包含 2 个类别,从而使其成为一个二分类问题。

X_2d = X[:, :2]
X_2d = X_2d[y > 0]
y_2d = y[y > 0]
y_2d -= 1

通常,对 SVM 训练数据进行缩放是一个好主意。在本例中,我们有点“作弊”,因为我们缩放了所有数据,而不是在训练集上拟合变换然后只将其应用于测试集。

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X = scaler.fit_transform(X)
X_2d = scaler.fit_transform(X_2d)

训练分类器#

对于初始搜索,以 10 为底的对数网格通常很有帮助。使用以 2 为底的网格可以实现更精细的调优,但成本要高得多。

from sklearn.model_selection import GridSearchCV, StratifiedShuffleSplit
from sklearn.svm import SVC

C_range = np.logspace(-2, 10, 13)
gamma_range = np.logspace(-9, 3, 13)
param_grid = dict(gamma=gamma_range, C=C_range)
cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
grid.fit(X, y)

print(
    "The best parameters are %s with a score of %0.2f"
    % (grid.best_params_, grid.best_score_)
)
The best parameters are {'C': np.float64(1.0), 'gamma': np.float64(0.1)} with a score of 0.97

现在我们需要为 2D 版本中的所有参数拟合一个分类器(这里我们使用较小的参数集,因为训练需要一些时间)

C_2d_range = [1e-2, 1, 1e2]
gamma_2d_range = [1e-1, 1, 1e1]
classifiers = []
for C in C_2d_range:
    for gamma in gamma_2d_range:
        clf = SVC(C=C, gamma=gamma)
        clf.fit(X_2d, y_2d)
        classifiers.append((C, gamma, clf))

可视化#

绘制参数效果的可视化

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))
for k, (C, gamma, clf) in enumerate(classifiers):
    # evaluate decision function in a grid
    Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)

    # visualize decision function for these parameters
    plt.subplot(len(C_2d_range), len(gamma_2d_range), k + 1)
    plt.title("gamma=10^%d, C=10^%d" % (np.log10(gamma), np.log10(C)), size="medium")

    # visualize parameter's effect on decision function
    plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.RdBu)
    plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_2d, cmap=plt.cm.RdBu_r, edgecolors="k")
    plt.xticks(())
    plt.yticks(())
    plt.axis("tight")

scores = grid.cv_results_["mean_test_score"].reshape(len(C_range), len(gamma_range))
gamma=10^-1, C=10^-2, gamma=10^0, C=10^-2, gamma=10^1, C=10^-2, gamma=10^-1, C=10^0, gamma=10^0, C=10^0, gamma=10^1, C=10^0, gamma=10^-1, C=10^2, gamma=10^0, C=10^2, gamma=10^1, C=10^2

绘制验证准确率作为 gamma 和 C 函数的热力图

分数以热力颜色映射(从深红到亮黄)编码。由于最有趣的得分都位于 0.92 到 0.97 范围内,我们使用自定义归一化器将中点设置为 0.92,以便更容易地可视化感兴趣范围内的分数微小变化,同时避免将所有低分数粗暴地折叠为相同的颜色。

plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=0.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(
    scores,
    interpolation="nearest",
    cmap=plt.cm.hot,
    norm=MidpointNormalize(vmin=0.2, midpoint=0.92),
)
plt.xlabel("gamma")
plt.ylabel("C")
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plt.title("Validation accuracy")
plt.show()
Validation accuracy

脚本总运行时间: (0 分 4.929 秒)

相关示例

鸢尾花数据集上半监督分类器与 SVM 的决策边界

鸢尾花数据集上半监督分类器与 SVM 的决策边界

绘制不同 SVM 核的分类边界

绘制不同 SVM 核的分类边界

在鸢尾花数据集上绘制不同 SVM 分类器

在鸢尾花数据集上绘制不同 SVM 分类器

SVM 间隔示例

SVM 间隔示例

由 Sphinx-Gallery 生成的图库