diff python/ml.py @ 993:e8eabef7857c

update to OpenCV3 for python
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 16 May 2018 21:06:52 -0400
parents 23f98ebb113f
children 933670761a57
line wrap: on
line diff
--- a/python/ml.py	Fri Apr 13 16:48:02 2018 -0400
+++ b/python/ml.py	Wed May 16 21:06:52 2018 -0400
@@ -25,50 +25,52 @@
 #####################
 
 def computeConfusionMatrix(model, samples, responses):
-    'computes the confusion matrix of the classifier (model)'
+    '''computes the confusion matrix of the classifier (model)
+
+    samples should be n samples by m variables'''
     classifications = {}
-    for x,y in zip(samples, responses):
-        predicted = model.predict(x)
+    predictions = model.predict(samples)
+    for predicted, y in zip(predictions, responses):
         classifications[(y, predicted)] = classifications.get((y, predicted), 0)+1
     return classifications
 
-class StatModel(object):
-    '''Abstract class for loading/saving model'''    
-    def load(self, filename):
-        if path.exists(filename):
-            self.model.load(filename)
-        else:
-            print('Provided filename {} does not exist: model not loaded!'.format(filename))
-
-    def save(self, filename):
-        self.model.save(filename)
-
 if opencvAvailable:
-    class SVM(StatModel):
+    class SVM(object):
         '''wrapper for OpenCV SimpleVectorMachine algorithm'''
-        def __init__(self, svmType = cv2.SVM_C_SVC, kernelType = cv2.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
-            self.model = cv2.SVM()
-            self.params = dict(svm_type = svmType, kernel_type = kernelType, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p)
-            # OpenCV3
-            # self.model = cv2.SVM()
-            # self.model.setType(svmType)
-            # self.model.setKernel(kernelType)
-            # self.model.setDegree(degree)
-            # self.model.setGamma(gamma)
-            # self.model.setCoef0(coef0)
-            # self.model.setC(Cvalue)
-            # self.model.setNu(nu)
-            # self.model.setP(p)
+        def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
+            self.model = cv2.ml.SVM_create()
+            self.model.setType(svmType)
+            self.model.setKernel(kernelType)
+            self.model.setDegree(degree)
+            self.model.setGamma(gamma)
+            self.model.setCoef0(coef0)
+            self.model.setC(Cvalue)
+            self.model.setNu(nu)
+            self.model.setP(p)
 
-        def train(self, samples, responses, computePerformance = False):
-            self.model.train(samples, responses, params = self.params)
+        def save(self, filename):
+            self.model.save(filename)
+            
+        def train(self, samples, layout, responses, computePerformance = False):
+            self.model.train(samples, layout, responses)
             if computePerformance:
                 return computeConfusionMatrix(self, samples, responses)
 
         def predict(self, hog):
-            return self.model.predict(hog)
+            retval, predictions = self.model.predict(hog)
+            if hog.shape[0] == 1:
+                return predictions[0][0]
+            else:
+                return np.asarray(predictions, dtype = np.int).ravel().tolist()
 
-
+    def SVM_load(filename):
+        if path.exists(filename):
+            svm = SVM()
+            svm.model = cv2.ml.SVM_load(filename)
+            return svm
+        else:
+            print('Provided filename {} does not exist: model not loaded!'.format(filename))
+        
 #####################
 # Clustering
 #####################