Mercurial Hosting > traffic-intelligence
diff python/ml.py @ 680:da1352b89d02 dev
classification is working
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Fri, 05 Jun 2015 02:25:30 +0200 |
parents | 15e244d2a1b5 |
children | b02431a8234c |
line wrap: on
line diff
--- a/python/ml.py Wed Jun 03 16:00:46 2015 +0200 +++ b/python/ml.py Fri Jun 05 02:25:30 2015 +0200 @@ -6,25 +6,29 @@ class Model(object): '''Abstract class for loading/saving model''' - def load(self, fn): - self.model.load(fn) + def load(self, filename): + from os import path + if path.exists(filename): + self.model.load(filename) + else: + print('Provided filename {} does not exist: model not loaded!'.format(filename)) - def save(self, fn): - self.model.save(fn) + def save(self, filename): + self.model.save(filename) class SVM(Model): '''wrapper for OpenCV SimpleVectorMachine algorithm''' - def __init__(self, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): + def __init__(self): import cv2 self.model = cv2.SVM() + + def train(self, samples, responses, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p) - - def train(self, samples, responses): self.model.train(samples, responses, params = self.params) - def predict(self, samples): - return np.float32([self.model.predict(s) for s in samples]) + def predict(self, hog): + return self.model.predict(hog) class Centroid(object):