comparison python/ml.py @ 993:e8eabef7857c

update to OpenCV3 for python
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 16 May 2018 21:06:52 -0400
parents 23f98ebb113f
children 933670761a57
comparison
equal deleted inserted replaced
992:2cd1ce245024 993:e8eabef7857c
23 ##################### 23 #####################
24 # OpenCV ML models 24 # OpenCV ML models
25 ##################### 25 #####################
26 26
27 def computeConfusionMatrix(model, samples, responses): 27 def computeConfusionMatrix(model, samples, responses):
28 'computes the confusion matrix of the classifier (model)' 28 '''computes the confusion matrix of the classifier (model)
29
30 samples should be n samples by m variables'''
29 classifications = {} 31 classifications = {}
30 for x,y in zip(samples, responses): 32 predictions = model.predict(samples)
31 predicted = model.predict(x) 33 for predicted, y in zip(predictions, responses):
32 classifications[(y, predicted)] = classifications.get((y, predicted), 0)+1 34 classifications[(y, predicted)] = classifications.get((y, predicted), 0)+1
33 return classifications 35 return classifications
34 36
35 class StatModel(object): 37 if opencvAvailable:
36 '''Abstract class for loading/saving model''' 38 class SVM(object):
37 def load(self, filename): 39 '''wrapper for OpenCV SimpleVectorMachine algorithm'''
40 def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
41 self.model = cv2.ml.SVM_create()
42 self.model.setType(svmType)
43 self.model.setKernel(kernelType)
44 self.model.setDegree(degree)
45 self.model.setGamma(gamma)
46 self.model.setCoef0(coef0)
47 self.model.setC(Cvalue)
48 self.model.setNu(nu)
49 self.model.setP(p)
50
51 def save(self, filename):
52 self.model.save(filename)
53
54 def train(self, samples, layout, responses, computePerformance = False):
55 self.model.train(samples, layout, responses)
56 if computePerformance:
57 return computeConfusionMatrix(self, samples, responses)
58
59 def predict(self, hog):
60 retval, predictions = self.model.predict(hog)
61 if hog.shape[0] == 1:
62 return predictions[0][0]
63 else:
64 return np.asarray(predictions, dtype = np.int).ravel().tolist()
65
66 def SVM_load(filename):
38 if path.exists(filename): 67 if path.exists(filename):
39 self.model.load(filename) 68 svm = SVM()
69 svm.model = cv2.ml.SVM_load(filename)
70 return svm
40 else: 71 else:
41 print('Provided filename {} does not exist: model not loaded!'.format(filename)) 72 print('Provided filename {} does not exist: model not loaded!'.format(filename))
42 73
43 def save(self, filename):
44 self.model.save(filename)
45
46 if opencvAvailable:
47 class SVM(StatModel):
48 '''wrapper for OpenCV SimpleVectorMachine algorithm'''
49 def __init__(self, svmType = cv2.SVM_C_SVC, kernelType = cv2.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0):
50 self.model = cv2.SVM()
51 self.params = dict(svm_type = svmType, kernel_type = kernelType, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p)
52 # OpenCV3
53 # self.model = cv2.SVM()
54 # self.model.setType(svmType)
55 # self.model.setKernel(kernelType)
56 # self.model.setDegree(degree)
57 # self.model.setGamma(gamma)
58 # self.model.setCoef0(coef0)
59 # self.model.setC(Cvalue)
60 # self.model.setNu(nu)
61 # self.model.setP(p)
62
63 def train(self, samples, responses, computePerformance = False):
64 self.model.train(samples, responses, params = self.params)
65 if computePerformance:
66 return computeConfusionMatrix(self, samples, responses)
67
68 def predict(self, hog):
69 return self.model.predict(hog)
70
71
72 ##################### 74 #####################
73 # Clustering 75 # Clustering
74 ##################### 76 #####################
75 77
76 class Centroid(object): 78 class Centroid(object):