00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #ifndef __OPENCV_ML_HPP__
00042 #define __OPENCV_ML_HPP__
00043
00044
00045 #if _MSC_VER >= 1400
00046 #pragma warning( disable : 4996 )
00047 #endif
00048
00049 #ifndef SKIP_INCLUDES
00050
00051 #include "opencv2/core/core.hpp"
00052 #include <limits.h>
00053
00054 #if defined WIN32 || defined _WIN32
00055 #include <windows.h>
00056 #endif
00057
00058 #else // SKIP_INCLUDES
00059
00060 #if defined WIN32 || defined _WIN32
00061 #define CV_CDECL __cdecl
00062 #define CV_STDCALL __stdcall
00063 #else
00064 #define CV_CDECL
00065 #define CV_STDCALL
00066 #endif
00067
00068 #ifndef CV_EXTERN_C
00069 #ifdef __cplusplus
00070 #define CV_EXTERN_C extern "C"
00071 #define CV_DEFAULT(val) = val
00072 #else
00073 #define CV_EXTERN_C
00074 #define CV_DEFAULT(val)
00075 #endif
00076 #endif
00077
00078 #ifndef CV_EXTERN_C_FUNCPTR
00079 #ifdef __cplusplus
00080 #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
00081 #else
00082 #define CV_EXTERN_C_FUNCPTR(x) typedef x
00083 #endif
00084 #endif
00085
00086 #ifndef CV_INLINE
00087 #if defined __cplusplus
00088 #define CV_INLINE inline
00089 #elif (defined WIN32 || defined _WIN32) && !defined __GNUC__
00090 #define CV_INLINE __inline
00091 #else
00092 #define CV_INLINE static
00093 #endif
00094 #endif
00095
00096 #if (defined WIN32 || defined _WIN32) && defined CVAPI_EXPORTS
00097 #define CV_EXPORTS __declspec(dllexport)
00098 #else
00099 #define CV_EXPORTS
00100 #endif
00101
00102 #ifndef CVAPI
00103 #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
00104 #endif
00105
00106 #endif // SKIP_INCLUDES
00107
00108
00109 #ifdef __cplusplus
00110
00111
00112
00113 #undef check
00114
00115
00116
00117
00118
00119
00120 #define CV_LOG2PI (1.8378770664093454835606594728112)
00121
00122
00123 #define CV_COL_SAMPLE 0
00124
00125
00126 #define CV_ROW_SAMPLE 1
00127
00128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00129
00130 struct CvVectors
00131 {
00132 int type;
00133 int dims, count;
00134 CvVectors* next;
00135 union
00136 {
00137 uchar** ptr;
00138 float** fl;
00139 double** db;
00140 } data;
00141 };
00142
00143 #if 0
00144
00145
00146
00147 typedef struct CvParamLattice
00148 {
00149 double min_val;
00150 double max_val;
00151 double step;
00152 }
00153 CvParamLattice;
00154
00155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00156 double log_step )
00157 {
00158 CvParamLattice pl;
00159 pl.min_val = MIN( min_val, max_val );
00160 pl.max_val = MAX( min_val, max_val );
00161 pl.step = MAX( log_step, 1. );
00162 return pl;
00163 }
00164
00165 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00166 {
00167 CvParamLattice pl = {0,0,0};
00168 return pl;
00169 }
00170 #endif
00171
00172
00173 #define CV_VAR_NUMERICAL 0
00174 #define CV_VAR_ORDERED 0
00175 #define CV_VAR_CATEGORICAL 1
00176
00177 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
00178 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
00179 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
00180 #define CV_TYPE_NAME_ML_EM "opencv-ml-em"
00181 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
00182 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
00183 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
00184 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
00185 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
00186 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
00187
00188 #define CV_TRAIN_ERROR 0
00189 #define CV_TEST_ERROR 1
00190
00191 class CV_EXPORTS_W CvStatModel
00192 {
00193 public:
00194 CvStatModel();
00195 virtual ~CvStatModel();
00196
00197 virtual void clear();
00198
00199 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
00200 CV_WRAP virtual void load( const char* filename, const char* name=0 );
00201
00202 virtual void write( CvFileStorage* storage, const char* name ) const;
00203 virtual void read( CvFileStorage* storage, CvFileNode* node );
00204
00205 protected:
00206 const char* default_model_name;
00207 };
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218 class CvMLData;
00219
00220 struct CV_EXPORTS_W_MAP CvParamGrid
00221 {
00222
00223 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00224
00225 CvParamGrid()
00226 {
00227 min_val = max_val = step = 0;
00228 }
00229
00230 CvParamGrid( double _min_val, double _max_val, double log_step )
00231 {
00232 min_val = _min_val;
00233 max_val = _max_val;
00234 step = log_step;
00235 }
00236
00237 bool check() const;
00238
00239 CV_PROP_RW double min_val;
00240 CV_PROP_RW double max_val;
00241 CV_PROP_RW double step;
00242 };
00243
00244 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
00245 {
00246 public:
00247 CV_WRAP CvNormalBayesClassifier();
00248 virtual ~CvNormalBayesClassifier();
00249
00250 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
00251 const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
00252
00253 virtual bool train( const CvMat* trainData, const CvMat* responses,
00254 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
00255
00256 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
00257 CV_WRAP virtual void clear();
00258
00259 #ifndef SWIG
00260 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
00261 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
00262 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00263 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00264 bool update=false );
00265 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
00266 #endif
00267
00268 virtual void write( CvFileStorage* storage, const char* name ) const;
00269 virtual void read( CvFileStorage* storage, CvFileNode* node );
00270
00271 protected:
00272 int var_count, var_all;
00273 CvMat* var_idx;
00274 CvMat* cls_labels;
00275 CvMat** count;
00276 CvMat** sum;
00277 CvMat** productsum;
00278 CvMat** avg;
00279 CvMat** inv_eigen_values;
00280 CvMat** cov_rotate_mats;
00281 CvMat* c;
00282 };
00283
00284
00285
00286
00287
00288
00289
00290 class CV_EXPORTS_W CvKNearest : public CvStatModel
00291 {
00292 public:
00293
00294 CV_WRAP CvKNearest();
00295 virtual ~CvKNearest();
00296
00297 CvKNearest( const CvMat* trainData, const CvMat* responses,
00298 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
00299
00300 virtual bool train( const CvMat* trainData, const CvMat* responses,
00301 const CvMat* sampleIdx=0, bool is_regression=false,
00302 int maxK=32, bool updateBase=false );
00303
00304 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
00305 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
00306
00307 #ifndef SWIG
00308 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
00309 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
00310
00311 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00312 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
00313 int maxK=32, bool updateBase=false );
00314
00315 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
00316 const float** neighbors=0, cv::Mat* neighborResponses=0,
00317 cv::Mat* dist=0 ) const;
00318 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
00319 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
00320 #endif
00321
00322 virtual void clear();
00323 int get_max_k() const;
00324 int get_var_count() const;
00325 int get_sample_count() const;
00326 bool is_regression() const;
00327
00328 protected:
00329
00330 virtual float write_results( int k, int k1, int start, int end,
00331 const float* neighbor_responses, const float* dist, CvMat* _results,
00332 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00333
00334 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00335 float* neighbor_responses, const float** neighbors, float* dist ) const;
00336
00337
00338 int max_k, var_count;
00339 int total;
00340 bool regression;
00341 CvVectors* samples;
00342 };
00343
00344
00345
00346
00347
00348
00349 struct CV_EXPORTS_W_MAP CvSVMParams
00350 {
00351 CvSVMParams();
00352 CvSVMParams( int _svm_type, int _kernel_type,
00353 double _degree, double _gamma, double _coef0,
00354 double Cvalue, double _nu, double _p,
00355 CvMat* _class_weights, CvTermCriteria _term_crit );
00356
00357 CV_PROP_RW int svm_type;
00358 CV_PROP_RW int kernel_type;
00359 CV_PROP_RW double degree;
00360 CV_PROP_RW double gamma;
00361 CV_PROP_RW double coef0;
00362
00363 CV_PROP_RW double C;
00364 CV_PROP_RW double nu;
00365 CV_PROP_RW double p;
00366 CvMat* class_weights;
00367 CV_PROP_RW CvTermCriteria term_crit;
00368 };
00369
00370
00371 struct CV_EXPORTS CvSVMKernel
00372 {
00373 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00374 const float* another, float* results );
00375 CvSVMKernel();
00376 CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
00377 virtual bool create( const CvSVMParams* params, Calc _calc_func );
00378 virtual ~CvSVMKernel();
00379
00380 virtual void clear();
00381 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00382
00383 const CvSVMParams* params;
00384 Calc calc_func;
00385
00386 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00387 const float* another, float* results,
00388 double alpha, double beta );
00389
00390 virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00391 const float* another, float* results );
00392 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00393 const float* another, float* results );
00394 virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00395 const float* another, float* results );
00396 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00397 const float* another, float* results );
00398 };
00399
00400
00401 struct CvSVMKernelRow
00402 {
00403 CvSVMKernelRow* prev;
00404 CvSVMKernelRow* next;
00405 float* data;
00406 };
00407
00408
00409 struct CvSVMSolutionInfo
00410 {
00411 double obj;
00412 double rho;
00413 double upper_bound_p;
00414 double upper_bound_n;
00415 double r;
00416 };
00417
00418 class CV_EXPORTS CvSVMSolver
00419 {
00420 public:
00421 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00422 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00423 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00424
00425 CvSVMSolver();
00426
00427 CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00428 int alpha_count, double* alpha, double Cp, double Cn,
00429 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00430 SelectWorkingSet select_working_set, CalcRho calc_rho );
00431 virtual bool create( int count, int var_count, const float** samples, schar* y,
00432 int alpha_count, double* alpha, double Cp, double Cn,
00433 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00434 SelectWorkingSet select_working_set, CalcRho calc_rho );
00435 virtual ~CvSVMSolver();
00436
00437 virtual void clear();
00438 virtual bool solve_generic( CvSVMSolutionInfo& si );
00439
00440 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00441 double Cp, double Cn, CvMemStorage* storage,
00442 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00443 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00444 CvMemStorage* storage, CvSVMKernel* kernel,
00445 double* alpha, CvSVMSolutionInfo& si );
00446 virtual bool solve_one_class( int count, int var_count, const float** samples,
00447 CvMemStorage* storage, CvSVMKernel* kernel,
00448 double* alpha, CvSVMSolutionInfo& si );
00449
00450 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00451 CvMemStorage* storage, CvSVMKernel* kernel,
00452 double* alpha, CvSVMSolutionInfo& si );
00453
00454 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00455 CvMemStorage* storage, CvSVMKernel* kernel,
00456 double* alpha, CvSVMSolutionInfo& si );
00457
00458 virtual float* get_row_base( int i, bool* _existed );
00459 virtual float* get_row( int i, float* dst );
00460
00461 int sample_count;
00462 int var_count;
00463 int cache_size;
00464 int cache_line_size;
00465 const float** samples;
00466 const CvSVMParams* params;
00467 CvMemStorage* storage;
00468 CvSVMKernelRow lru_list;
00469 CvSVMKernelRow* rows;
00470
00471 int alpha_count;
00472
00473 double* G;
00474 double* alpha;
00475
00476
00477 schar* alpha_status;
00478
00479 schar* y;
00480 double* b;
00481 float* buf[2];
00482 double eps;
00483 int max_iter;
00484 double C[2];
00485 CvSVMKernel* kernel;
00486
00487 SelectWorkingSet select_working_set_func;
00488 CalcRho calc_rho_func;
00489 GetRow get_row_func;
00490
00491 virtual bool select_working_set( int& i, int& j );
00492 virtual bool select_working_set_nu_svm( int& i, int& j );
00493 virtual void calc_rho( double& rho, double& r );
00494 virtual void calc_rho_nu_svm( double& rho, double& r );
00495
00496 virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00497 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00498 virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00499 };
00500
00501
00502 struct CvSVMDecisionFunc
00503 {
00504 double rho;
00505 int sv_count;
00506 double* alpha;
00507 int* sv_index;
00508 };
00509
00510
00511
00512 class CV_EXPORTS_W CvSVM : public CvStatModel
00513 {
00514 public:
00515
00516 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00517
00518
00519 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00520
00521
00522 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00523
00524 CV_WRAP CvSVM();
00525 virtual ~CvSVM();
00526
00527 CvSVM( const CvMat* trainData, const CvMat* responses,
00528 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00529 CvSVMParams params=CvSVMParams() );
00530
00531 virtual bool train( const CvMat* trainData, const CvMat* responses,
00532 const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00533 CvSVMParams params=CvSVMParams() );
00534
00535 virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
00536 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
00537 int kfold = 10,
00538 CvParamGrid Cgrid = get_default_grid(CvSVM::C),
00539 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
00540 CvParamGrid pGrid = get_default_grid(CvSVM::P),
00541 CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
00542 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
00543 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
00544 bool balanced=false );
00545
00546 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
00547
00548 #ifndef SWIG
00549 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
00550 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00551 CvSVMParams params=CvSVMParams() );
00552
00553 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00554 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00555 CvSVMParams params=CvSVMParams() );
00556
00557 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
00558 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
00559 int k_fold = 10,
00560 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C),
00561 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA),
00562 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
00563 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
00564 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
00565 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
00566 bool balanced=false);
00567 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
00568 #endif
00569
00570 CV_WRAP virtual int get_support_vector_count() const;
00571 virtual const float* get_support_vector(int i) const;
00572 virtual CvSVMParams get_params() const { return params; };
00573 CV_WRAP virtual void clear();
00574
00575 static CvParamGrid get_default_grid( int param_id );
00576
00577 virtual void write( CvFileStorage* storage, const char* name ) const;
00578 virtual void read( CvFileStorage* storage, CvFileNode* node );
00579 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00580
00581 protected:
00582
00583 virtual bool set_params( const CvSVMParams& params );
00584 virtual bool train1( int sample_count, int var_count, const float** samples,
00585 const void* responses, double Cp, double Cn,
00586 CvMemStorage* _storage, double* alpha, double& rho );
00587 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00588 const CvMat* responses, CvMemStorage* _storage, double* alpha );
00589 virtual void create_kernel();
00590 virtual void create_solver();
00591
00592 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00593
00594 virtual void write_params( CvFileStorage* fs ) const;
00595 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00596
00597 CvSVMParams params;
00598 CvMat* class_labels;
00599 int var_all;
00600 float** sv;
00601 int sv_total;
00602 CvMat* var_idx;
00603 CvMat* class_weights;
00604 CvSVMDecisionFunc* decision_func;
00605 CvMemStorage* storage;
00606
00607 CvSVMSolver* solver;
00608 CvSVMKernel* kernel;
00609 };
00610
00611
00612
00613
00614
00615 struct CV_EXPORTS_W_MAP CvEMParams
00616 {
00617 CvEMParams() : nclusters(10), cov_mat_type(1),
00618 start_step(0), probs(0), weights(0), means(0), covs(0)
00619 {
00620 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
00621 }
00622
00623 CvEMParams( int _nclusters, int _cov_mat_type=1,
00624 int _start_step=0,
00625 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
00626 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
00627 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
00628 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
00629 {}
00630
00631 CV_PROP_RW int nclusters;
00632 CV_PROP_RW int cov_mat_type;
00633 CV_PROP_RW int start_step;
00634 const CvMat* probs;
00635 const CvMat* weights;
00636 const CvMat* means;
00637 const CvMat** covs;
00638 CV_PROP_RW CvTermCriteria term_crit;
00639 };
00640
00641
00642 class CV_EXPORTS_W CvEM : public CvStatModel
00643 {
00644 public:
00645
00646 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
00647
00648
00649 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
00650
00651 CV_WRAP CvEM();
00652 CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
00653 CvEMParams params=CvEMParams(), CvMat* labels=0 );
00654
00655
00656
00657 virtual ~CvEM();
00658
00659 virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
00660 CvEMParams params=CvEMParams(), CvMat* labels=0 );
00661
00662 virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
00663
00664 #ifndef SWIG
00665 CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
00666 CvEMParams params=CvEMParams() );
00667
00668 CV_WRAP virtual bool train( const cv::Mat& samples,
00669 const cv::Mat& sampleIdx=cv::Mat(),
00670 CvEMParams params=CvEMParams(),
00671 CV_OUT cv::Mat* labels=0 );
00672
00673 CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
00674
00675 CV_WRAP int getNClusters() const;
00676 CV_WRAP cv::Mat getMeans() const;
00677 CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
00678 CV_WRAP cv::Mat getWeights() const;
00679 CV_WRAP cv::Mat getProbs() const;
00680
00681 CV_WRAP inline double getLikelihood() const { return log_likelihood; };
00682 #endif
00683
00684 CV_WRAP virtual void clear();
00685
00686 int get_nclusters() const;
00687 const CvMat* get_means() const;
00688 const CvMat** get_covs() const;
00689 const CvMat* get_weights() const;
00690 const CvMat* get_probs() const;
00691
00692 inline double get_log_likelihood () const { return log_likelihood; };
00693
00694
00695
00696
00697
00698 protected:
00699
00700 virtual void set_params( const CvEMParams& params,
00701 const CvVectors& train_data );
00702 virtual void init_em( const CvVectors& train_data );
00703 virtual double run_em( const CvVectors& train_data );
00704 virtual void init_auto( const CvVectors& samples );
00705 virtual void kmeans( const CvVectors& train_data, int nclusters,
00706 CvMat* labels, CvTermCriteria criteria,
00707 const CvMat* means );
00708 CvEMParams params;
00709 double log_likelihood;
00710
00711 CvMat* means;
00712 CvMat** covs;
00713 CvMat* weights;
00714 CvMat* probs;
00715
00716 CvMat* log_weight_div_det;
00717 CvMat* inv_eigen_values;
00718 CvMat** cov_rotate_mats;
00719 };
00720
00721
00722
00723 \
00724 struct CvPair16u32s
00725 {
00726 unsigned short* u;
00727 int* i;
00728 };
00729
00730
00731 #define CV_DTREE_CAT_DIR(idx,subset) \
00732 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00733
00734 struct CvDTreeSplit
00735 {
00736 int var_idx;
00737 int condensed_idx;
00738 int inversed;
00739 float quality;
00740 CvDTreeSplit* next;
00741 union
00742 {
00743 int subset[2];
00744 struct
00745 {
00746 float c;
00747 int split_point;
00748 }
00749 ord;
00750 };
00751 };
00752
00753 struct CvDTreeNode
00754 {
00755 int class_idx;
00756 int Tn;
00757 double value;
00758
00759 CvDTreeNode* parent;
00760 CvDTreeNode* left;
00761 CvDTreeNode* right;
00762
00763 CvDTreeSplit* split;
00764
00765 int sample_count;
00766 int depth;
00767 int* num_valid;
00768 int offset;
00769 int buf_idx;
00770 double maxlr;
00771
00772
00773 int complexity;
00774 double alpha;
00775 double node_risk, tree_risk, tree_error;
00776
00777
00778 int* cv_Tn;
00779 double* cv_node_risk;
00780 double* cv_node_error;
00781
00782 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00783 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00784 };
00785
00786
00787 struct CV_EXPORTS_W_MAP CvDTreeParams
00788 {
00789 CV_PROP_RW int max_categories;
00790 CV_PROP_RW int max_depth;
00791 CV_PROP_RW int min_sample_count;
00792 CV_PROP_RW int cv_folds;
00793 CV_PROP_RW bool use_surrogates;
00794 CV_PROP_RW bool use_1se_rule;
00795 CV_PROP_RW bool truncate_pruned_tree;
00796 CV_PROP_RW float regression_accuracy;
00797 const float* priors;
00798
00799 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
00800 cv_folds(10), use_surrogates(true), use_1se_rule(true),
00801 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
00802 {}
00803
00804 CvDTreeParams( int _max_depth, int _min_sample_count,
00805 float _regression_accuracy, bool _use_surrogates,
00806 int _max_categories, int _cv_folds,
00807 bool _use_1se_rule, bool _truncate_pruned_tree,
00808 const float* _priors ) :
00809 max_categories(_max_categories), max_depth(_max_depth),
00810 min_sample_count(_min_sample_count), cv_folds (_cv_folds),
00811 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
00812 truncate_pruned_tree(_truncate_pruned_tree),
00813 regression_accuracy(_regression_accuracy),
00814 priors(_priors)
00815 {}
00816 };
00817
00818
00819 struct CV_EXPORTS CvDTreeTrainData
00820 {
00821 CvDTreeTrainData();
00822 CvDTreeTrainData( const CvMat* trainData, int tflag,
00823 const CvMat* responses, const CvMat* varIdx=0,
00824 const CvMat* sampleIdx=0, const CvMat* varType=0,
00825 const CvMat* missingDataMask=0,
00826 const CvDTreeParams& params=CvDTreeParams(),
00827 bool _shared=false, bool _add_labels=false );
00828 virtual ~CvDTreeTrainData();
00829
00830 virtual void set_data( const CvMat* trainData, int tflag,
00831 const CvMat* responses, const CvMat* varIdx=0,
00832 const CvMat* sampleIdx=0, const CvMat* varType=0,
00833 const CvMat* missingDataMask=0,
00834 const CvDTreeParams& params=CvDTreeParams(),
00835 bool _shared=false, bool _add_labels=false,
00836 bool _update_data=false );
00837 virtual void do_responses_copy();
00838
00839 virtual void get_vectors( const CvMat* _subsample_idx,
00840 float* values, uchar* missing, float* responses, bool get_class_idx=false );
00841
00842 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00843
00844 virtual void write_params( CvFileStorage* fs ) const;
00845 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00846
00847
00848 virtual void clear();
00849
00850 int get_num_classes() const;
00851 int get_var_type(int vi) const;
00852 int get_work_var_count() const {return work_var_count;}
00853
00854 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00855 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00856 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00857 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00858 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00859 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00860 const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00861 virtual int get_child_buf_idx( CvDTreeNode* n );
00862
00864
00865 virtual bool set_params( const CvDTreeParams& params );
00866 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00867 int storage_idx, int offset );
00868
00869 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00870 int split_point, int inversed, float quality );
00871 virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00872 virtual void free_node_data( CvDTreeNode* node );
00873 virtual void free_train_data();
00874 virtual void free_node( CvDTreeNode* node );
00875
00876 int sample_count, var_all, var_count, max_c_count;
00877 int ord_var_count, cat_var_count, work_var_count;
00878 bool have_labels, have_priors;
00879 bool is_classifier;
00880 int tflag;
00881
00882 const CvMat* train_data;
00883 const CvMat* responses;
00884 CvMat* responses_copy;
00885
00886 int buf_count, buf_size;
00887 bool shared;
00888 int is_buf_16u;
00889
00890 CvMat* cat_count;
00891 CvMat* cat_ofs;
00892 CvMat* cat_map;
00893
00894 CvMat* counts;
00895 CvMat* buf;
00896 CvMat* direction;
00897 CvMat* split_buf;
00898
00899 CvMat* var_idx;
00900 CvMat* var_type;
00901
00902
00903 CvMat* priors;
00904 CvMat* priors_mult;
00905
00906 CvDTreeParams params;
00907
00908 CvMemStorage* tree_storage;
00909 CvMemStorage* temp_storage;
00910
00911 CvDTreeNode* data_root;
00912
00913 CvSet* node_heap;
00914 CvSet* split_heap;
00915 CvSet* cv_heap;
00916 CvSet* nv_heap;
00917
00918 cv::RNG* rng;
00919 };
00920
00921 class CvDTree;
00922 class CvForestTree;
00923
00924 namespace cv
00925 {
00926 struct DTreeBestSplitFinder;
00927 struct ForestTreeBestSplitFinder;
00928 }
00929
00930 class CV_EXPORTS_W CvDTree : public CvStatModel
00931 {
00932 public:
00933 CV_WRAP CvDTree();
00934 virtual ~CvDTree();
00935
00936 virtual bool train( const CvMat* trainData, int tflag,
00937 const CvMat* responses, const CvMat* varIdx=0,
00938 const CvMat* sampleIdx=0, const CvMat* varType=0,
00939 const CvMat* missingDataMask=0,
00940 CvDTreeParams params=CvDTreeParams() );
00941
00942 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
00943
00944
00945 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
00946
00947 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
00948
00949 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
00950 bool preprocessedInput=false ) const;
00951
00952 #ifndef SWIG
00953 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
00954 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
00955 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
00956 const cv::Mat& missingDataMask=cv::Mat(),
00957 CvDTreeParams params=CvDTreeParams() );
00958
00959 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
00960 bool preprocessedInput=false ) const;
00961 CV_WRAP virtual cv::Mat getVarImportance();
00962 #endif
00963
00964 virtual const CvMat* get_var_importance();
00965 CV_WRAP virtual void clear();
00966
00967 virtual void read( CvFileStorage* fs, CvFileNode* node );
00968 virtual void write( CvFileStorage* fs, const char* name ) const;
00969
00970
00971 virtual void read( CvFileStorage* fs, CvFileNode* node,
00972 CvDTreeTrainData* data );
00973 virtual void write( CvFileStorage* fs ) const;
00974
00975 const CvDTreeNode* get_root() const;
00976 int get_pruned_tree_idx() const;
00977 CvDTreeTrainData* get_data();
00978
00979 protected:
00980 friend struct cv::DTreeBestSplitFinder;
00981
00982 virtual bool do_train( const CvMat* _subsample_idx );
00983
00984 virtual void try_split_node( CvDTreeNode* n );
00985 virtual void split_node_data( CvDTreeNode* n );
00986 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00987 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
00988 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00989 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00990 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00991 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
00992 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00993 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
00994 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00995 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00996 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00997 virtual double calc_node_dir( CvDTreeNode* node );
00998 virtual void complete_node_dir( CvDTreeNode* node );
00999 virtual void cluster_categories( const int* vectors, int vector_count,
01000 int var_count, int* sums, int k, int* cluster_labels );
01001
01002 virtual void calc_node_value( CvDTreeNode* node );
01003
01004 virtual void prune_cv();
01005 virtual double update_tree_rnc( int T, int fold );
01006 virtual int cut_tree( int T, int fold, double min_alpha );
01007 virtual void free_prune_data(bool cut_tree);
01008 virtual void free_tree();
01009
01010 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
01011 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
01012 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
01013 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
01014 virtual void write_tree_nodes( CvFileStorage* fs ) const;
01015 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
01016
01017 CvDTreeNode* root;
01018 CvMat* var_importance;
01019 CvDTreeTrainData* data;
01020
01021 public:
01022 int pruned_tree_idx;
01023 };
01024
01025
01026
01027
01028
01029
01030 class CvRTrees;
01031
01032 class CV_EXPORTS CvForestTree: public CvDTree
01033 {
01034 public:
01035 CvForestTree();
01036 virtual ~CvForestTree();
01037
01038 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
01039
01040 virtual int get_var_count() const {return data ? data->var_count : 0;}
01041 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
01042
01043
01044 virtual bool train( const CvMat* trainData, int tflag,
01045 const CvMat* responses, const CvMat* varIdx=0,
01046 const CvMat* sampleIdx=0, const CvMat* varType=0,
01047 const CvMat* missingDataMask=0,
01048 CvDTreeParams params=CvDTreeParams() );
01049
01050 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01051 virtual void read( CvFileStorage* fs, CvFileNode* node );
01052 virtual void read( CvFileStorage* fs, CvFileNode* node,
01053 CvDTreeTrainData* data );
01054
01055
01056 protected:
01057 friend struct cv::ForestTreeBestSplitFinder;
01058
01059 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
01060 CvRTrees* forest;
01061 };
01062
01063
01064 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
01065 {
01066
01067 CV_PROP_RW bool calc_var_importance;
01068 CV_PROP_RW int nactive_vars;
01069 CV_PROP_RW CvTermCriteria term_crit;
01070
01071 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
01072 calc_var_importance(false), nactive_vars(0)
01073 {
01074 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
01075 }
01076
01077 CvRTParams( int _max_depth, int _min_sample_count,
01078 float _regression_accuracy, bool _use_surrogates,
01079 int _max_categories, const float* _priors, bool _calc_var_importance,
01080 int _nactive_vars, int max_num_of_trees_in_the_forest,
01081 float forest_accuracy, int termcrit_type ) :
01082 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
01083 _use_surrogates, _max_categories, 0,
01084 false, false, _priors ),
01085 calc_var_importance(_calc_var_importance),
01086 nactive_vars(_nactive_vars)
01087 {
01088 term_crit = cvTermCriteria(termcrit_type,
01089 max_num_of_trees_in_the_forest, forest_accuracy);
01090 }
01091 };
01092
01093
01094 class CV_EXPORTS_W CvRTrees : public CvStatModel
01095 {
01096 public:
01097 CV_WRAP CvRTrees();
01098 virtual ~CvRTrees();
01099 virtual bool train( const CvMat* trainData, int tflag,
01100 const CvMat* responses, const CvMat* varIdx=0,
01101 const CvMat* sampleIdx=0, const CvMat* varType=0,
01102 const CvMat* missingDataMask=0,
01103 CvRTParams params=CvRTParams() );
01104
01105 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01106 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01107 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01108
01109 #ifndef SWIG
01110 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01111 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01112 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01113 const cv::Mat& missingDataMask=cv::Mat(),
01114 CvRTParams params=CvRTParams() );
01115 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01116 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01117 CV_WRAP virtual cv::Mat getVarImportance();
01118 #endif
01119
01120 CV_WRAP virtual void clear();
01121
01122 virtual const CvMat* get_var_importance();
01123 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01124 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01125
01126 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 );
01127
01128 virtual float get_train_error();
01129
01130 virtual void read( CvFileStorage* fs, CvFileNode* node );
01131 virtual void write( CvFileStorage* fs, const char* name ) const;
01132
01133 CvMat* get_active_var_mask();
01134 CvRNG* get_rng();
01135
01136 int get_tree_count() const;
01137 CvForestTree* get_tree(int i) const;
01138
01139 protected:
01140
01141 virtual bool grow_forest( const CvTermCriteria term_crit );
01142
01143
01144 CvForestTree** trees;
01145 CvDTreeTrainData* data;
01146 int ntrees;
01147 int nclasses;
01148 double oob_error;
01149 CvMat* var_importance;
01150 int nsamples;
01151
01152 cv::RNG* rng;
01153 CvMat* active_var_mask;
01154 };
01155
01156
01157
01158
01159 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01160 {
01161 virtual void set_data( const CvMat* trainData, int tflag,
01162 const CvMat* responses, const CvMat* varIdx=0,
01163 const CvMat* sampleIdx=0, const CvMat* varType=0,
01164 const CvMat* missingDataMask=0,
01165 const CvDTreeParams& params=CvDTreeParams(),
01166 bool _shared=false, bool _add_labels=false,
01167 bool _update_data=false );
01168 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01169 const float** ord_values, const int** missing, int* sample_buf = 0 );
01170 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01171 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01172 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01173 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01174 float* responses, bool get_class_idx=false );
01175 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01176 const CvMat* missing_mask;
01177 };
01178
01179 class CV_EXPORTS CvForestERTree : public CvForestTree
01180 {
01181 protected:
01182 virtual double calc_node_dir( CvDTreeNode* node );
01183 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01184 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01185 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01186 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01187 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01188 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01189 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01190 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01191 virtual void split_node_data( CvDTreeNode* n );
01192 };
01193
01194 class CV_EXPORTS_W CvERTrees : public CvRTrees
01195 {
01196 public:
01197 CV_WRAP CvERTrees();
01198 virtual ~CvERTrees();
01199 virtual bool train( const CvMat* trainData, int tflag,
01200 const CvMat* responses, const CvMat* varIdx=0,
01201 const CvMat* sampleIdx=0, const CvMat* varType=0,
01202 const CvMat* missingDataMask=0,
01203 CvRTParams params=CvRTParams());
01204 #ifndef SWIG
01205 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01206 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01207 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01208 const cv::Mat& missingDataMask=cv::Mat(),
01209 CvRTParams params=CvRTParams());
01210 #endif
01211 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01212 protected:
01213 virtual bool grow_forest( const CvTermCriteria term_crit );
01214 };
01215
01216
01217
01218
01219
01220
01221 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
01222 {
01223 CV_PROP_RW int boost_type;
01224 CV_PROP_RW int weak_count;
01225 CV_PROP_RW int split_criteria;
01226 CV_PROP_RW double weight_trim_rate;
01227
01228 CvBoostParams();
01229 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01230 int max_depth, bool use_surrogates, const float* priors );
01231 };
01232
01233
01234 class CvBoost;
01235
01236 class CV_EXPORTS CvBoostTree: public CvDTree
01237 {
01238 public:
01239 CvBoostTree();
01240 virtual ~CvBoostTree();
01241
01242 virtual bool train( CvDTreeTrainData* trainData,
01243 const CvMat* subsample_idx, CvBoost* ensemble );
01244
01245 virtual void scale( double s );
01246 virtual void read( CvFileStorage* fs, CvFileNode* node,
01247 CvBoost* ensemble, CvDTreeTrainData* _data );
01248 virtual void clear();
01249
01250
01251 virtual bool train( const CvMat* trainData, int tflag,
01252 const CvMat* responses, const CvMat* varIdx=0,
01253 const CvMat* sampleIdx=0, const CvMat* varType=0,
01254 const CvMat* missingDataMask=0,
01255 CvDTreeParams params=CvDTreeParams() );
01256 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01257
01258 virtual void read( CvFileStorage* fs, CvFileNode* node );
01259 virtual void read( CvFileStorage* fs, CvFileNode* node,
01260 CvDTreeTrainData* data );
01261
01262
01263 protected:
01264
01265 virtual void try_split_node( CvDTreeNode* n );
01266 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01267 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01268 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
01269 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01270 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01271 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01272 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
01273 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01274 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
01275 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01276 virtual void calc_node_value( CvDTreeNode* n );
01277 virtual double calc_node_dir( CvDTreeNode* n );
01278
01279 CvBoost* ensemble;
01280 };
01281
01282
01283 class CV_EXPORTS_W CvBoost : public CvStatModel
01284 {
01285 public:
01286
01287 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01288
01289
01290 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01291
01292 CV_WRAP CvBoost();
01293 virtual ~CvBoost();
01294
01295 CvBoost( const CvMat* trainData, int tflag,
01296 const CvMat* responses, const CvMat* varIdx=0,
01297 const CvMat* sampleIdx=0, const CvMat* varType=0,
01298 const CvMat* missingDataMask=0,
01299 CvBoostParams params=CvBoostParams() );
01300
01301 virtual bool train( const CvMat* trainData, int tflag,
01302 const CvMat* responses, const CvMat* varIdx=0,
01303 const CvMat* sampleIdx=0, const CvMat* varType=0,
01304 const CvMat* missingDataMask=0,
01305 CvBoostParams params=CvBoostParams(),
01306 bool update=false );
01307
01308 virtual bool train( CvMLData* data,
01309 CvBoostParams params=CvBoostParams(),
01310 bool update=false );
01311
01312 virtual float predict( const CvMat* sample, const CvMat* missing=0,
01313 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01314 bool raw_mode=false, bool return_sum=false ) const;
01315
01316 #ifndef SWIG
01317 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
01318 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01319 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01320 const cv::Mat& missingDataMask=cv::Mat(),
01321 CvBoostParams params=CvBoostParams() );
01322
01323 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01324 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01325 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01326 const cv::Mat& missingDataMask=cv::Mat(),
01327 CvBoostParams params=CvBoostParams(),
01328 bool update=false );
01329
01330 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01331 const cv::Range& slice=cv::Range::all(), bool rawMode=false,
01332 bool returnSum=false ) const;
01333 #endif
01334
01335 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 );
01336
01337 CV_WRAP virtual void prune( CvSlice slice );
01338
01339 CV_WRAP virtual void clear();
01340
01341 virtual void write( CvFileStorage* storage, const char* name ) const;
01342 virtual void read( CvFileStorage* storage, CvFileNode* node );
01343 virtual const CvMat* get_active_vars(bool absolute_idx=true);
01344
01345 CvSeq* get_weak_predictors();
01346
01347 CvMat* get_weights();
01348 CvMat* get_subtree_weights();
01349 CvMat* get_weak_response();
01350 const CvBoostParams& get_params() const;
01351 const CvDTreeTrainData* get_data() const;
01352
01353 protected:
01354
01355 virtual bool set_params( const CvBoostParams& params );
01356 virtual void update_weights( CvBoostTree* tree );
01357 virtual void trim_weights();
01358 virtual void write_params( CvFileStorage* fs ) const;
01359 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01360
01361 CvDTreeTrainData* data;
01362 CvBoostParams params;
01363 CvSeq* weak;
01364
01365 CvMat* active_vars;
01366 CvMat* active_vars_abs;
01367 bool have_active_cat_vars;
01368
01369 CvMat* orig_response;
01370 CvMat* sum_response;
01371 CvMat* weak_eval;
01372 CvMat* subsample_mask;
01373 CvMat* weights;
01374 CvMat* subtree_weights;
01375 bool have_subsample;
01376 };
01377
01378
01379
01380
01381
01382
01383
01384
01385
01386
01387
01388
01389
01390
01391
01392
01393
01394
01395
01396
01397
01398
01399
01400 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
01401 {
01402 CV_PROP_RW int weak_count;
01403 CV_PROP_RW int loss_function_type;
01404 CV_PROP_RW float subsample_portion;
01405 CV_PROP_RW float shrinkage;
01406
01407 CvGBTreesParams();
01408 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
01409 float subsample_portion, int max_depth, bool use_surrogates );
01410 };
01411
01412
01413
01414
01415
01416
01417
01418
01419
01420
01421
01422
01423
01424
01425
01426
01427
01428
01429
01430
01431
01432
01433
01434
01435
01436
01437
01438
01439
01440
01441
01442
01443
01444
01445
01446
01447
01448
01449
01450
01451
01452
01453
01454
01455
01456
01457
01458
01459
01460 class CV_EXPORTS_W CvGBTrees : public CvStatModel
01461 {
01462 public:
01463
01464
01465
01466
01467
01468
01469
01470
01471
01472
01473
01474
01475
01476
01477
01478
01479
01480
01481
01482
01483
01484
01485
01486
01487 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
01488
01489
01490
01491
01492
01493
01494
01495
01496
01497
01498
01499
01500
01501 CV_WRAP CvGBTrees();
01502
01503
01504
01505
01506
01507
01508
01509
01510
01511
01512
01513
01514
01515
01516
01517
01518
01519
01520
01521
01522
01523
01524
01525
01526
01527
01528
01529
01530
01531
01532
01533
01534
01535
01536
01537
01538
01539
01540
01541 CvGBTrees( const CvMat* trainData, int tflag,
01542 const CvMat* responses, const CvMat* varIdx=0,
01543 const CvMat* sampleIdx=0, const CvMat* varType=0,
01544 const CvMat* missingDataMask=0,
01545 CvGBTreesParams params=CvGBTreesParams() );
01546
01547
01548
01549
01550
01551 virtual ~CvGBTrees();
01552
01553
01554
01555
01556
01557
01558
01559
01560
01561
01562
01563
01564
01565
01566
01567
01568
01569
01570
01571
01572
01573
01574
01575
01576
01577
01578
01579
01580
01581
01582
01583
01584
01585
01586
01587
01588
01589
01590
01591
01592
01593 virtual bool train( const CvMat* trainData, int tflag,
01594 const CvMat* responses, const CvMat* varIdx=0,
01595 const CvMat* sampleIdx=0, const CvMat* varType=0,
01596 const CvMat* missingDataMask=0,
01597 CvGBTreesParams params=CvGBTreesParams(),
01598 bool update=false );
01599
01600
01601
01602
01603
01604
01605
01606
01607
01608
01609
01610
01611
01612
01613
01614
01615
01616
01617 virtual bool train( CvMLData* data,
01618 CvGBTreesParams params=CvGBTreesParams(),
01619 bool update=false );
01620
01621
01622
01623
01624
01625
01626
01627
01628
01629
01630
01631
01632
01633
01634
01635
01636
01637
01638
01639
01640
01641
01642
01643
01644
01645
01646
01647
01648
01649 virtual float predict( const CvMat* sample, const CvMat* missing=0,
01650 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01651 int k=-1 ) const;
01652
01653
01654
01655
01656
01657
01658
01659
01660
01661
01662
01663
01664
01665
01666
01667 CV_WRAP virtual void clear();
01668
01669
01670
01671
01672
01673
01674
01675
01676
01677
01678
01679
01680
01681
01682
01683
01684
01685 virtual float calc_error( CvMLData* _data, int type,
01686 std::vector<float> *resp = 0 );
01687
01688
01689
01690
01691
01692
01693
01694
01695
01696
01697
01698
01699
01700
01701
01702 virtual void write( CvFileStorage* fs, const char* name ) const;
01703
01704
01705
01706
01707
01708
01709
01710
01711
01712
01713
01714
01715
01716
01717
01718 virtual void read( CvFileStorage* fs, CvFileNode* node );
01719
01720
01721
01722 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
01723 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01724 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01725 const cv::Mat& missingDataMask=cv::Mat(),
01726 CvGBTreesParams params=CvGBTreesParams() );
01727
01728 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01729 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01730 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01731 const cv::Mat& missingDataMask=cv::Mat(),
01732 CvGBTreesParams params=CvGBTreesParams(),
01733 bool update=false );
01734
01735 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01736 const cv::Range& slice = cv::Range::all(),
01737 int k=-1 ) const;
01738
01739 protected:
01740
01741
01742
01743
01744
01745
01746
01747
01748
01749
01750
01751
01752
01753
01754
01755
01756 virtual void find_gradient( const int k = 0);
01757
01758
01759
01760
01761
01762
01763
01764
01765
01766
01767
01768
01769
01770
01771
01772
01773
01774
01775 virtual void change_values(CvDTree* tree, const int k = 0);
01776
01777
01778
01779
01780
01781
01782
01783
01784
01785
01786
01787
01788
01789
01790
01791
01792
01793
01794 virtual float find_optimal_value( const CvMat* _Idx );
01795
01796
01797
01798
01799
01800
01801
01802
01803
01804
01805
01806
01807
01808
01809
01810
01811 virtual void do_subsample();
01812
01813
01814
01815
01816
01817
01818
01819
01820
01821
01822
01823
01824
01825
01826
01827
01828 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01829
01830
01831
01832
01833
01834
01835
01836
01837
01838
01839
01840
01841
01842
01843
01844
01845 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01846
01847
01848
01849
01850
01851
01852
01853
01854
01855
01856
01857
01858
01859
01860
01861 virtual bool problem_type() const;
01862
01863
01864
01865
01866
01867
01868
01869
01870
01871
01872
01873
01874
01875
01876 virtual void write_params( CvFileStorage* fs ) const;
01877
01878
01879
01880
01881
01882
01883
01884
01885
01886
01887
01888
01889
01890
01891
01892
01893
01894
01895
01896 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
01897
01898
01899 CvDTreeTrainData* data;
01900 CvGBTreesParams params;
01901
01902 CvSeq** weak;
01903 CvMat* orig_response;
01904 CvMat* sum_response;
01905 CvMat* sum_response_tmp;
01906 CvMat* weak_eval;
01907 CvMat* sample_idx;
01908 CvMat* subsample_train;
01909 CvMat* subsample_test;
01910 CvMat* missing;
01911 CvMat* class_labels;
01912
01913 cv::RNG* rng;
01914
01915 int class_count;
01916 float delta;
01917 float base_value;
01918
01919 };
01920
01921
01922
01923
01924
01925
01926
01928
01929 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
01930 {
01931 CvANN_MLP_TrainParams();
01932 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01933 double param1, double param2=0 );
01934 ~CvANN_MLP_TrainParams();
01935
01936 enum { BACKPROP=0, RPROP=1 };
01937
01938 CV_PROP_RW CvTermCriteria term_crit;
01939 CV_PROP_RW int train_method;
01940
01941
01942 CV_PROP_RW double bp_dw_scale, bp_moment_scale;
01943
01944
01945 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01946 };
01947
01948
01949 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
01950 {
01951 public:
01952 CV_WRAP CvANN_MLP();
01953 CvANN_MLP( const CvMat* layerSizes,
01954 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01955 double fparam1=0, double fparam2=0 );
01956
01957 virtual ~CvANN_MLP();
01958
01959 virtual void create( const CvMat* layerSizes,
01960 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01961 double fparam1=0, double fparam2=0 );
01962
01963 virtual int train( const CvMat* inputs, const CvMat* outputs,
01964 const CvMat* sampleWeights, const CvMat* sampleIdx=0,
01965 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01966 int flags=0 );
01967 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
01968
01969 #ifndef SWIG
01970 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
01971 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01972 double fparam1=0, double fparam2=0 );
01973
01974 CV_WRAP virtual void create( const cv::Mat& layerSizes,
01975 int activateFunc=CvANN_MLP::SIGMOID_SYM,
01976 double fparam1=0, double fparam2=0 );
01977
01978 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
01979 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
01980 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01981 int flags=0 );
01982
01983 CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const;
01984 #endif
01985
01986 CV_WRAP virtual void clear();
01987
01988
01989 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01990
01991
01992 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01993
01994 virtual void read( CvFileStorage* fs, CvFileNode* node );
01995 virtual void write( CvFileStorage* storage, const char* name ) const;
01996
01997 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01998 const CvMat* get_layer_sizes() { return layer_sizes; }
01999 double* get_weights(int layer)
02000 {
02001 return layer_sizes && weights &&
02002 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
02003 }
02004
02005 protected:
02006
02007 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
02008 const CvMat* _sample_weights, const CvMat* sampleIdx,
02009 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
02010
02011
02012 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
02013
02014
02015 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
02016
02017 virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
02018 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
02019 virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
02020 double _f_param1=0, double _f_param2=0 );
02021 virtual void init_weights();
02022 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
02023 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
02024 virtual void calc_input_scale( const CvVectors* vecs, int flags );
02025 virtual void calc_output_scale( const CvVectors* vecs, int flags );
02026
02027 virtual void write_params( CvFileStorage* fs ) const;
02028 virtual void read_params( CvFileStorage* fs, CvFileNode* node );
02029
02030 CvMat* layer_sizes;
02031 CvMat* wbuf;
02032 CvMat* sample_weights;
02033 double** weights;
02034 double f_param1, f_param2;
02035 double min_val, max_val, min_val1, max_val1;
02036 int activ_func;
02037 int max_count, max_buf_sz;
02038 CvANN_MLP_TrainParams params;
02039 cv::RNG* rng;
02040 };
02041
02042
02043
02044
02045
02046
02047
02048 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
02049 CvRNG* rng CV_DEFAULT(0) );
02050
02051
02052 CVAPI(void) cvRandGaussMixture( CvMat* means[],
02053 CvMat* covs[],
02054 float weights[],
02055 int clsnum,
02056 CvMat* sample,
02057 CvMat* sampClasses CV_DEFAULT(0) );
02058
02059 #define CV_TS_CONCENTRIC_SPHERES 0
02060
02061
02062 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
02063 int num_samples,
02064 int num_features,
02065 CvMat** responses,
02066 int num_classes, ... );
02067
02068
02069 #endif
02070
02071
02072
02073
02074
02075 #include <map>
02076 #include <string>
02077 #include <iostream>
02078
02079 #define CV_COUNT 0
02080 #define CV_PORTION 1
02081
02082 struct CV_EXPORTS CvTrainTestSplit
02083 {
02084 public:
02085 CvTrainTestSplit();
02086 CvTrainTestSplit( int _train_sample_count, bool _mix = true);
02087 CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
02088
02089 union
02090 {
02091 int count;
02092 float portion;
02093 } train_sample_part;
02094 int train_sample_part_mode;
02095
02096 union
02097 {
02098 int *count;
02099 float *portion;
02100 } *class_part;
02101 int class_part_mode;
02102
02103 bool mix;
02104 };
02105
02106 class CV_EXPORTS CvMLData
02107 {
02108 public:
02109 CvMLData();
02110 virtual ~CvMLData();
02111
02112
02113
02114
02115 int read_csv(const char* filename);
02116
02117 const CvMat* get_values(){ return values; };
02118
02119 const CvMat* get_responses();
02120
02121 const CvMat* get_missing(){ return missing; };
02122
02123 void set_response_idx( int idx );
02124
02125 int get_response_idx() { return response_idx; }
02126
02127 const CvMat* get_train_sample_idx() { return train_sample_idx; };
02128 const CvMat* get_test_sample_idx() { return test_sample_idx; };
02129 void mix_train_and_test_idx();
02130 void set_train_test_split( const CvTrainTestSplit * spl);
02131
02132 const CvMat* get_var_idx();
02133 void chahge_var_idx( int vi, bool state );
02134
02135 const CvMat* get_var_types();
02136 int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
02137
02138
02139
02140 void set_var_types( const char* str );
02141
02142
02143 void change_var_type( int var_idx, int type);
02144
02145 void set_delimiter( char ch );
02146 char get_delimiter() { return delimiter; };
02147
02148 void set_miss_ch( char ch );
02149 char get_miss_ch() { return miss_ch; };
02150
02151 protected:
02152 virtual void clear();
02153
02154 void str_to_flt_elem( const char* token, float& flt_elem, int& type);
02155 void free_train_test_idx();
02156
02157 char delimiter;
02158 char miss_ch;
02159
02160
02161 CvMat* values;
02162 CvMat* missing;
02163 CvMat* var_types;
02164 CvMat* var_idx_mask;
02165
02166 CvMat* response_out;
02167 CvMat* var_idx_out;
02168 CvMat* var_types_out;
02169
02170 int response_idx;
02171
02172 int train_sample_count;
02173 bool mix;
02174
02175 int total_class_count;
02176 std::map<std::string, int> *class_map;
02177
02178 CvMat* train_sample_idx;
02179 CvMat* test_sample_idx;
02180 int* sample_idx;
02181
02182 cv::RNG* rng;
02183 };
02184
02185
02186 namespace cv
02187 {
02188
02189 typedef CvStatModel StatModel;
02190 typedef CvParamGrid ParamGrid;
02191 typedef CvNormalBayesClassifier NormalBayesClassifier;
02192 typedef CvKNearest KNearest;
02193 typedef CvSVMParams SVMParams;
02194 typedef CvSVMKernel SVMKernel;
02195 typedef CvSVMSolver SVMSolver;
02196 typedef CvSVM SVM;
02197 typedef CvEMParams EMParams;
02198 typedef CvEM ExpectationMaximization;
02199 typedef CvDTreeParams DTreeParams;
02200 typedef CvMLData TrainData;
02201 typedef CvDTree DecisionTree;
02202 typedef CvForestTree ForestTree;
02203 typedef CvRTParams RandomTreeParams;
02204 typedef CvRTrees RandomTrees;
02205 typedef CvERTreeTrainData ERTreeTRainData;
02206 typedef CvForestERTree ERTree;
02207 typedef CvERTrees ERTrees;
02208 typedef CvBoostParams BoostParams;
02209 typedef CvBoostTree BoostTree;
02210 typedef CvBoost Boost;
02211 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
02212 typedef CvANN_MLP NeuralNet_MLP;
02213 typedef CvGBTreesParams GradientBoostingTreeParams;
02214 typedef CvGBTrees GradientBoostingTrees;
02215
02216 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
02217
02218 }
02219
02220 #endif
02221