Mercurial Hosting > traffic-intelligence
comparison python/ml.py @ 680:da1352b89d02 dev
classification is working
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Fri, 05 Jun 2015 02:25:30 +0200 |
parents | 15e244d2a1b5 |
children | b02431a8234c |
comparison
equal
deleted
inserted
replaced
678:97c305108460 | 680:da1352b89d02 |
---|---|
4 import numpy as np | 4 import numpy as np |
5 | 5 |
6 | 6 |
7 class Model(object): | 7 class Model(object): |
8 '''Abstract class for loading/saving model''' | 8 '''Abstract class for loading/saving model''' |
9 def load(self, fn): | 9 def load(self, filename): |
10 self.model.load(fn) | 10 from os import path |
11 if path.exists(filename): | |
12 self.model.load(filename) | |
13 else: | |
14 print('Provided filename {} does not exist: model not loaded!'.format(filename)) | |
11 | 15 |
12 def save(self, fn): | 16 def save(self, filename): |
13 self.model.save(fn) | 17 self.model.save(filename) |
14 | 18 |
15 class SVM(Model): | 19 class SVM(Model): |
16 '''wrapper for OpenCV SimpleVectorMachine algorithm''' | 20 '''wrapper for OpenCV SimpleVectorMachine algorithm''' |
17 | 21 |
18 def __init__(self, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): | 22 def __init__(self): |
19 import cv2 | 23 import cv2 |
20 self.model = cv2.SVM() | 24 self.model = cv2.SVM() |
25 | |
26 def train(self, samples, responses, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): | |
21 self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p) | 27 self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p) |
22 | |
23 def train(self, samples, responses): | |
24 self.model.train(samples, responses, params = self.params) | 28 self.model.train(samples, responses, params = self.params) |
25 | 29 |
26 def predict(self, samples): | 30 def predict(self, hog): |
27 return np.float32([self.model.predict(s) for s in samples]) | 31 return self.model.predict(hog) |
28 | 32 |
29 | 33 |
30 class Centroid(object): | 34 class Centroid(object): |
31 'Wrapper around instances to add a counter' | 35 'Wrapper around instances to add a counter' |
32 | 36 |