changeset 556:dc58ad777a72

modified prediction for multiprocessing, not sure how beneficial it is (single thread with instance method seems much faster
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 14 Jul 2014 01:56:51 -0400
parents f13220f765e0
children b91f33e098ee
files python/prediction.py
diffstat 1 files changed, 75 insertions(+), 57 deletions(-) [+]
line wrap: on
line diff
--- a/python/prediction.py	Sun Jul 13 23:34:00 2014 -0400
+++ b/python/prediction.py	Mon Jul 14 01:56:51 2014 -0400
@@ -124,6 +124,78 @@
         t += 1
     return t, p1, p2
 
+def computeCrossingsCollisionsAtInstant(predictionParams, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False):
+    '''returns the lists of collision points and crossing zones'''
+    predictedTrajectories1 = predictionParams.generatePredictedTrajectories(obj1, currentInstant)
+    predictedTrajectories2 = predictionParams.generatePredictedTrajectories(obj2, currentInstant)
+
+    collisionPoints = []
+    crossingZones = []
+    for et1 in predictedTrajectories1:
+        for et2 in predictedTrajectories2:
+            t, p1, p2 = computeCollisionTime(et1, et2, collisionDistanceThreshold, timeHorizon)
+
+            if t <= timeHorizon:
+                collisionPoints.append(SafetyPoint((p1+p2).multiply(0.5), et1.probability*et2.probability, t))
+            elif computeCZ: # check if there is a crossing zone
+                # TODO? zone should be around the points at which the traj are the closest
+                # look for CZ at different times, otherwise it would be a collision
+                # an approximation would be to look for close points at different times, ie the complementary of collision points
+                cz = None
+                t1 = 0
+                while not cz and t1 < timeHorizon: # t1 <= timeHorizon-1
+                    t2 = 0
+                    while not cz and t2 < timeHorizon:
+                        #if (et1.predictPosition(t1)-et2.predictPosition(t2)).norm2() < collisionDistanceThreshold:
+                        #    cz = (et1.predictPosition(t1)+et2.predictPosition(t2)).multiply(0.5)
+                        cz = moving.segmentIntersection(et1.predictPosition(t1), et1.predictPosition(t1+1), et2.predictPosition(t2), et2.predictPosition(t2+1))
+                        if cz:
+                            crossingZones.append(SafetyPoint(cz, et1.probability*et2.probability, abs(t1-t2)))
+                        t2 += 1
+                    t1 += 1                        
+
+    if debug:
+        from matplotlib.pyplot import figure, axis, title
+        figure()
+        for et in predictedTrajectories1:
+            et.predictPosition(timeHorizon)
+            et.plot('rx')
+
+        for et in predictedTrajectories2:
+            et.predictPosition(timeHorizon)
+            et.plot('bx')
+        obj1.plot('r')
+        obj2.plot('b')
+        title('instant {0}'.format(currentInstant))
+        axis('equal')
+
+    return currentInstant, collisionPoints, crossingZones
+
+def computeCrossingsCollisions(predictionParams, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False, timeInterval = None, nProcesses = 1):
+    '''Computes all crossing and collision points at each common instant for two road users. '''
+    collisionPoints={}
+    crossingZones={}
+    if timeInterval:
+        commonTimeInterval = timeInterval
+    else:
+        commonTimeInterval = obj1.commonTimeInterval(obj2)
+    if nProcesses == 1:
+        for i in list(commonTimeInterval)[:-1]: # do not look at the 1 last position/velocities, often with errors
+            i, collisionPoints[i], crossingZones[i] = computeCrossingsCollisionsAtInstant(predictionParams, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug)
+    else:
+        from multiprocessing import Pool
+        pool = Pool(processes = nProcesses)
+        jobs = [pool.apply_async(computeCrossingsCollisionsAtInstant, args = (predictionParams, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ)) for i in list(commonTimeInterval)[:-1]]
+        #results = [j.get() for j in jobs]
+        #results.sort()
+        for j in jobs:
+            i, cp, cz = j.get()
+            #if len(cp) != 0 or len(cz) != 0:
+            collisionPoints[i] = cp
+            crossingZones[i] = cz
+        pool.close()
+    return collisionPoints, crossingZones
+
 class PredictionParameters:
     def __init__(self, name, maxSpeed):
         self.name = name
@@ -136,64 +208,10 @@
         return []
 
     def computeCrossingsCollisionsAtInstant(self, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False):
-        '''returns the lists of collision points and crossing zones'''
-        predictedTrajectories1 = self.generatePredictedTrajectories(obj1, currentInstant)
-        predictedTrajectories2 = self.generatePredictedTrajectories(obj2, currentInstant)
-
-        collisionPoints = []
-        crossingZones = []
-        for et1 in predictedTrajectories1:
-            for et2 in predictedTrajectories2:
-                t, p1, p2 = computeCollisionTime(et1, et2, collisionDistanceThreshold, timeHorizon)
-
-                if t <= timeHorizon:
-                    collisionPoints.append(SafetyPoint((p1+p2).multiply(0.5), et1.probability*et2.probability, t))
-                elif computeCZ: # check if there is a crossing zone
-                    # TODO? zone should be around the points at which the traj are the closest
-                    # look for CZ at different times, otherwise it would be a collision
-                    # an approximation would be to look for close points at different times, ie the complementary of collision points
-                    cz = None
-                    t1 = 0
-                    while not cz and t1 < timeHorizon: # t1 <= timeHorizon-1
-                        t2 = 0
-                        while not cz and t2 < timeHorizon:
-                            #if (et1.predictPosition(t1)-et2.predictPosition(t2)).norm2() < collisionDistanceThreshold:
-                            #    cz = (et1.predictPosition(t1)+et2.predictPosition(t2)).multiply(0.5)
-                            cz = moving.segmentIntersection(et1.predictPosition(t1), et1.predictPosition(t1+1), et2.predictPosition(t2), et2.predictPosition(t2+1))
-                            if cz:
-                                crossingZones.append(SafetyPoint(cz, et1.probability*et2.probability, abs(t1-t2)))
-                            t2 += 1
-                        t1 += 1                        
+        return computeCrossingsCollisionsAtInstant(self, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False)
 
-        if debug:
-            from matplotlib.pyplot import figure, axis, title
-            figure()
-            for et in predictedTrajectories1:
-                et.predictPosition(timeHorizon)
-                et.plot('rx')
-
-            for et in predictedTrajectories2:
-                et.predictPosition(timeHorizon)
-                et.plot('bx')
-            obj1.plot('r')
-            obj2.plot('b')
-            title('instant {0}'.format(currentInstant))
-            axis('equal')
-
-        return collisionPoints, crossingZones
-
-    def computeCrossingsCollisions(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False, timeInterval = None):
-        '''Computes all crossing and collision points at each common instant for two road users. '''
-        collisionPoints={}
-        crossingZones={}
-        if timeInterval:
-            commonTimeInterval = timeInterval
-        else:
-            commonTimeInterval = obj1.commonTimeInterval(obj2)
-        for i in list(commonTimeInterval)[:-1]: # do not look at the 1 last position/velocities, often with errors
-            collisionPoints[i], crossingZones[i] = self.computeCrossingsCollisionsAtInstant(i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug)
-
-        return collisionPoints, crossingZones
+    def computeCrossingsCollisions(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False, timeInterval = None, nProcesses = 1):
+        return computeCrossingsCollisions(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False, timeInterval = None, nProcesses = 1)
 
     def computeCollisionProbability(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, debug = False, timeInterval = None):
         '''Computes only collision probabilities