diff python/ml.py @ 788:5b970a5bc233 dev

updated classifying code to OpenCV 3.x (bug in function to load classification models)
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Thu, 24 Mar 2016 16:37:37 -0400
parents 0a428b449b80
children 1158a6e2d28e
line wrap: on
line diff
--- a/python/ml.py	Thu Mar 17 16:01:19 2016 -0400
+++ b/python/ml.py	Thu Mar 24 16:37:37 2016 -0400
@@ -11,6 +11,7 @@
 import matplotlib.pyplot as plt
 from scipy.cluster.vq import kmeans, whiten, vq
 from sklearn import mixture
+import cv2
 
 import utils
 
@@ -18,7 +19,7 @@
 # OpenCV ML models
 #####################
 
-class Model(object):
+class StatModel(object):
     '''Abstract class for loading/saving model'''    
     def load(self, filename):
         if path.exists(filename):
@@ -29,16 +30,21 @@
     def save(self, filename):
         self.model.save(filename)
 
-class SVM(Model):
+class SVM(StatModel):
     '''wrapper for OpenCV SimpleVectorMachine algorithm'''
+    def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
+        self.model = cv2.ml.SVM_create()
+        self.model.setType(svmType)
+        self.model.setKernel(kernelType)
+        self.model.setDegree(degree)
+        self.model.setGamma(gamma)
+        self.model.setCoef0(coef0)
+        self.model.setC(Cvalue)
+        self.model.setNu(nu)
+        self.model.setP(p)
 
-    def __init__(self):
-        import cv2
-        self.model = cv2.SVM()
-
-    def train(self, samples, responses, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
-        self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p)
-        self.model.train(samples, responses, params = self.params)
+    def train(self, samples, layout, responses):
+        self.model.train(samples, layout, responses)
 
     def predict(self, hog):
         return self.model.predict(hog)