comparison scripts/train-object-classification.py @ 680:da1352b89d02 dev

classification is working
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 05 Jun 2015 02:25:30 +0200
parents ce40a89bd6ae
children 5b970a5bc233
comparison
equal deleted inserted replaced
678:97c305108460 680:da1352b89d02
34 trainingSamplesPV = {} 34 trainingSamplesPV = {}
35 trainingLabelsPV = {} 35 trainingLabelsPV = {}
36 36
37 for k, v in imageDirectories.iteritems(): 37 for k, v in imageDirectories.iteritems():
38 print('Loading {} samples'.format(k)) 38 print('Loading {} samples'.format(k))
39 trainingSamplesPBV[k], trainingLabelsPBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) 39 trainingSamples, trainingLabels = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock)
40 trainingSamplesPBV[k], trainingLabelsPBV[k] = trainingSamples, trainingLabels
40 if k != 'pedestrian': 41 if k != 'pedestrian':
41 trainingSamplesBV[k], trainingLabelsBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) 42 trainingSamplesBV[k], trainingLabelsBV[k] = trainingSamples, trainingLabels
42 if k != 'car': 43 if k != 'car':
43 trainingSamplesPB[k], trainingLabelsPB[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) 44 trainingSamplesPB[k], trainingLabelsPB[k] = trainingSamples, trainingLabels
44 if k != 'bicycle': 45 if k != 'bicycle':
45 trainingSamplesPV[k], trainingLabelsPV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) 46 trainingSamplesPV[k], trainingLabelsPV[k] = trainingSamples, trainingLabels
46 47
47 # Training the Support Vector Machine 48 # Training the Support Vector Machine
48 print "Training Pedestrian-Cyclist-Vehicle Model" 49 print "Training Pedestrian-Cyclist-Vehicle Model"
49 model = ml.SVM(args.svmType, args.kernelType) 50 model = ml.SVM()
50 model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values())) 51 model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()), args.svmType, args.kernelType)
51 model.save(args.directoryName + "/modelPBV.xml") 52 model.save(args.directoryName + "/modelPBV.xml")
52 53
53 print "Training Cyclist-Vehicle Model" 54 print "Training Cyclist-Vehicle Model"
54 model = ml.SVM(args.svmType, args.kernelType) 55 model = ml.SVM()
55 model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values())) 56 model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()), args.svmType, args.kernelType)
56 model.save(args.directoryName + "/modelBV.xml") 57 model.save(args.directoryName + "/modelBV.xml")
57 58
58 print "Training Pedestrian-Cyclist Model" 59 print "Training Pedestrian-Cyclist Model"
59 model = ml.SVM(args.svmType, args.kernelType) 60 model = ml.SVM()
60 model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values())) 61 model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()), args.svmType, args.kernelType)
61 model.save(args.directoryName + "/modelPB.xml") 62 model.save(args.directoryName + "/modelPB.xml")
62 63
63 print "Training Pedestrian-Vehicle Model" 64 print "Training Pedestrian-Vehicle Model"
64 model = ml.SVM(args.svmType, args.kernelType) 65 model = ml.SVM()
65 model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values())) 66 model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()), args.svmType, args.kernelType)
66 model.save(args.directoryName + "/modelPV.xml") 67 model.save(args.directoryName + "/modelPV.xml")