comparison scripts/classify-objects.py @ 854:33d296984dd8

rework and more info on speed probabilities for classification
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Thu, 22 Sep 2016 17:50:35 -0400
parents b9ec0cc2677d
children ff92801e5c54
comparison
equal deleted inserted replaced
853:95e7622b11be 854:33d296984dd8
14 parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file', required = True) 14 parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file', required = True)
15 parser.add_argument('-d', dest = 'databaseFilename', help = 'name of the Sqlite database file (overrides the configuration file)') 15 parser.add_argument('-d', dest = 'databaseFilename', help = 'name of the Sqlite database file (overrides the configuration file)')
16 parser.add_argument('-i', dest = 'videoFilename', help = 'name of the video file (overrides the configuration file)') 16 parser.add_argument('-i', dest = 'videoFilename', help = 'name of the video file (overrides the configuration file)')
17 parser.add_argument('-n', dest = 'nObjects', help = 'number of objects to classify', type = int, default = None) 17 parser.add_argument('-n', dest = 'nObjects', help = 'number of objects to classify', type = int, default = None)
18 parser.add_argument('--plot-speed-distributions', dest = 'plotSpeedDistribution', help = 'simply plots the distributions used for each user type', action = 'store_true') 18 parser.add_argument('--plot-speed-distributions', dest = 'plotSpeedDistribution', help = 'simply plots the distributions used for each user type', action = 'store_true')
19 parser.add_argument('--max-speed-distribution-plot', dest = 'maxSpeedDistributionPlot', help = 'if plotting the user distributions, the maximum speed to display', type = float, default = 50.) 19 parser.add_argument('--max-speed-distribution-plot', dest = 'maxSpeedDistributionPlot', help = 'if plotting the user distributions, the maximum speed to display (km/h)', type = float, default = 50.)
20 20
21 args = parser.parse_args() 21 args = parser.parse_args()
22 params = storage.ProcessParameters(args.configFilename) 22 params = storage.ProcessParameters(args.configFilename)
23 classifierParams = storage.ClassifierParameters(params.classifierFilename) 23 classifierParams = storage.ClassifierParameters(params.classifierFilename)
24 classifierParams.convertToFrames(params.videoFrameRate, 3.6) # conversion from km/h to m/frame
24 25
25 if args.videoFilename is not None: 26 if args.videoFilename is not None:
26 videoFilename = args.videoFilename 27 videoFilename = args.videoFilename
27 else: 28 else:
28 videoFilename = params.videoFilename 29 videoFilename = params.videoFilename
29 if args.databaseFilename is not None: 30 if args.databaseFilename is not None:
30 databaseFilename = args.databaseFilename 31 databaseFilename = args.databaseFilename
31 else: 32 else:
32 databaseFilename = params.databaseFilename 33 databaseFilename = params.databaseFilename
33 34
34 classifierParams.convertToFrames(params.videoFrameRate, 3.6) # conversion from km/h to m/s
35 if params.homography is not None: 35 if params.homography is not None:
36 invHomography = np.linalg.inv(params.homography) 36 invHomography = np.linalg.inv(params.homography)
37 else: 37 else:
38 invHomography = None 38 invHomography = None
39 39
51 pedBikeCarSVM.load(classifierParams.pedBikeCarSVMFilename) 51 pedBikeCarSVM.load(classifierParams.pedBikeCarSVMFilename)
52 bikeCarSVM = ml.SVM() 52 bikeCarSVM = ml.SVM()
53 bikeCarSVM.load(classifierParams.bikeCarSVMFilename) 53 bikeCarSVM.load(classifierParams.bikeCarSVMFilename)
54 54
55 # log logistic for ped and bik otherwise ((pedBeta/pedAlfa)*((sMean/pedAlfa)**(pedBeta-1)))/((1+(sMean/pedAlfa)**pedBeta)**2.) 55 # log logistic for ped and bik otherwise ((pedBeta/pedAlfa)*((sMean/pedAlfa)**(pedBeta-1)))/((1+(sMean/pedAlfa)**pedBeta)**2.)
56 speedProbabilities = {'car': lambda s: norm(classifierParams.meanVehicleSpeed, classifierParams.stdVehicleSpeed).pdf(s), 56 carNorm = norm(classifierParams.meanVehicleSpeed, classifierParams.stdVehicleSpeed)
57 'pedestrian': lambda s: norm(classifierParams.meanPedestrianSpeed, classifierParams.stdPedestrianSpeed).pdf(s), 57 pedNorm = norm(classifierParams.meanPedestrianSpeed, classifierParams.stdPedestrianSpeed)
58 'bicycle': lambda s: lognorm(classifierParams.scaleCyclistSpeed, loc = 0., scale = np.exp(classifierParams.locationCyclistSpeed)).pdf(s)} # numpy lognorm shape, loc, scale: shape for numpy is scale (std of the normal) and scale for numpy is location (mean of the normal) 58 # numpy lognorm shape, loc, scale: shape for numpy is scale (std of the normal) and scale for numpy is exp(location) (loc=mean of the normal)
59 bicLogNorm = lognorm(classifierParams.scaleCyclistSpeed, loc = 0., scale = np.exp(classifierParams.locationCyclistSpeed))
60 speedProbabilities = {'car': lambda s: carNorm.pdf(s),
61 'pedestrian': lambda s: pedNorm.pdf(s),
62 'bicycle': lambda s: bicLogNorm.pdf(s)}
59 63
60 if args.plotSpeedDistribution: 64 if args.plotSpeedDistribution:
61 import matplotlib.pyplot as plt 65 import matplotlib.pyplot as plt
62 plt.figure() 66 plt.figure()
63 for k in speedProbabilities: 67 for k in speedProbabilities:
64 plt.plot(np.arange(0.1, args.maxSpeedDistributionPlot, 0.1), [speedProbabilities[k](s/3.6/25) for s in np.arange(0.1, args.maxSpeedDistributionPlot, 0.1)], label = k) 68 plt.plot(np.arange(0.1, args.maxSpeedDistributionPlot, 0.1), [speedProbabilities[k](s/(3.6*params.videoFrameRate)) for s in np.arange(0.1, args.maxSpeedDistributionPlot, 0.1)], label = k)
69 maxProb = -1.
70 for k in speedProbabilities:
71 maxProb = max(maxProb, np.max([speedProbabilities[k](s/(3.6*params.videoFrameRate)) for s in np.arange(0.1, args.maxSpeedDistributionPlot, 0.1)]))
72 plt.plot([classifierParams.minSpeedEquiprobable*3.6*params.videoFrameRate]*2, [0., maxProb], 'k-')
73 plt.text(classifierParams.minSpeedEquiprobable*3.6*params.videoFrameRate, maxProb, 'threshold for equiprobable class')
65 plt.xlabel('Speed (km/h)') 74 plt.xlabel('Speed (km/h)')
66 plt.ylabel('Probability') 75 plt.ylabel('Probability')
67 plt.legend() 76 plt.legend()
68 plt.title('Probability Density Function') 77 plt.title('Probability Density Function')
69 plt.show() 78 plt.show()