OpenCv学习笔记---OpenCv中支持向量机模块SVM------源代码分析

作者:maweifei

/****************************************************************************************
                                 Support Vector Machines                              
****************************************************************************************/

// SVM training parameters
//【1】SVM的训练参数
//【2】在上一篇教程中,我们以--线性可分--的例子,简单讲解了SVM的基本原理。然而,SVM的实际应用情形可能
       //复杂的多(比如1--线性可分问题;2--非线性可分问题;3--SVM核函数的选择问题等等).总而言之,我
	   //们在训练之前,需要对SVM做一些参数设定,这类参数就保存在---CvSVMParams这个类中
struct CV_EXPORTS_W_MAP CvSVMParams
{
	//【1】CvSVMParams的默认构造函数
    CvSVMParams();
	//【2】CvSVMParams的带参构造函数
    CvSVMParams( int svm_type, int kernel_type,
                 double degree, double gamma, double coef0,
                 double Cvalue, double nu, double p,
                 CvMat* class_weights, CvTermCriteria term_crit );
    //【3】svm_type,SVM的类型
	        //【1】C_SVM----分类器---允许用异常值惩罚因子C进行不完全分类
			//【2】NU_SVC---类似然不完全分类的--分类器.参数nu取代了c,其值在区间[0,1]中,nu越大,
			      //决策边界越平滑
			//【3】ONE_CLASS--单分类器,所有饿训练数据提取自同一个类里,然后SVM建立了一个分界线以
			      //以分割该类在特征空间中所占区域与其他类在特征空间中所占区域
			//【4】EPS_SVR----回归--训练集中的特征向量和拟合出来的超平面的距离需要小于p.异常值惩
			     //罚因子C被采用
    CV_PROP_RW int         svm_type;
	//【4】kernel_type--核类型:
	        //【1】CvSVM::LINEAR---没有任何向量映射至高维空间,线性区分(或回归)在原始特征空间中被
			     //完成,这是最快的选择.d(x,y)=x*y=(x,y)
			//【2】CvSVM::POLY-----多项式核d(x,y)=(gamma*(x*y)+core0)degree
			//【3】CvSVM::RBF------径向基,这对大多数情况都是一个比较好的选择d(x,y)=exp(-gramma*|x-y|2)
			//【4】CvSVM::SIGMOID---sigmoid函数被用作核函数:d(x,y)=tanh(gamma*(x*y)+coref0)
    CV_PROP_RW int         kernel_type;
	//【5】degree,gramma,coref0都是核函数的参数,具体的参见上面的核函数方程
    CV_PROP_RW double      degree; // for poly
    CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
    CV_PROP_RW double      coef0;  // for poly/sigmoid
    //【6】C,nu,p---在一般的SVM优化求解时的参数
    CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
    CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
    CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
	//【8】class_weights--可选权重,赋给指定的类别.一般乘以C以后去影响不同类别的错误分类惩罚项.
	       //权重越大,某一类别的误分类数据的惩罚项就越大
    CvMat*      class_weights; // for CV_SVM_C_SVC
	//【9】迭代训练过程的--终止--解决了部分受约束二次最优问题
    CV_PROP_RW CvTermCriteria term_crit; // termination criteria
};

//【1】CvSVM核函数类
struct CV_EXPORTS CvSVMKernel
{
    typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
                                       const float* another, float* results );
	//【1】核函数类的构造函数
    CvSVMKernel();
    CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
	//【2】
    virtual bool create( const CvSVMParams* params, Calc _calc_func );
	//【3】析构函数
    virtual ~CvSVMKernel();
    //【4】
    virtual void clear();
	//【5】
    virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
    //【6】指向CvSVM的参数类的---类对象指针
    const CvSVMParams* params;
    Calc calc_func;
    //【7】虚函数
    virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
                                    const float* another, float* results,
                                    double alpha, double beta );

    virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
                              const float* another, float* results );
    virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
                           const float* another, float* results );
    virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
                            const float* another, float* results );
    virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
                               const float* another, float* results );
};


struct CvSVMKernelRow
{
    CvSVMKernelRow* prev;
    CvSVMKernelRow* next;
    float* data;
};


struct CvSVMSolutionInfo
{
    double obj;
    double rho;
    double upper_bound_p;
    double upper_bound_n;
    double r;   // for Solver_NU
};

class CV_EXPORTS CvSVMSolver
{
public:
    typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
    typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
    typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );

    CvSVMSolver();

    CvSVMSolver( int count, int var_count, const float** samples, schar* y,
                 int alpha_count, double* alpha, double Cp, double Cn,
                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
                 SelectWorkingSet select_working_set, CalcRho calc_rho );
    virtual bool create( int count, int var_count, const float** samples, schar* y,
                 int alpha_count, double* alpha, double Cp, double Cn,
                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
                 SelectWorkingSet select_working_set, CalcRho calc_rho );
    virtual ~CvSVMSolver();

    virtual void clear();
    virtual bool solve_generic( CvSVMSolutionInfo& si );

    virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
                              double Cp, double Cn, CvMemStorage* storage,
                              CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
    virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
                               CvMemStorage* storage, CvSVMKernel* kernel,
                               double* alpha, CvSVMSolutionInfo& si );
    virtual bool solve_one_class( int count, int var_count, const float** samples,
                                  CvMemStorage* storage, CvSVMKernel* kernel,
                                  double* alpha, CvSVMSolutionInfo& si );

    virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
                                CvMemStorage* storage, CvSVMKernel* kernel,
                                double* alpha, CvSVMSolutionInfo& si );

    virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
                               CvMemStorage* storage, CvSVMKernel* kernel,
                               double* alpha, CvSVMSolutionInfo& si );

    virtual float* get_row_base( int i, bool* _existed );
    virtual float* get_row( int i, float* dst );

    int sample_count;
    int var_count;
    int cache_size;
    int cache_line_size;
    const float** samples;
    const CvSVMParams* params;
    CvMemStorage* storage;
    CvSVMKernelRow lru_list;
    CvSVMKernelRow* rows;

    int alpha_count;

    double* G;
    double* alpha;

    // -1 - lower bound, 0 - free, 1 - upper bound
    schar* alpha_status;

    schar* y;
    double* b;
    float* buf[2];
    double eps;
    int max_iter;
    double C[2];  // C[0] == Cn, C[1] == Cp
    CvSVMKernel* kernel;

    SelectWorkingSet select_working_set_func;
    CalcRho calc_rho_func;
    GetRow get_row_func;

    virtual bool select_working_set( int& i, int& j );
    virtual bool select_working_set_nu_svm( int& i, int& j );
    virtual void calc_rho( double& rho, double& r );
    virtual void calc_rho_nu_svm( double& rho, double& r );

    virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
    virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
    virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
};


struct CvSVMDecisionFunc
{
    double rho;
    int sv_count;
    double* alpha;
    int* sv_index;
};


// SVM model
//【1】支持向量机CvSVM,继承自基类CvStatModel

class CV_EXPORTS_W CvSVM : public CvStatModel
{
public:
    // SVM type
	//【1】SVM的类型
	//【2】如果选择SVC--则是分类器
	//【3】如果选择SVR--则SVR是SVM的回归
	//【1】C_SVC----分类器---允许用异常值惩罚因子C进行不完全分类
	//【2】NU_SVC---类似然不完全分类的--分类器.参数nu取代了c,其值在区间[0,1]中,nu越大,
		   //决策边界越平滑
	//【3】ONE_CLASS--单分类器,所有饿训练数据提取自同一个类里,然后SVM建立了一个分界线以
		   //以分割该类在特征空间中所占区域与其他类在特征空间中所占区域
	//【4】EPS_SVR----回归--训练集中的特征向量和拟合出来的超平面的距离需要小于p.异常值惩
		   //罚因子C被采用
    enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };

    // SVM kernel type
	//【2】SVM提供四种核函数,分别是:
	        //【1】LINEAR----线性
			//【2】POLY------多项式
			//【3】RBF-------径向基
			//【4】SIGMOID---sigmoid型函数
    enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };

    // SVM params type
	//【3】SVM的参数类型
	        //【1】
			//【2】
    enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
    //【4】CvSVM的默认构造函数和析构函数
    CV_WRAP CvSVM();
    virtual ~CvSVM();
    //【5】CvSVM的带参构造函数
    CvSVM( const CvMat* trainData, const CvMat* responses,
           const CvMat* varIdx=0, const CvMat* sampleIdx=0,
           CvSVMParams params=CvSVMParams() );
    //【6】训练支持向量机,调用CvSVM::train来建立SVM模型
	//【7】该方法训练支持向量机模型,它遵循的泛型“方法”约定具有如下的限制:
			//【1】仅仅支持CV_ROW_SAMPLE--行样本的数据布局
			//【2】所有的输入变量总是有序的
			//【3】所有的params参数都由CvSVMParams结构体收集
    virtual bool train( const CvMat* trainData, const CvMat* responses,
                        const CvMat* varIdx=0, const CvMat* sampleIdx=0,
                        CvSVMParams params=CvSVMParams() );
    //【8】使用最佳的,最理想的参数训练SVM支持向量机模型
    virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
        const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
        int kfold = 10,
        CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
        CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
        CvParamGrid pGrid      = get_default_grid(CvSVM::P),
        CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
        CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
        CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
        bool balanced=false );
    //【8】函数CvSVM::predit通过重建完毕的支持向量机来将输入的样本分类.本例中,我们通过该函数给向量空间着色,以及
	      //将图像中的每个像素当做笛卡尔平面上的一点,每一点的着色取决于SVM对该点的分类类别:绿色表示标记为1的点,
		  //蓝色表示标记为-1的点
    virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
    virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
   
    CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
          const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
          CvSVMParams params=CvSVMParams() );

    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
                       const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
                       CvSVMParams params=CvSVMParams() );

    CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
                            const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
                            int k_fold = 10,
                            CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
                            CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
                            CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
                            CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
                            CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
                            CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
                            bool balanced=false);
    CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
    CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;
    //【10】得到支持向量的个数
    CV_WRAP virtual int get_support_vector_count() const;
    virtual const float* get_support_vector(int i) const;
    virtual CvSVMParams get_params() const { return params; };
    CV_WRAP virtual void clear();

    static CvParamGrid get_default_grid( int param_id );

    virtual void write( CvFileStorage* storage, const char* name ) const;
    virtual void read( CvFileStorage* storage, CvFileNode* node );
    CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }

protected:

    virtual bool set_params( const CvSVMParams& params );
    virtual bool train1( int sample_count, int var_count, const float** samples,
                    const void* responses, double Cp, double Cn,
                    CvMemStorage* _storage, double* alpha, double& rho );
    virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
                    const CvMat* responses, CvMemStorage* _storage, double* alpha );
    virtual void create_kernel();
    virtual void create_solver();

    virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;

    virtual void write_params( CvFileStorage* fs ) const;
    virtual void read_params( CvFileStorage* fs, CvFileNode* node );

    void optimize_linear_svm();

    CvSVMParams params;
    CvMat* class_labels;
    int var_all;
    float** sv;
    int sv_total;
    CvMat* var_idx;
    CvMat* class_weights;
    CvSVMDecisionFunc* decision_func;
    CvMemStorage* storage;

    CvSVMSolver* solver;
    CvSVMKernel* kernel;

private:
    CvSVM(const CvSVM&);
    CvSVM& operator = (const CvSVM&);
};

发表评论

0个评论

我要留言×

技术领域:

我要留言×

留言成功,我们将在审核后加至投票列表中!

提示x

人工智能机器学习知识库已成功保存至我的图谱现在你可以用它来管理自己的知识内容了

删除图谱提示×

你保存在该图谱下的知识内容也会被删除,建议你先将内容移到其他图谱中。你确定要删除知识图谱及其内容吗?

删除节点提示×

无法删除该知识节点,因该节点下仍保存有相关知识内容!

删除节点提示×

你确定要删除该知识节点吗?