changeset 696:ae137e3b1990 dev

minor correction
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 14 Jul 2015 00:14:54 -0400
parents 957126bfb456
children 0421a5a0072c
files python/prediction.py
diffstat 1 files changed, 4 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/python/prediction.py	Thu Jul 09 00:58:08 2015 -0400
+++ b/python/prediction.py	Tue Jul 14 00:14:54 2015 -0400
@@ -6,6 +6,7 @@
 
 import math, random
 import numpy as np
+from multiprocessing import Pool
 
 
 class PredictedTrajectory(object):
@@ -137,8 +138,7 @@
 
     @staticmethod
     def computeExpectedIndicator(points):
-        from numpy import sum
-        return sum([p.indicator*p.probability for p in points])/sum([p.probability for p in points])
+        return np.sum([p.indicator*p.probability for p in points])/sum([p.probability for p in points])
 
 def computeCollisionTime(predictedTrajectory1, predictedTrajectory2, collisionDistanceThreshold, timeHorizon):
     '''Computes the first instant 
@@ -161,11 +161,11 @@
     from matplotlib.pyplot import figure, axis, title, close, savefig
     figure()
     for et in predictedTrajectories1:
-        et.predictPosition(timeHorizon)
+        et.predictPosition(int(np.round(timeHorizon)))
         et.plot('rx')
 
     for et in predictedTrajectories2:
-        et.predictPosition(timeHorizon)
+        et.predictPosition(int(np.round(timeHorizon)))
         et.plot('bx')
     obj1.plot('r')
     obj2.plot('b')
@@ -345,7 +345,6 @@
                     if len(cz) != 0:
                         crossingZones[i] = cz
         else:
-            from multiprocessing import Pool
             pool = Pool(processes = nProcesses)
             jobs = [pool.apply_async(computeCrossingsCollisionsAtInstant, args = (self, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug,usePrototypes,route1,route2,prototypes,secondStepPrototypes,nMatching,objects,noiseEntryNums,noiseExitNums,minSimilarity,mostMatched,useDestination,useSpeedPrototype)) for i in list(commonTimeInterval)[:-1]]
             #results = [j.get() for j in jobs]