comparison python/ml.py @ 788:5b970a5bc233 dev

updated classifying code to OpenCV 3.x (bug in function to load classification models)
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Thu, 24 Mar 2016 16:37:37 -0400
parents 0a428b449b80
children 1158a6e2d28e
comparison
equal deleted inserted replaced
787:0a428b449b80 788:5b970a5bc233
9 from matplotlib.pylab import text 9 from matplotlib.pylab import text
10 import matplotlib as mpl 10 import matplotlib as mpl
11 import matplotlib.pyplot as plt 11 import matplotlib.pyplot as plt
12 from scipy.cluster.vq import kmeans, whiten, vq 12 from scipy.cluster.vq import kmeans, whiten, vq
13 from sklearn import mixture 13 from sklearn import mixture
14 import cv2
14 15
15 import utils 16 import utils
16 17
17 ##################### 18 #####################
18 # OpenCV ML models 19 # OpenCV ML models
19 ##################### 20 #####################
20 21
21 class Model(object): 22 class StatModel(object):
22 '''Abstract class for loading/saving model''' 23 '''Abstract class for loading/saving model'''
23 def load(self, filename): 24 def load(self, filename):
24 if path.exists(filename): 25 if path.exists(filename):
25 self.model.load(filename) 26 self.model.load(filename)
26 else: 27 else:
27 print('Provided filename {} does not exist: model not loaded!'.format(filename)) 28 print('Provided filename {} does not exist: model not loaded!'.format(filename))
28 29
29 def save(self, filename): 30 def save(self, filename):
30 self.model.save(filename) 31 self.model.save(filename)
31 32
32 class SVM(Model): 33 class SVM(StatModel):
33 '''wrapper for OpenCV SimpleVectorMachine algorithm''' 34 '''wrapper for OpenCV SimpleVectorMachine algorithm'''
34 35 def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
35 def __init__(self): 36 self.model = cv2.ml.SVM_create()
36 import cv2 37 self.model.setType(svmType)
37 self.model = cv2.SVM() 38 self.model.setKernel(kernelType)
38 39 self.model.setDegree(degree)
39 def train(self, samples, responses, svm_type, kernel_type, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): 40 self.model.setGamma(gamma)
40 self.params = dict(svm_type = svm_type, kernel_type = kernel_type, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p) 41 self.model.setCoef0(coef0)
41 self.model.train(samples, responses, params = self.params) 42 self.model.setC(Cvalue)
43 self.model.setNu(nu)
44 self.model.setP(p)
45
46 def train(self, samples, layout, responses):
47 self.model.train(samples, layout, responses)
42 48
43 def predict(self, hog): 49 def predict(self, hog):
44 return self.model.predict(hog) 50 return self.model.predict(hog)
45 51
46 52