Mercurial Hosting > traffic-intelligence
comparison scripts/train-object-classification.py @ 680:da1352b89d02 dev
classification is working
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Fri, 05 Jun 2015 02:25:30 +0200 |
parents | ce40a89bd6ae |
children | 5b970a5bc233 |
comparison
equal
deleted
inserted
replaced
678:97c305108460 | 680:da1352b89d02 |
---|---|
34 trainingSamplesPV = {} | 34 trainingSamplesPV = {} |
35 trainingLabelsPV = {} | 35 trainingLabelsPV = {} |
36 | 36 |
37 for k, v in imageDirectories.iteritems(): | 37 for k, v in imageDirectories.iteritems(): |
38 print('Loading {} samples'.format(k)) | 38 print('Loading {} samples'.format(k)) |
39 trainingSamplesPBV[k], trainingLabelsPBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 39 trainingSamples, trainingLabels = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) |
40 trainingSamplesPBV[k], trainingLabelsPBV[k] = trainingSamples, trainingLabels | |
40 if k != 'pedestrian': | 41 if k != 'pedestrian': |
41 trainingSamplesBV[k], trainingLabelsBV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 42 trainingSamplesBV[k], trainingLabelsBV[k] = trainingSamples, trainingLabels |
42 if k != 'car': | 43 if k != 'car': |
43 trainingSamplesPB[k], trainingLabelsPB[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 44 trainingSamplesPB[k], trainingLabelsPB[k] = trainingSamples, trainingLabels |
44 if k != 'bicycle': | 45 if k != 'bicycle': |
45 trainingSamplesPV[k], trainingLabelsPV[k] = cvutils.createHOGTrainingSet(v, moving.userType2Num[k], rescaleSize, args.nOrientations, nPixelsPerCell, nCellsPerBlock) | 46 trainingSamplesPV[k], trainingLabelsPV[k] = trainingSamples, trainingLabels |
46 | 47 |
47 # Training the Support Vector Machine | 48 # Training the Support Vector Machine |
48 print "Training Pedestrian-Cyclist-Vehicle Model" | 49 print "Training Pedestrian-Cyclist-Vehicle Model" |
49 model = ml.SVM(args.svmType, args.kernelType) | 50 model = ml.SVM() |
50 model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values())) | 51 model.train(np.concatenate(trainingSamplesPBV.values()), np.concatenate(trainingLabelsPBV.values()), args.svmType, args.kernelType) |
51 model.save(args.directoryName + "/modelPBV.xml") | 52 model.save(args.directoryName + "/modelPBV.xml") |
52 | 53 |
53 print "Training Cyclist-Vehicle Model" | 54 print "Training Cyclist-Vehicle Model" |
54 model = ml.SVM(args.svmType, args.kernelType) | 55 model = ml.SVM() |
55 model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values())) | 56 model.train(np.concatenate(trainingSamplesBV.values()), np.concatenate(trainingLabelsBV.values()), args.svmType, args.kernelType) |
56 model.save(args.directoryName + "/modelBV.xml") | 57 model.save(args.directoryName + "/modelBV.xml") |
57 | 58 |
58 print "Training Pedestrian-Cyclist Model" | 59 print "Training Pedestrian-Cyclist Model" |
59 model = ml.SVM(args.svmType, args.kernelType) | 60 model = ml.SVM() |
60 model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values())) | 61 model.train(np.concatenate(trainingSamplesPB.values()), np.concatenate(trainingLabelsPB.values()), args.svmType, args.kernelType) |
61 model.save(args.directoryName + "/modelPB.xml") | 62 model.save(args.directoryName + "/modelPB.xml") |
62 | 63 |
63 print "Training Pedestrian-Vehicle Model" | 64 print "Training Pedestrian-Vehicle Model" |
64 model = ml.SVM(args.svmType, args.kernelType) | 65 model = ml.SVM() |
65 model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values())) | 66 model.train(np.concatenate(trainingSamplesPV.values()), np.concatenate(trainingLabelsPV.values()), args.svmType, args.kernelType) |
66 model.save(args.directoryName + "/modelPV.xml") | 67 model.save(args.directoryName + "/modelPV.xml") |