changeset 557:b91f33e098ee

refactored some more code in compute crossing and collisions (parallel code works)
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 14 Jul 2014 17:33:43 -0400
parents dc58ad777a72
children a80ef6931fd8
files python/prediction.py
diffstat 1 files changed, 22 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/python/prediction.py	Mon Jul 14 01:56:51 2014 -0400
+++ b/python/prediction.py	Mon Jul 14 17:33:43 2014 -0400
@@ -124,6 +124,23 @@
         t += 1
     return t, p1, p2
 
+def savePredictedTrajectoriesFigure(currentInstant, obj1, obj2, predictedTrajectories1, predictedTrajectories2, timeHorizon):
+    from matplotlib.pyplot import figure, axis, title, close, savefig
+    figure()
+    for et in predictedTrajectories1:
+        et.predictPosition(timeHorizon)
+        et.plot('rx')
+
+    for et in predictedTrajectories2:
+        et.predictPosition(timeHorizon)
+        et.plot('bx')
+    obj1.plot('r')
+    obj2.plot('b')
+    title('instant {0}'.format(currentInstant))
+    axis('equal')
+    savefig('predicted-trajectories-t-{0}.png'.format(currentInstant))
+    close()
+
 def computeCrossingsCollisionsAtInstant(predictionParams, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False):
     '''returns the lists of collision points and crossing zones'''
     predictedTrajectories1 = predictionParams.generatePredictedTrajectories(obj1, currentInstant)
@@ -155,19 +172,7 @@
                     t1 += 1                        
 
     if debug:
-        from matplotlib.pyplot import figure, axis, title
-        figure()
-        for et in predictedTrajectories1:
-            et.predictPosition(timeHorizon)
-            et.plot('rx')
-
-        for et in predictedTrajectories2:
-            et.predictPosition(timeHorizon)
-            et.plot('bx')
-        obj1.plot('r')
-        obj2.plot('b')
-        title('instant {0}'.format(currentInstant))
-        axis('equal')
+        savePredictedTrajectoriesFigure(currentInstant, obj1, obj2, predictedTrajectories1, predictedTrajectories2, timeHorizon)
 
     return currentInstant, collisionPoints, crossingZones
 
@@ -185,7 +190,7 @@
     else:
         from multiprocessing import Pool
         pool = Pool(processes = nProcesses)
-        jobs = [pool.apply_async(computeCrossingsCollisionsAtInstant, args = (predictionParams, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ)) for i in list(commonTimeInterval)[:-1]]
+        jobs = [pool.apply_async(computeCrossingsCollisionsAtInstant, args = (predictionParams, i, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug)) for i in list(commonTimeInterval)[:-1]]
         #results = [j.get() for j in jobs]
         #results.sort()
         for j in jobs:
@@ -208,10 +213,10 @@
         return []
 
     def computeCrossingsCollisionsAtInstant(self, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False):
-        return computeCrossingsCollisionsAtInstant(self, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False)
+        return computeCrossingsCollisionsAtInstant(self, currentInstant, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug)
 
     def computeCrossingsCollisions(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False, timeInterval = None, nProcesses = 1):
-        return computeCrossingsCollisions(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ = False, debug = False, timeInterval = None, nProcesses = 1)
+        return computeCrossingsCollisions(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, computeCZ, debug, timeInterval, nProcesses)
 
     def computeCollisionProbability(self, obj1, obj2, collisionDistanceThreshold, timeHorizon, debug = False, timeInterval = None):
         '''Computes only collision probabilities
@@ -223,7 +228,6 @@
             commonTimeInterval = obj1.commonTimeInterval(obj2)
         for i in list(commonTimeInterval)[:-1]:
             nCollisions = 0
-            print(obj1.num, obj2.num, i)
             predictedTrajectories1 = self.generatePredictedTrajectories(obj1, i)
             predictedTrajectories2 = self.generatePredictedTrajectories(obj2, i)
             for et1 in predictedTrajectories1:
@@ -236,19 +240,7 @@
             collisionProbabilities[i] = [nSamples, float(nCollisions)/nSamples]
 
             if debug:
-                from matplotlib.pyplot import figure, axis, title
-                figure()
-                for et in predictedTrajectories1:
-                    et.predictPosition(timeHorizon)
-                    et.plot('rx')
-
-                for et in predictedTrajectories2:
-                    et.predictPosition(timeHorizon)
-                    et.plot('bx')
-                obj1.plot('r')
-                obj2.plot('b')
-                title('instant {0}'.format(i))
-                axis('equal')
+                savePredictedTrajectoriesFigure(i, obj1, obj2, predictedTrajectories1, predictedTrajectories2, timeHorizon)
 
         return collisionProbabilities