《统计学习方法》-支持向量机SVM学习笔记和python源码

作者:V_victor

支持向量机SVM的学习笔记。对书中关键知识点进行了摘录,并加入一些自己的理解。




------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

python代码主要参照《机器学习实战》上的python代码和数据集

#author=altman
import numpy as np 
import copy as cp
import matplotlib.pyplot as plt
'''
读取数据
'''
def readData(fileName):
    matrix = []
    labels = []
    fr = open(fileName)
    for line in fr.readlines():
        line = line.strip().split("\t")
        data = list(map(float,line))
        matrix.append([data[0],data[1]])
        labels.append(data[-1])
    return matrix,labels
def clipAlpha(aj,H,L):
    if aj > H: 
        aj = H
    if L > aj:
        aj = L
    return aj
def kernelTrans(X, A, kTup):
    m,n = np.shape(X)
    K = np.mat(np.zeros((m,1)))
    if kTup[0]=='lin':
        K = X * A.T   
    elif kTup[0]=='rbf':
        for j in range(m):
            deltaRow = X[j,:] - A
            K[j] = deltaRow*deltaRow.T
        K = np.exp(K/(-1*kTup[1]**2)) 
    return K
'''
数据结构
'''
class svmStrcut(object):
    """docstring for svmStrcut"""
    def __init__(self, data,labels,C,toler,kTup):
        self.C = C
        self.data = data
        self.labels = labels
        self.toler = toler
        self.m = self.data.shape[0]
        self.b = 0.0
        self.alphas = np.mat(np.zeros((self.m,1)))
        self.eCache = np.mat(np.zeros((self.m,2)))
        self.K = np.mat(np.zeros((self.m,self.m)))
        for i in range(self.m):
            self.K[:,i] = kernelTrans(self.data, self.data[i,:], kTup)
''' 
计算误差
error=g(x)-y
'''
def calcEk(sT,k):
    value = float(np.multiply(sT.alphas,sT.labels.T).T*sT.K[:,k]+sT.b)
    error = value - float(sT.labels[0,k])
    return error
'''
选择第二个改变alpha
'''
def selectJ(i,sT,ei):
    maxK = -1
    maxDeltaE = 0
    ej = 0
    #将i对应的alpha,标识为已更改过
    sT.eCache[i] = [1,ei]
    validEcacheList = np.nonzero(sT.eCache[:,0].A)[0]
    if (len(validEcacheList))>1:
        for k in validEcacheList:
            if k == i : continue
            ek = calcEk(sT,k)
            deltaE = abs(ei-ek)
            if deltaE > maxDeltaE:
                maxDeltaE = deltaE
                ej = ek
                maxK = k
        return maxK,ej
    else:
        j = i
        while (j==i):
            j = np.random.uniform(sT.m)
        ej = calcEk(sT,j)
        return j,ej
def updateEk(sT,k):
    ek = calcEk(sT,k)
    sT.eCache[k] = [1,ek]
'''
内循环,即第二alpha选择
'''
def innerLoop(i,sT):
    ei = calcEk(sT,i)
    if ((sT.labels[0,i]*ei < -sT.toler and sT.alphas[i]< sT.C)or(sT.labels[0,i]*ei > sT.toler and sT.alphas[i]> 0)):
        #根据ei和i,选择j
        j,ej = selectJ(i,sT,ei)
        #保存i和j的旧的alpha
        alphaOldI = sT.alphas[i].copy()
        alphaOldJ = sT.alphas[j].copy()
        #计算alpha的上下界
        if (sT.labels[0,i] != sT.labels[0,j]):
            L = max(0, sT.alphas[j] - sT.alphas[i])
            H = min(sT.C, sT.C + sT.alphas[j] - sT.alphas[i])
        else:
            L = max(0, sT.alphas[j] + sT.alphas[i] - sT.C)
            H = min(sT.C, sT.alphas[j] + sT.alphas[i])
        if L == H :print("L==H");return 0
        #更新j的alpha
        eta = sT.K[i,i]+sT.K[j,j]-2*sT.K[i,j]
        if eta <=0 :print("eta<=0");return 0 
        sT.alphas[j] = alphaOldJ + sT.labels[0,j]*(ei-ej)/eta
        sT.alphas[j] = clipAlpha(sT.alphas[j],H,L)
        updateEk(sT,j)
        #如果j的移动过小,则直接返回
        if (abs(sT.alphas[j]-alphaOldJ) < 0.00001):
            print("j not move enough") ;return 0
        #更新i的alpha
        sT.alphas[i] = alphaOldI + sT.labels[0,i]*sT.labels[0,j]*(alphaOldJ-sT.alphas[j])
        updateEk(sT,i)
        #更新b
        b1 = sT.b - ei- sT.labels[0,i]*(sT.alphas[i]-alphaOldI)*sT.K[i,i] - sT.labels[0,j]*(sT.alphas[j]-alphaOldJ)*sT.K[i,j]
        b2 = sT.b - ej- sT.labels[0,i]*(sT.alphas[i]-alphaOldI)*sT.K[i,j] - sT.labels[0,j]*(sT.alphas[j]-alphaOldJ)*sT.K[j,j]
        if (0 < sT.alphas[i]) and (sT.C > sT.alphas[i]): sT.b = b1
        elif (0 < sT.alphas[j]) and (sT.C > sT.alphas[j]): sT.b = b2
        else: sT.b = (b1 + b2)/2.0
        return 1
    else:
        return 0 
def smoP(dataMatIn, classLabels, C, toler, maxIter,kTup=['lin',0]):    
    sT = svmStrcut(np.mat(dataMatIn),np.mat(classLabels),C,toler,kTup)
    count = 0
    entireSet = True
    while (count < maxIter):
        alphaPairsChanged = 0
        #首先计算满足0<alpha<c的样本,选择违反KTT条件最差的样本点
        nonBoundIs = np.nonzero((sT.alphas.A > 0) * (sT.alphas.A < C))[0]
        worestI = None
        worestErr = 0.0
        for i in nonBoundIs:
            error = calcEk(sT,i)
            if error > worestErr:
                worestErr = error
                worestI = i
        if worestI != None:
            alphaPairsChanged += innerLoop(worestI,sT)
        #遍历全部样本
        if alphaPairsChanged==0 :
            for i in range(sT.m):
                alphaPairsChanged += innerLoop(i,sT)
                print ("fullSet, iter: %d i:%d, pairs changed %d" %(count,i,alphaPairsChanged))
        count += 1
        if entireSet: entireSet = False 
        elif (alphaPairsChanged == 0): entireSet = True  
        print ("iteration number: %d" % count)
        if alphaPairsChanged==0 :
            break
    return sT.b,sT.alphas
#计算权值,w=sum(alpha*label*x)
def calcWs(alphas,dataArr,classLabels):
    X = np.mat(dataArr); labelMat = np.mat(classLabels).transpose()
    m,n = np.shape(X)
    w = np.zeros((n,1))
    for i in range(m):
        w += np.multiply(alphas[i]*labelMat[i],X[i,:].T)
    return w
def show(data,labels,w,b):
    x1=[]
    y1=[]
    x2=[]
    y2=[]
    for i in range(len(labels)):
        if labels[i] == 1:
            x1.append(data[i,0])
            y1.append(data[i,1])
        else:
            x2.append(data[i,0])
            y2.append(data[i,1])
    plt.scatter(x1,y1,edgecolors='r')
    plt.scatter(x2,y2,edgecolors='k')

    max_x = (np.max(data[:,0]))
    min_x = (np.min(data[:,0]))
    y_min_x = float(-(min_x*w[0]+b)/w[1])
    y_max_x = float(-(max_x*w[0]+b)/w[1])
    plt.plot([min_x, max_x], [y_min_x, y_max_x], '-g')

    plt.show()

def testRbf(k1=1.3):
    dataArr,labelArr = readData('testSetRBF.txt')
    b,alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, ('rbf', k1)) 
    datMat=np.mat(dataArr)
    labelMat = np.mat(labelArr).transpose()
    svInd=np.nonzero(alphas.A>0)[0]
    sVs=datMat[svInd] 
    labelSV = labelMat[svInd];
    print ("there are %d Support Vectors" % np.shape(sVs)[0])
    m,n = np.shape(datMat)
    errorCount = 0
    for i in range(m):
        kernelEval = kernelTrans(sVs,datMat[i,:],('rbf', k1))
        predict=kernelEval.T * np.multiply(labelSV,alphas[svInd]) + b
        if np.sign(predict)!=np.sign(labelArr[i]): errorCount += 1
    print ("the training error rate is: %f" % (float(errorCount)/m))
    
    dataArr,labelArr = readData('testSetRBF2.txt')
    errorCount = 0
    datMat= np.mat(dataArr)
    labelMat = np.mat(labelArr).transpose()
    m,n = np.shape(datMat)
    for i in range(m):
        kernelEval = kernelTrans(sVs,datMat[i,:],('rbf', k1))
        predict=kernelEval.T * np.multiply(labelSV,alphas[svInd]) + b
        if np.sign(predict)!=np.sign(labelArr[i]): errorCount += 1    
    print ("the test error rate is: %f" % (float(errorCount)/m))    
def test():
    dataArr,labelArr = readData('testSet.txt')
    b,alphas = smoP(dataArr, labelArr, 0.6, 0.0001, 40, ('lin', 0))
    ws = calcWs(alphas,dataArr,labelArr)
    show(np.mat(dataArr),np.array(labelArr),np.array(ws),b)
def main():
    pass
if __name__ == '__main__':
    main()
testSet.txt数据集
3.542485	1.977398	-1
3.018896	2.556416	-1
7.551510	-1.580030	1
2.114999	-0.004466	-1
8.127113	1.274372	1
7.108772	-0.986906	1
8.610639	2.046708	1
2.326297	0.265213	-1
3.634009	1.730537	-1
0.341367	-0.894998	-1
3.125951	0.293251	-1
2.123252	-0.783563	-1
0.887835	-2.797792	-1
7.139979	-2.329896	1
1.696414	-1.212496	-1
8.117032	0.623493	1
8.497162	-0.266649	1
4.658191	3.507396	-1
8.197181	1.545132	1
1.208047	0.213100	-1
1.928486	-0.321870	-1
2.175808	-0.014527	-1
7.886608	0.461755	1
3.223038	-0.552392	-1
3.628502	2.190585	-1
7.407860	-0.121961	1
7.286357	0.251077	1
2.301095	-0.533988	-1
-0.232542	-0.547690	-1
3.457096	-0.082216	-1
3.023938	-0.057392	-1
8.015003	0.885325	1
8.991748	0.923154	1
7.916831	-1.781735	1
7.616862	-0.217958	1
2.450939	0.744967	-1
7.270337	-2.507834	1
1.749721	-0.961902	-1
1.803111	-0.176349	-1
8.804461	3.044301	1
1.231257	-0.568573	-1
2.074915	1.410550	-1
-0.743036	-1.736103	-1
3.536555	3.964960	-1
8.410143	0.025606	1
7.382988	-0.478764	1
6.960661	-0.245353	1
8.234460	0.701868	1
8.168618	-0.903835	1
1.534187	-0.622492	-1
9.229518	2.066088	1
7.886242	0.191813	1
2.893743	-1.643468	-1
1.870457	-1.040420	-1
5.286862	-2.358286	1
6.080573	0.418886	1
2.544314	1.714165	-1
6.016004	-3.753712	1
0.926310	-0.564359	-1
0.870296	-0.109952	-1
2.369345	1.375695	-1
1.363782	-0.254082	-1
7.279460	-0.189572	1
1.896005	0.515080	-1
8.102154	-0.603875	1
2.529893	0.662657	-1
1.963874	-0.365233	-1
8.132048	0.785914	1
8.245938	0.372366	1
6.543888	0.433164	1
-0.236713	-5.766721	-1
8.112593	0.295839	1
9.803425	1.495167	1
1.497407	-0.552916	-1
1.336267	-1.632889	-1
9.205805	-0.586480	1
1.966279	-1.840439	-1
8.398012	1.584918	1
7.239953	-1.764292	1
7.556201	0.241185	1
9.015509	0.345019	1
8.266085	-0.230977	1
8.545620	2.788799	1
9.295969	1.346332	1
2.404234	0.570278	-1
2.037772	0.021919	-1
1.727631	-0.453143	-1
1.979395	-0.050773	-1
8.092288	-1.372433	1
1.667645	0.239204	-1
9.854303	1.365116	1
7.921057	-1.327587	1
8.500757	1.492372	1
1.339746	-0.291183	-1
3.107511	0.758367	-1
2.609525	0.902979	-1
3.263585	1.367898	-1
2.912122	-0.202359	-1
1.731786	0.589096	-1
2.387003	1.573131	-1
testSetRBF.txt数据集
-0.214824	0.662756	-1.000000
-0.061569	-0.091875	1.000000
0.406933	0.648055	-1.000000
0.223650	0.130142	1.000000
0.231317	0.766906	-1.000000
-0.748800	-0.531637	-1.000000
-0.557789	0.375797	-1.000000
0.207123	-0.019463	1.000000
0.286462	0.719470	-1.000000
0.195300	-0.179039	1.000000
-0.152696	-0.153030	1.000000
0.384471	0.653336	-1.000000
-0.117280	-0.153217	1.000000
-0.238076	0.000583	1.000000
-0.413576	0.145681	1.000000
0.490767	-0.680029	-1.000000
0.199894	-0.199381	1.000000
-0.356048	0.537960	-1.000000
-0.392868	-0.125261	1.000000
0.353588	-0.070617	1.000000
0.020984	0.925720	-1.000000
-0.475167	-0.346247	-1.000000
0.074952	0.042783	1.000000
0.394164	-0.058217	1.000000
0.663418	0.436525	-1.000000
0.402158	0.577744	-1.000000
-0.449349	-0.038074	1.000000
0.619080	-0.088188	-1.000000
0.268066	-0.071621	1.000000
-0.015165	0.359326	1.000000
0.539368	-0.374972	-1.000000
-0.319153	0.629673	-1.000000
0.694424	0.641180	-1.000000
0.079522	0.193198	1.000000
0.253289	-0.285861	1.000000
-0.035558	-0.010086	1.000000
-0.403483	0.474466	-1.000000
-0.034312	0.995685	-1.000000
-0.590657	0.438051	-1.000000
-0.098871	-0.023953	1.000000
-0.250001	0.141621	1.000000
-0.012998	0.525985	-1.000000
0.153738	0.491531	-1.000000
0.388215	-0.656567	-1.000000
0.049008	0.013499	1.000000
0.068286	0.392741	1.000000
0.747800	-0.066630	-1.000000
0.004621	-0.042932	1.000000
-0.701600	0.190983	-1.000000
0.055413	-0.024380	1.000000
0.035398	-0.333682	1.000000
0.211795	0.024689	1.000000
-0.045677	0.172907	1.000000
0.595222	0.209570	-1.000000
0.229465	0.250409	1.000000
-0.089293	0.068198	1.000000
0.384300	-0.176570	1.000000
0.834912	-0.110321	-1.000000
-0.307768	0.503038	-1.000000
-0.777063	-0.348066	-1.000000
0.017390	0.152441	1.000000
-0.293382	-0.139778	1.000000
-0.203272	0.286855	1.000000
0.957812	-0.152444	-1.000000
0.004609	-0.070617	1.000000
-0.755431	0.096711	-1.000000
-0.526487	0.547282	-1.000000
-0.246873	0.833713	-1.000000
0.185639	-0.066162	1.000000
0.851934	0.456603	-1.000000
-0.827912	0.117122	-1.000000
0.233512	-0.106274	1.000000
0.583671	-0.709033	-1.000000
-0.487023	0.625140	-1.000000
-0.448939	0.176725	1.000000
0.155907	-0.166371	1.000000
0.334204	0.381237	-1.000000
0.081536	-0.106212	1.000000
0.227222	0.527437	-1.000000
0.759290	0.330720	-1.000000
0.204177	-0.023516	1.000000
0.577939	0.403784	-1.000000
-0.568534	0.442948	-1.000000
-0.011520	0.021165	1.000000
0.875720	0.422476	-1.000000
0.297885	-0.632874	-1.000000
-0.015821	0.031226	1.000000
0.541359	-0.205969	-1.000000
-0.689946	-0.508674	-1.000000
-0.343049	0.841653	-1.000000
0.523902	-0.436156	-1.000000
0.249281	-0.711840	-1.000000
0.193449	0.574598	-1.000000
-0.257542	-0.753885	-1.000000
-0.021605	0.158080	1.000000
0.601559	-0.727041	-1.000000
-0.791603	0.095651	-1.000000
-0.908298	-0.053376	-1.000000
0.122020	0.850966	-1.000000
-0.725568	-0.292022	-1.000000
testSetRBF2.txt数据集
0.676771	-0.486687	-1.000000
0.008473	0.186070	1.000000
-0.727789	0.594062	-1.000000
0.112367	0.287852	1.000000
0.383633	-0.038068	1.000000
-0.927138	-0.032633	-1.000000
-0.842803	-0.423115	-1.000000
-0.003677	-0.367338	1.000000
0.443211	-0.698469	-1.000000
-0.473835	0.005233	1.000000
0.616741	0.590841	-1.000000
0.557463	-0.373461	-1.000000
-0.498535	-0.223231	-1.000000
-0.246744	0.276413	1.000000
-0.761980	-0.244188	-1.000000
0.641594	-0.479861	-1.000000
-0.659140	0.529830	-1.000000
-0.054873	-0.238900	1.000000
-0.089644	-0.244683	1.000000
-0.431576	-0.481538	-1.000000
-0.099535	0.728679	-1.000000
-0.188428	0.156443	1.000000
0.267051	0.318101	1.000000
0.222114	-0.528887	-1.000000
0.030369	0.113317	1.000000
0.392321	0.026089	1.000000
0.298871	-0.915427	-1.000000
-0.034581	-0.133887	1.000000
0.405956	0.206980	1.000000
0.144902	-0.605762	-1.000000
0.274362	-0.401338	1.000000
0.397998	-0.780144	-1.000000
0.037863	0.155137	1.000000
-0.010363	-0.004170	1.000000
0.506519	0.486619	-1.000000
0.000082	-0.020625	1.000000
0.057761	-0.155140	1.000000
0.027748	-0.553763	-1.000000
-0.413363	-0.746830	-1.000000
0.081500	-0.014264	1.000000
0.047137	-0.491271	1.000000
-0.267459	0.024770	1.000000
-0.148288	-0.532471	-1.000000
-0.225559	-0.201622	1.000000
0.772360	-0.518986	-1.000000
-0.440670	0.688739	-1.000000
0.329064	-0.095349	1.000000
0.970170	-0.010671	-1.000000
-0.689447	-0.318722	-1.000000
-0.465493	-0.227468	-1.000000
-0.049370	0.405711	1.000000
-0.166117	0.274807	1.000000
0.054483	0.012643	1.000000
0.021389	0.076125	1.000000
-0.104404	-0.914042	-1.000000
0.294487	0.440886	-1.000000
0.107915	-0.493703	-1.000000
0.076311	0.438860	1.000000
0.370593	-0.728737	-1.000000
0.409890	0.306851	-1.000000
0.285445	0.474399	-1.000000
-0.870134	-0.161685	-1.000000
-0.654144	-0.675129	-1.000000
0.285278	-0.767310	-1.000000
0.049548	-0.000907	1.000000
0.030014	-0.093265	1.000000
-0.128859	0.278865	1.000000
0.307463	0.085667	1.000000
0.023440	0.298638	1.000000
0.053920	0.235344	1.000000
0.059675	0.533339	-1.000000
0.817125	0.016536	-1.000000
-0.108771	0.477254	1.000000
-0.118106	0.017284	1.000000
0.288339	0.195457	1.000000
0.567309	-0.200203	-1.000000
-0.202446	0.409387	1.000000
-0.330769	-0.240797	1.000000
-0.422377	0.480683	-1.000000
-0.295269	0.326017	1.000000
0.261132	0.046478	1.000000
-0.492244	-0.319998	-1.000000
-0.384419	0.099170	1.000000
0.101882	-0.781145	-1.000000
0.234592	-0.383446	1.000000
-0.020478	-0.901833	-1.000000
0.328449	0.186633	1.000000
-0.150059	-0.409158	1.000000
-0.155876	-0.843413	-1.000000
-0.098134	-0.136786	1.000000
0.110575	-0.197205	1.000000
0.219021	0.054347	1.000000
0.030152	0.251682	1.000000
0.033447	-0.122824	1.000000
-0.686225	-0.020779	-1.000000
-0.911211	-0.262011	-1.000000
0.572557	0.377526	-1.000000
-0.073647	-0.519163	-1.000000
-0.281830	-0.797236	-1.000000
-0.555263	0.126232	-1.000000



发表评论

0个评论

我要留言×

技术领域:

我要留言×

留言成功,我们将在审核后加至投票列表中!

提示x

人工智能机器学习知识库已成功保存至我的图谱现在你可以用它来管理自己的知识内容了

删除图谱提示×

你保存在该图谱下的知识内容也会被删除,建议你先将内容移到其他图谱中。你确定要删除知识图谱及其内容吗?

删除节点提示×

无法删除该知识节点,因该节点下仍保存有相关知识内容!

删除节点提示×

你确定要删除该知识节点吗?