changeset 791:1158a6e2d28e dev

temporary solution for classification, with corrected svm.cpp and ml.hpp for loading saved classifiers
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 29 Apr 2016 16:07:35 -0400
parents 5b970a5bc233
children ee3433fc0026
files python/ml.py scripts/classify-objects.py
diffstat 2 files changed, 13 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/python/ml.py	Thu Mar 24 16:37:37 2016 -0400
+++ b/python/ml.py	Fri Apr 29 16:07:35 2016 -0400
@@ -20,7 +20,9 @@
 #####################
 
 class StatModel(object):
-    '''Abstract class for loading/saving model'''    
+    '''Abstract class for loading/saving model
+
+    Issues with OpenCV, does not seem to work'''    
     def load(self, filename):
         if path.exists(filename):
             self.model.load(filename)
@@ -43,6 +45,12 @@
         self.model.setNu(nu)
         self.model.setP(p)
 
+    def load(self, filename):
+        if path.exists(filename):
+            cv2.ml.SVM_load(filename)
+        else:
+            print('Provided filename {} does not exist: model not loaded!'.format(filename))
+
     def train(self, samples, layout, responses):
         self.model.train(samples, layout, responses)
 
--- a/scripts/classify-objects.py	Thu Mar 24 16:37:37 2016 -0400
+++ b/scripts/classify-objects.py	Fri Apr 29 16:07:35 2016 -0400
@@ -4,7 +4,7 @@
 
 import numpy as np
 import sys, argparse
-from cv2 import SVM_RBF, SVM_C_SVC
+from cv2.ml import SVM_RBF, SVM_C_SVC
 import cv2
 from scipy.stats import norm, lognorm
 
@@ -35,6 +35,8 @@
 params.convertToFrames(3.6)
 if params.homography is not None:
     invHomography = np.linalg.inv(params.homography)
+else:
+    invHomography = None
 
 if params.speedAggregationMethod == 'median':
     speedAggregationFunc = np.median
@@ -74,7 +76,7 @@
 for obj in objects:
     #obj.setFeatures(features)
     intervals.append(obj.getTimeInterval())
-timeInterval = moving.unionIntervals(intervals)
+timeInterval = moving.TimeInterval.unionIntervals(intervals)
 
 capture = cv2.VideoCapture(videoFilename)
 width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))