diff scripts/train-object-classification.py @ 708:a37c565f4b68

merged dev
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 22 Jul 2015 14:17:44 -0400
parents da1352b89d02
children 5b970a5bc233
line wrap: on
line diff
--- a/scripts/train-object-classification.py	Wed Jul 22 14:17:19 2015 -0400
+++ b/scripts/train-object-classification.py	Wed Jul 22 14:17:44 2015 -0400
@@ -36,31 +36,32 @@
 
 for k, v in imageDirectories.iteritems():
     print('Loading {} samples'.format(k))
-    trainingSamplesPBV[k], trainingLabelsPBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+    trainingSamples, trainingLabels = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+    trainingSamplesPBV[k], trainingLabelsPBV[k] = trainingSamples, trainingLabels
     if k != 'pedestrian':
-	trainingSamplesBV[k], trainingLabelsBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+	trainingSamplesBV[k], trainingLabelsBV[k] = trainingSamples, trainingLabels
     if k != 'car':
-	trainingSamplesPB[k], trainingLabelsPB[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+	trainingSamplesPB[k], trainingLabelsPB[k] = trainingSamples, trainingLabels
     if k != 'bicycle':
-	trainingSamplesPV[k], trainingLabelsPV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
+	trainingSamplesPV[k], trainingLabelsPV[k] = trainingSamples, trainingLabels
 
 # Training the Support Vector Machine
 print "Training Pedestrian-Cyclist-Vehicle Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelPBV.xml")
 
 print "Training Cyclist-Vehicle Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelBV.xml")
 
 print "Training Pedestrian-Cyclist Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelPB.xml")
 
 print "Training Pedestrian-Vehicle Model"
-model = ml.SVM(args.svmType, args.kernelType)
-model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()))
+model = ml.SVM()
+model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()), args.svmType, args.kernelType)
 model.save(args.directoryName + "/modelPV.xml")