diff scripts/train-object-classification.py @ 680:da1352b89d02 dev

classification is working
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 05 Jun 2015 02:25:30 +0200
parents ce40a89bd6ae
children 5b970a5bc233
line wrap: on
line diff
--- a/scripts/train-object-classification.py	Wed Jun 03 16:00:46 2015 +0200
+++ b/scripts/train-object-classification.py	Fri Jun 05 02:25:30 2015 +0200
@@ -36,31 +36,32 @@
 
 for k, v in imageDirectories.iteritems():
     print('Loading {} samples'.format(k))
-    trainingSamplesPBV[k], trainingLabelsPBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+    trainingSamples, trainingLabels = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+    trainingSamplesPBV[k], trainingLabelsPBV[k] = trainingSamples, trainingLabels
     if k != 'pedestrian':
-	trainingSamplesBV[k], trainingLabelsBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+	trainingSamplesBV[k], trainingLabelsBV[k] = trainingSamples, trainingLabels
     if k != 'car':
-	trainingSamplesPB[k], trainingLabelsPB[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+	trainingSamplesPB[k], trainingLabelsPB[k] = trainingSamples, trainingLabels
     if k != 'bicycle':
-	trainingSamplesPV[k], trainingLabelsPV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+	trainingSamplesPV[k], trainingLabelsPV[k] = trainingSamples, trainingLabels
 
 # Training the Support Vector Machine
 print "Training Pedestrian-Cyclist-Vehicle Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelPBV.xml")
 
 print "Training Cyclist-Vehicle Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelBV.xml")
 
 print "Training Pedestrian-Cyclist Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelPB.xml")
 
 print "Training Pedestrian-Vehicle Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelPV.xml")