comparison python/prediction.py @ 696:ae137e3b1990 dev

minor correction
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 14 Jul 2015 00:14:54 -0400
parents 9a258687af4c
children 8d99a9e16644
comparison
equal deleted inserted replaced
695:957126bfb456 696:ae137e3b1990
4 import moving 4 import moving
5 from utils import LCSS 5 from utils import LCSS
6 6
7 import math, random 7 import math, random
8 import numpy as np 8 import numpy as np
9 from multiprocessing import Pool
9 10
10 11
11 class PredictedTrajectory(object): 12 class PredictedTrajectory(object):
12 '''Class for predicted trajectories with lazy evaluation 13 '''Class for predicted trajectories with lazy evaluation
13 if the predicted position has not been already computed, compute it 14 if the predicted position has not been already computed, compute it
135 for p in points: 136 for p in points:
136 out.write('{0} {1} {2} {3}\n'.format(objNum1, objNum2, predictionInstant, p)) 137 out.write('{0} {1} {2} {3}\n'.format(objNum1, objNum2, predictionInstant, p))
137 138
138 @staticmethod 139 @staticmethod
139 def computeExpectedIndicator(points): 140 def computeExpectedIndicator(points):
140 from numpy import sum 141 return np.sum([p.indicator*p.probability for p in points])/sum([p.probability for p in points])
141 return sum([p.indicator*p.probability for p in points])/sum([p.probability for p in points])
142 142
143 def computeCollisionTime(predictedTrajectory1, predictedTrajectory2, collisionDistanceThreshold, timeHorizon): 143 def computeCollisionTime(predictedTrajectory1, predictedTrajectory2, collisionDistanceThreshold, timeHorizon):
144 '''Computes the first instant 144 '''Computes the first instant
145 at which two predicted trajectories are within some distance threshold 145 at which two predicted trajectories are within some distance threshold
146 Computes all the times including timeHorizon 146 Computes all the times including timeHorizon
159 159
160 def savePredictedTrajectoriesFigure(currentInstant, obj1, obj2, predictedTrajectories1, predictedTrajectories2, timeHorizon): 160 def savePredictedTrajectoriesFigure(currentInstant, obj1, obj2, predictedTrajectories1, predictedTrajectories2, timeHorizon):
161 from matplotlib.pyplot import figure, axis, title, close, savefig 161 from matplotlib.pyplot import figure, axis, title, close, savefig
162 figure() 162 figure()
163 for et in predictedTrajectories1: 163 for et in predictedTrajectories1:
164 et.predictPosition(timeHorizon) 164 et.predictPosition(int(np.round(timeHorizon)))
165 et.plot('rx') 165 et.plot('rx')
166 166
167 for et in predictedTrajectories2: 167 for et in predictedTrajectories2:
168 et.predictPosition(timeHorizon) 168 et.predictPosition(int(np.round(timeHorizon)))
169 et.plot('bx') 169 et.plot('bx')
170 obj1.plot('r') 170 obj1.plot('r')
171 obj2.plot('b') 171 obj2.plot('b')
172 title('instant {0}'.format(currentInstant)) 172 title('instant {0}'.format(currentInstant))
173 axis('equal') 173 axis('equal')
343 if len(cp) != 0: 343 if len(cp) != 0:
344 collisionPoints[i] = cp 344 collisionPoints[i] = cp
345 if len(cz) != 0: 345 if len(cz) != 0:
346 crossingZones[i] = cz 346 crossingZones[i] = cz
347 else: 347 else:
348 from multiprocessing import Pool
349 pool = Pool(processes = nProcesses) 348 pool = Pool(processes = nProcesses)
350 jobs = [pool.apply_async(computeCrossingsCollisionsAtInstant, args = (self, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug,usePrototypes,route1,route2,prototypes,secondStepPrototypes,nMatching,objects,noiseEntryNums,noiseExitNums,minSimilarity,mostMatched,useDestination,useSpeedPrototype)) for i in list(commonTimeInterval)[:-1]] 349 jobs = [pool.apply_async(computeCrossingsCollisionsAtInstant, args = (self, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug,usePrototypes,route1,route2,prototypes,secondStepPrototypes,nMatching,objects,noiseEntryNums,noiseExitNums,minSimilarity,mostMatched,useDestination,useSpeedPrototype)) for i in list(commonTimeInterval)[:-1]]
351 #results = [j.get() for j in jobs] 350 #results = [j.get() for j in jobs]
352 #results.sort() 351 #results.sort()
353 for j in jobs: 352 for j in jobs: