Mercurial Hosting > traffic-intelligence
comparison python/ml.py @ 993:e8eabef7857c
update to OpenCV3 for python
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 16 May 2018 21:06:52 -0400 |
parents | 23f98ebb113f |
children | 933670761a57 |
comparison
equal
deleted
inserted
replaced
992:2cd1ce245024 | 993:e8eabef7857c |
---|---|
23 ##################### | 23 ##################### |
24 # OpenCV ML models | 24 # OpenCV ML models |
25 ##################### | 25 ##################### |
26 | 26 |
27 def computeConfusionMatrix(model, samples, responses): | 27 def computeConfusionMatrix(model, samples, responses): |
28 'computes the confusion matrix of the classifier (model)' | 28 '''computes the confusion matrix of the classifier (model) |
29 | |
30 samples should be n samples by m variables''' | |
29 classifications = {} | 31 classifications = {} |
30 for x,y in zip(samples, responses): | 32 predictions = model.predict(samples) |
31 predicted = model.predict(x) | 33 for predicted, y in zip(predictions, responses): |
32 classifications[(y, predicted)] = classifications.get((y, predicted), 0)+1 | 34 classifications[(y, predicted)] = classifications.get((y, predicted), 0)+1 |
33 return classifications | 35 return classifications |
34 | 36 |
35 class StatModel(object): | 37 if opencvAvailable: |
36 '''Abstract class for loading/saving model''' | 38 class SVM(object): |
37 def load(self, filename): | 39 '''wrapper for OpenCV SimpleVectorMachine algorithm''' |
40 def __init__(self, svmType = cv2.ml.SVM_C_SVC, kernelType = cv2.ml.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): | |
41 self.model = cv2.ml.SVM_create() | |
42 self.model.setType(svmType) | |
43 self.model.setKernel(kernelType) | |
44 self.model.setDegree(degree) | |
45 self.model.setGamma(gamma) | |
46 self.model.setCoef0(coef0) | |
47 self.model.setC(Cvalue) | |
48 self.model.setNu(nu) | |
49 self.model.setP(p) | |
50 | |
51 def save(self, filename): | |
52 self.model.save(filename) | |
53 | |
54 def train(self, samples, layout, responses, computePerformance = False): | |
55 self.model.train(samples, layout, responses) | |
56 if computePerformance: | |
57 return computeConfusionMatrix(self, samples, responses) | |
58 | |
59 def predict(self, hog): | |
60 retval, predictions = self.model.predict(hog) | |
61 if hog.shape[0] == 1: | |
62 return predictions[0][0] | |
63 else: | |
64 return np.asarray(predictions, dtype = np.int).ravel().tolist() | |
65 | |
66 def SVM_load(filename): | |
38 if path.exists(filename): | 67 if path.exists(filename): |
39 self.model.load(filename) | 68 svm = SVM() |
69 svm.model = cv2.ml.SVM_load(filename) | |
70 return svm | |
40 else: | 71 else: |
41 print('Provided filename {} does not exist: model not loaded!'.format(filename)) | 72 print('Provided filename {} does not exist: model not loaded!'.format(filename)) |
42 | 73 |
43 def save(self, filename): | |
44 self.model.save(filename) | |
45 | |
46 if opencvAvailable: | |
47 class SVM(StatModel): | |
48 '''wrapper for OpenCV SimpleVectorMachine algorithm''' | |
49 def __init__(self, svmType = cv2.SVM_C_SVC, kernelType = cv2.SVM_RBF, degree = 0, gamma = 1, coef0 = 0, Cvalue = 1, nu = 0, p = 0): | |
50 self.model = cv2.SVM() | |
51 self.params = dict(svm_type = svmType, kernel_type = kernelType, degree = degree, gamma = gamma, coef0 = coef0, Cvalue = Cvalue, nu = nu, p = p) | |
52 # OpenCV3 | |
53 # self.model = cv2.SVM() | |
54 # self.model.setType(svmType) | |
55 # self.model.setKernel(kernelType) | |
56 # self.model.setDegree(degree) | |
57 # self.model.setGamma(gamma) | |
58 # self.model.setCoef0(coef0) | |
59 # self.model.setC(Cvalue) | |
60 # self.model.setNu(nu) | |
61 # self.model.setP(p) | |
62 | |
63 def train(self, samples, responses, computePerformance = False): | |
64 self.model.train(samples, responses, params = self.params) | |
65 if computePerformance: | |
66 return computeConfusionMatrix(self, samples, responses) | |
67 | |
68 def predict(self, hog): | |
69 return self.model.predict(hog) | |
70 | |
71 | |
72 ##################### | 74 ##################### |
73 # Clustering | 75 # Clustering |
74 ##################### | 76 ##################### |
75 | 77 |
76 class Centroid(object): | 78 class Centroid(object): |