diff python/ml.py @ 680:da1352b89d02 dev

classification is working
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 05 Jun 2015 02:25:30 +0200
parents 15e244d2a1b5
children b02431a8234c
line wrap: on
line diff
--- a/python/ml.py	Wed Jun 03 16:00:46 2015 +0200
+++ b/python/ml.py	Fri Jun 05 02:25:30 2015 +0200
@@ -6,25 +6,29 @@
 
 class Model(object):
     '''Abstract class for loading/saving model'''    
-    def load(self, fn):
-        self.model.load(fn)
+    def load(self, filename):
+        from os import path
+        if path.exists(filename):
+            self.model.load(filename)
+        else:
+            print('Provided filename {} does not exist: model not loaded!'.format(filename))
 
-    def save(self, fn):
-        self.model.save(fn)
+    def save(self, filename):
+        self.model.save(filename)
 
 class SVM(Model):
     '''wrapper for OpenCV SimpleVectorMachine algorithm'''
 
-    def __init__(self, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
+    def __init__(self):
         import cv2
         self.model = cv2.SVM()
+
+    def train(self, samples, responses, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
         self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p)
-
-    def train(self, samples, responses):
         self.model.train(samples, responses, params = self.params)
 
-    def predict(self, samples):
-        return np.float32([self.model.predict(s) for s in samples])
+    def predict(self, hog):
+        return self.model.predict(hog)
 
 
 class Centroid(object):