【机器学习】支持向量机 SVM 从原理到实战(Python 全流程实现)

目录

一、前言

二、SVM 核心原理(从通俗到深入)

2.1 什么是 SVM?一个通俗的小故事

2.2 核心目标:最优超平面与最大间隔

2.2.1 超平面方程

2.2.2 点到超平面的距离

2.2.3 最大间隔的优化目标

2.2.4 什么是支持向量?

2.3 软间隔:解决噪声与线性不可分

2.4 核函数:低维解决高维非线性问题

2.5 SVM 的优缺点

优点

缺点

三、SVM 实战:基于 Python+sklearn 实现

3.1 环境准备

3.2 实战一:二维特征线性 SVM 可视化

​3.2.1 完整代码实现

3.2.2 结果可视化与解读

3.3 实战二:鸢尾花数据集全特征 SVM 多分类

3.3.1 数据集与预处理

3.3.2 完整代码实现

3.3.3 模型评估结果解读

训练集混淆矩阵

测试集混淆矩阵

测试集分类报告

四、SVM 核心 API 参数详解

五、总结

注:其中iris.csv数据集和SVM详细API文档都在我的主页资源中

一、前言

支持向量机(Support Vector Machine,SVM)是机器学习领域经典的有监督分类算法,自诞生以来凭借扎实的数学理论、优秀的小样本学习能力、强大的非线性拟合能力,在分类、回归等任务中得到了广泛应用。本文将从通俗的原理讲解入手,深入拆解 SVM 的核心逻辑,再基于 Python+sklearn 实现完整的 SVM 分类任务,包含可视化、模型训练、评估全流程,帮助读者从入门到实战彻底掌握 SVM。

二、SVM 核心原理(从通俗到深入)

2.1 什么是 SVM?一个通俗的小故事

我们用一个经典的故事理解 SVM 的核心思想:很久以前,公主被魔鬼绑架,王子需要完成魔鬼的挑战:用一根棍子分开桌子上两种颜色的球,并且要求后续加入更多球时,这根棍子依然能有效分类。

1.第一次王子随便放了棍子,结果新增的球直接越界,分类失效;

2.后来王子把棍子放在了两类球的中间,让棍子两边到最近的球的距离尽可能大,此时哪怕新增更多球,棍子依然能稳定分类;

3.魔鬼又把球摆成了非线性的布局,二维平面里根本没法用一根直线分开,王子一拍桌子让球飞到空中,用一张纸完美隔开了两类球。

对应到 SVM 的核心概念里:

两种颜色的球 = 我们的训练数据

棍子 / 纸 = 分类决策边界(超平面)

让棍子两边间隙最大的操作 = 最大间隔最优化

拍桌子让球飞起来 = 核函数(低维映射到高维)

离棍子最近、决定棍子位置的球 = 支持向量

2.2 核心目标:最优超平面与最大间隔

SVM 的核心目标,就是找到一个最优超平面,让不同类别的样本被完美分开,且两类样本到超平面的最小距离(间隔)最大化。

2.2.1 超平面方程

超平面是分类的决策边界,在不同维度空间有不同的表达形式:

二维平面:一条直线,方程为 ​

三维空间:一个平面,方程为 ​

更高维空间:超平面,通用方程为 ​。

其中 ​ω为超平面的法向量(决定超平面方向),​b为偏置项(决定超平面的位置)。

​最终的分类决策函数为:。

其中​sign为符号函数,输入大于 0 输出 1(正例),小于 0 输出 - 1(负例)。

2.2.2 点到超平面的距离

样本点到超平面的距离,是衡量分类置信度的核心指标,公式为:。

结合分类的正确性,我们可以得到几何间隔:当样本分类正确时,​,因此样本到超平面的几何间隔可写为:

2.2.3 最大间隔的优化目标

我们的目标是:让离超平面最近的样本点(支持向量)到超平面的距离最大化。通过数学放缩,我们可以约束支持向量满足 ,此时最大间隔的优化目标可转化为:。

约束条件:。这个优化问题可以通过拉格朗日乘子法转化为对偶问题求解,最终得到最优的​ω和​b,也就是超平面的参数。

2.2.4 什么是支持向量?

在求解过程中,只有满足 的样本点,对应的拉格朗日乘子,这些样本就是支持向量。

SVM 的核心特性之一就是:最终的决策超平面只由少数支持向量决定,哪怕移除其他所有样本,超平面的位置也不会改变。这也是 SVM 在小样本场景下表现优异的核心原因。

2.3 软间隔:解决噪声与线性不可分

现实场景中,很多数据存在噪声点,无法实现完美的线性可分,如果强行追求 100% 分类正确,会导致模型泛化能力极差。因此 SVM 引入了软间隔的概念:允许少数样本点违反约束、出现在间隔带内,甚至被误分类,以此提升模型的泛化能力。

我们引入松弛因子,将约束条件放宽为:。 同时优化目标更新为:

其中C为惩罚因子,是 SVM 的核心超参数:

C越大:对误分类的惩罚越重,模型越不允许出现误分类,容易过拟合,泛化能力弱;

C越小:对误分类的惩罚越轻,允许更多样本违反约束,模型泛化能力强,容易欠拟合。

2.4 核函数:低维解决高维非线性问题

对于完全线性不可分的数据,SVM 通过核函数解决问题:将低维空间的线性不可分数据,映射到高维特征空间,使其在高维空间中线性可分,再在高维空间中学习最优超平面。

直接在高维空间计算会带来巨大的计算量,而核函数的核心优势是:在低维空间完成高维空间的内积运算,结果完全一致,大幅降低计算复杂度。

常用的核函数有以下几种:

1.线性核:​,适用于线性可分的数据,计算速度快,可解释性强;

2.多项式核:,适用于中等规模的非线性数据,可通过 degree 调整多项式维度;

3.高斯核(RBF,径向基函数):,默认核函数,适用于绝大多数非线性场景,通过​调整映射范围:

越小:正态分布越 “胖”,辐射范围越大,过拟合风险越低;

越大:正态分布越 “瘦”,辐射范围越小,过拟合风险越高。

2.5 SVM 的优缺点

优点

有严格的数学理论支撑,可解释性强,不同于黑盒模型;

小样本场景下表现优异,最终决策仅由少数支持向量决定;

软间隔机制可有效提升模型泛化能力,适配带噪声的现实数据;

核函数可完美解决非线性分类问题,避免 “维数灾难”;

泛化能力强,在分类任务中不易过拟合。

缺点

对大规模训练样本适配性差,样本量超过 10 万时,核矩阵的存储和计算会耗费大量内存和时间;

对核函数和超参数的选择非常敏感,不同参数对模型效果影响极大;

预测速度与支持向量的数量成正比,支持向量过多时,预测效率较低。

三、SVM 实战:基于 Python+sklearn 实现

本次实战使用经典的鸢尾花数据集,分为两个部分:

二维特征线性 SVM 可视化,直观展示超平面、间隔、支持向量;

全特征 RBF 核 SVM 多分类,完成模型训练、混淆矩阵可视化、分类报告输出全流程。

3.1 环境准备

需要提前安装相关依赖库:

pip install pandas numpy matplotlib scikit-learn

3.2 实战一:二维特征线性 SVM 可视化

本部分选取鸢尾花数据集的 2 个特征,训练线性核 SVM,并可视化超平面、间隔边界和支持向量,直观理解 SVM 的核心逻辑。

​3.2.1 完整代码实现

import pandas as pd

data=pd.read_csv("iris.csv",header=None)

import matplotlib.pyplot as plt

data1=data.iloc[:50,:]

data2=data.iloc[50:100,:]

data3=data.iloc[100:,:]

plt.scatter(data1[1],data1[3],marker="^")

plt.scatter(data2[1],data2[3],marker="o")

#使用svm训练

from sklearn.svm import SVC

x = data.iloc[:,[1,3]]

y = data.iloc[:,-1]

svm = SVC(kernel="linear",C=100,random_state=0) #c无穷大float('inf'),则软间隔为0,不容有其他点进入软间隔内

svm.fit(x,y)

#可视化svm结果

#参数w【原始数据为二维数组】

w = svm.coef_[0]

b = svm.intercept_[0]

#超平面方程w1*1+w2*2+b=0

import numpy as np

x1 = np.linspace(0,5,700)

#超平面方程

x2 = -(w[0]*x1+b)/w[1]

#上超平面方程

x3 = (1-(w[0]*x1+b))/w[1]

#下超平面方程

x4 = (-1-(w[0]*x1+b))/w[1]

#可视化超平面

plt.plot(x1,x2,linewidth=2,color='r')

plt.plot(x1,x3,linewidth=1,color='r',linestyle='--')

plt.plot(x1,x4,linewidth=1,color='r',linestyle='--')

# #对坐标限制

# plt.xlim([2,6])

# plt.ylim([0,3])

#找到支持向量【二维数组】可视化向量

vet = svm.support_vectors_

plt.scatter(vet[:,0],vet[:,1],c="b",marker="+")

plt.show()

3.2.2 结果可视化与解读

从可视化结果中我们可以清晰看到:

红色实线为 SVM 学习到的最优分类超平面,完美分隔了两类样本;

两条红色虚线为间隔边界,两类样本的间隔被最大化;

蓝色 + 标记的点就是支持向量,这些点落在间隔边界上,是决定超平面位置的核心样本,其他样本的移除不会改变超平面的位置。

3.3 实战二:鸢尾花数据集全特征 SVM 多分类

本部分使用鸢尾花数据集的全部 4 个特征,基于 RBF 核 SVM 完成三分类任务,包含数据划分、模型训练、混淆矩阵可视化、分类报告输出全流程。

3.3.1 数据集与预处理

鸢尾花数据集包含 3 类鸢尾花,每类 50 个样本,共 150 条数据,4 个特征分别为花萼长度、花萼宽度、花瓣长度、花瓣宽度,我们按照 8:2 划分训练集和测试集。

3.3.2 完整代码实现

'''四个特征全训练'''

import pandas as pd

datas = pd.read_csv("iris.csv",header=None)

data = datas.iloc[:,:-1].values

target = datas.iloc[:,-1].values

#数据切分

from sklearn.model_selection import train_test_split

x_train,x_test,y_train,y_test = \

train_test_split(data,target,test_size=0.2,random_state=0)

#可视化混淆矩阵

def cm_plot(ah,yp):

from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt

cm = confusion_matrix(ah,yp)

plt.matshow(cm, cmap=plt.cm.Blues)

plt.colorbar()

for x in range(len(cm)):

for y in range(len(cm[x])):

plt.annotate(cm[x][y], xy=(y,x),horizontalalignment="center"

,color="white",verticalalignment="center")

plt.ylabel('True label')

plt.xlabel('Predicted label')

plt.show()

return plt

#模型训练

from sklearn.svm import SVC

svm = SVC(kernel='rbf',C=10)

svm.fit(x_train,y_train)

#模型自测试

y_pred = svm.predict(x_train)

cm_plot(y_train,y_pred)

#测试集测试

y_test_pred = svm.predict(x_test)

cm_plot(y_test,y_test_pred)

""""#测试集测试获得分类结果报告"""

from sklearn import metrics

test_predicted_big = svm.predict(x_test)

print(metrics.classification_report(y_test,test_predicted_big))

3.3.3 模型评估结果解读

训练集混淆矩阵

可视化视图

训练集共 120 个样本,仅出现 2 个误分类样本,整体分类准确率超过 98%,模型在训练数据上拟合效果良好,没有出现欠拟合。

测试集混淆矩阵

可视化视图

测试集共 30 个样本,所有样本均被正确分类,无任何误分类情况,模型在未见过的测试数据上表现完美,泛化能力优异。

测试集分类报告

从分类报告可以看到,3 个类别的精确率(precision)、召回率(recall)、F1-score 均为 1.00,整体准确率(accuracy)达到 100%,进一步验证了 SVM 在该分类任务上的优秀表现。

四、SVM 核心 API 参数详解

本文使用 sklearn 的SVC类实现 SVM 分类,核心参数如下:

参数名

作用

核心说明

C

惩罚因子

浮点数,默认 1.0。C 越大,对误分类惩罚越重,易过拟合;C 越小,容错率越高,易欠拟合

kernel

核函数

默认rbf,可选linear(线性核)、poly(多项式核)、sigmoid

degree

多项式维度

整数,默认 3,仅对poly核生效,其他核函数会忽略该参数

gamma

核函数系数

仅对rbf、poly、sigmoid生效。gamma 越大,过拟合风险越高;gamma 越小,泛化能力越强

random_state

随机种子

固定随机种子,保证实验结果可复现

其中,C、kernel、gamma是对模型效果影响最大的三个超参数,实际使用中建议通过网格搜索 + 交叉验证的方式选择最优参数组合。

五、总结

本文从通俗的原理入手,深入讲解了 SVM 的最优超平面、最大间隔、支持向量、软间隔、核函数等核心概念,同时基于 Python+sklearn 实现了完整的 SVM 分类任务,包含可视化、模型训练、评估全流程。SVM 作为经典的机器学习算法,在小样本、中等样本量的分类任务中有着不可替代的优势,掌握其核心原理和实战技巧,是机器学习入门的必备技能。读者可以基于本文的代码,更换自己的数据集,调整超参数,进一步深入理解 SVM 的特性。