注意
转到结尾 下载完整的示例代码。或通过JupyterLite或Binder在浏览器中运行此示例。
RBF SVM参数#
此示例说明径向基函数 (RBF) 核 SVM 的参数gamma
和C
的影响。
直观地说,gamma
参数定义单个训练样本的影响范围,值越低表示范围越“远”,值越高表示范围越“近”。gamma
参数可以看作是模型选择为支持向量的样本影响半径的倒数。
C
参数权衡了训练样本的正确分类与决策函数裕度的最大化。对于较大的C
值,如果决策函数能够更好地对所有训练点进行正确分类,则会接受较小的裕度。较低的C
值会鼓励更大的裕度,因此会得到更简单的决策函数,但代价是训练精度降低。换句话说,C
在SVM中充当正则化参数。
第一个图是在一个简化的分类问题上,针对各种参数值对决策函数进行可视化,该问题仅涉及 2 个输入特征和 2 个可能的目标类别(二元分类)。请注意,对于具有更多特征或目标类别的这种问题,这种图是不可能绘制的。
第二个图是分类器的交叉验证精度作为C
和gamma
函数的热图。在此示例中,出于说明目的,我们探索了一个相对较大的网格。在实践中,从\(10^{-3}\)到\(10^3\)的对数网格通常就足够了。如果最佳参数位于网格的边界上,则可以在后续搜索中向该方向扩展网格。
请注意,热图具有特殊的色条,其中间值接近最佳性能模型的分数值,以便轻松快速地区分它们。
模型的行为对gamma
参数非常敏感。如果gamma
过大,支持向量的区域影响半径仅包含支持向量本身,并且任何使用C
进行的正则化都无法防止过拟合。
当gamma
非常小时,模型过于受限,无法捕捉数据的复杂性或“形状”。任何选定支持向量的区域影响都将包含整个训练集。生成的模型的行为类似于具有能够分离任何一对两类的高密度中心的超平面的线性模型。
对于中间值,我们可以在第二个图上看到,可以在C
和gamma
的对角线上找到良好的模型。可以通过增加正确分类每个点的权重(更大的C
值)来使平滑模型(较低的gamma
值)更加复杂,因此形成了良好性能模型的对角线。
最后,还可以观察到,对于gamma
的一些中间值,当C
变得非常大时,我们会得到性能相同的模型。这表明支持向量集不再发生变化。RBF 核的半径本身就充当了良好的结构正则化器。进一步增加C
无济于事,这可能是因为没有更多违规的训练点(在裕度内或分类错误),或者至少找不到更好的解决方案。分数相等的情况下,使用较小的C
值可能更有意义,因为非常大的C
值通常会增加拟合时间。
另一方面,较低的C
值通常会导致更多的支持向量,这可能会增加预测时间。因此,降低C
值涉及拟合时间和预测时间之间的权衡。
我们还应该注意,分数的微小差异是由交叉验证过程的随机分割造成的。可以通过增加CV迭代次数n_splits
来消除这些虚假变化,但代价是计算时间增加。增加C_range
和gamma_range
步数的值将提高超参数热图的分辨率。
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
实用程序类,用于将颜色图的中点移动到感兴趣的值附近。
import numpy as np
from matplotlib.colors import Normalize
class MidpointNormalize(Normalize):
def __init__(self, vmin=None, vmax=None, midpoint=None, clip=False):
self.midpoint = midpoint
Normalize.__init__(self, vmin, vmax, clip)
def __call__(self, value, clip=None):
x, y = [self.vmin, self.midpoint, self.vmax], [0, 0.5, 1]
return np.ma.masked_array(np.interp(value, x, y))
加载和准备数据集#
用于网格搜索的数据集
用于决策函数可视化的数据集:我们只保留X中的前两个特征,并对数据集进行子采样,只保留2个类别,使其成为一个二元分类问题。
X_2d = X[:, :2]
X_2d = X_2d[y > 0]
y_2d = y[y > 0]
y_2d -= 1
通常,缩放数据对于SVM训练来说是一个好主意。在这个例子中,我们有点作弊,对所有数据进行缩放,而不是在训练集上拟合变换,只在测试集上应用它。
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X = scaler.fit_transform(X)
X_2d = scaler.fit_transform(X_2d)
训练分类器#
对于初始搜索,使用基数为 10 的对数网格通常很有帮助。使用基数为 2,可以实现更精细的调整,但代价要高得多。
from sklearn.model_selection import GridSearchCV, StratifiedShuffleSplit
from sklearn.svm import SVC
C_range = np.logspace(-2, 10, 13)
gamma_range = np.logspace(-9, 3, 13)
param_grid = dict(gamma=gamma_range, C=C_range)
cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)
grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv)
grid.fit(X, y)
print(
"The best parameters are %s with a score of %0.2f"
% (grid.best_params_, grid.best_score_)
)
The best parameters are {'C': np.float64(1.0), 'gamma': np.float64(0.1)} with a score of 0.97
现在我们需要为二维版本中的所有参数拟合一个分类器(我们在这里使用较小的参数集,因为训练需要一段时间)。
C_2d_range = [1e-2, 1, 1e2]
gamma_2d_range = [1e-1, 1, 1e1]
classifiers = []
for C in C_2d_range:
for gamma in gamma_2d_range:
clf = SVC(C=C, gamma=gamma)
clf.fit(X_2d, y_2d)
classifiers.append((C, gamma, clf))
可视化#
绘制参数影响的可视化图
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
xx, yy = np.meshgrid(np.linspace(-3, 3, 200), np.linspace(-3, 3, 200))
for k, (C, gamma, clf) in enumerate(classifiers):
# evaluate decision function in a grid
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# visualize decision function for these parameters
plt.subplot(len(C_2d_range), len(gamma_2d_range), k + 1)
plt.title("gamma=10^%d, C=10^%d" % (np.log10(gamma), np.log10(C)), size="medium")
# visualize parameter's effect on decision function
plt.pcolormesh(xx, yy, -Z, cmap=plt.cm.RdBu)
plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y_2d, cmap=plt.cm.RdBu_r, edgecolors="k")
plt.xticks(())
plt.yticks(())
plt.axis("tight")
scores = grid.cv_results_["mean_test_score"].reshape(len(C_range), len(gamma_range))
绘制验证精度作为gamma和C函数的热图
分数使用热颜色图编码,该颜色图从深红色变化到亮黄色。由于最有趣的分数都位于0.92到0.97的范围内,因此我们使用自定义归一化器将中点设置为0.92,以便更容易地可视化有趣范围内分数值的微小变化,同时不会粗暴地将所有低分数值压缩到相同的颜色。
plt.figure(figsize=(8, 6))
plt.subplots_adjust(left=0.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(
scores,
interpolation="nearest",
cmap=plt.cm.hot,
norm=MidpointNormalize(vmin=0.2, midpoint=0.92),
)
plt.xlabel("gamma")
plt.ylabel("C")
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plt.title("Validation accuracy")
plt.show()
脚本总运行时间:(0 分钟 5.290 秒)
相关示例