diff python/ml.py @ 807:52aa03260f03 opencv3

reversed all code to OpenCV 2.4.13
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 10 Jun 2016 15:26:19 -0400
parents 180b6b0231c0
children 5dc7a507353e
line wrap: on
line diff
--- a/python/ml.py	Fri Jun 10 12:29:58 2016 -0400
+++ b/python/ml.py	Fri Jun 10 15:26:19 2016 -0400
@@ -20,9 +20,7 @@
 #####################
 
 class StatModel(object):
-    '''Abstract class for loading/saving model
-
-    Issues with OpenCV, does not seem to work'''    
+    '''Abstract class for loading/saving model'''    
     def load(self, filename):
         if path.exists(filename):
             self.model.load(filename)
@@ -34,25 +32,22 @@
 
 class SVM(StatModel):
     '''wrapper for OpenCV SimpleVectorMachine algorithm'''
-    def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
-        self.model = cv2.ml.SVM_create()
-        self.model.setType(svmType)
-        self.model.setKernel(kernelType)
-        self.model.setDegree(degree)
-        self.model.setGamma(gamma)
-        self.model.setCoef0(coef0)
-        self.model.setC(Cvalue)
-        self.model.setNu(nu)
-        self.model.setP(p)
+    def __init__(self, svmType = cv2.SVM_C_SVC, kernelType = cv2.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
+        self.model = cv2.SVM()
+        self.params = dict(svm_type = svmType, kernel_type = kernelType, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p)
+        # OpenCV3
+        # self.model = cv2.SVM()
+        # self.model.setType(svmType)
+        # self.model.setKernel(kernelType)
+        # self.model.setDegree(degree)
+        # self.model.setGamma(gamma)
+        # self.model.setCoef0(coef0)
+        # self.model.setC(Cvalue)
+        # self.model.setNu(nu)
+        # self.model.setP(p)
 
-    def load(self, filename):
-        if path.exists(filename):
-            cv2.ml.SVM_load(filename)
-        else:
-            print('Provided filename {} does not exist: model not loaded!'.format(filename))
-
-    def train(self, samples, layout, responses):
-        self.model.train(samples, layout, responses)
+    def train(self, samples, responses):
+        self.model.train(samples, responses, params = self.params)
 
     def predict(self, hog):
         return self.model.predict(hog)