diff scripts/safety-analysis.py @ 949:d6c1c05d11f5

modified multithreading at the interaction level for safety computations
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 21 Jul 2017 17:52:56 -0400
parents 584b9405e494
children eb42f2f51490
line wrap: on
line diff
--- a/scripts/safety-analysis.py	Fri Jul 21 12:11:55 2017 -0400
+++ b/scripts/safety-analysis.py	Fri Jul 21 17:52:56 2017 -0400
@@ -3,6 +3,7 @@
 import storage, prediction, events, moving
 
 import sys, argparse, random
+from multiprocessing import Pool
 
 import matplotlib.pyplot as plt
 import numpy as np
@@ -70,22 +71,22 @@
 objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
 
 interactions = events.createInteractions(objects)
-for inter in interactions:
-    print('processing interaction {}'.format(inter.getNum())
-    inter.computeIndicators()
-    if not args.noMotionPrediction:
-        inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, nProcesses = args.nProcesses)
-
-if args.computePET:
-    for inter in interactions:
-        inter.computePET(params.collisionDistance)
-    
-storage.saveIndicatorsToSqlite(params.databaseFilename, interactions)
+if args.nProcesses == 1:
+    processed = events.computeIndicators(interactions, not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
+else:
+    pool = Pool(processes = args.nProcesses)
+    nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
+    jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
+    processed = []
+    for job in jobs:
+        processed += job.get()
+    pool.close()
+storage.saveIndicatorsToSqlite(params.databaseFilename, processed)
 
 if args.displayCollisionPoints:
     plt.figure()
     allCollisionPoints = []
-    for inter in interactions:
+    for inter in processed:
         for collisionPoints in inter.collisionPoints.values():
             allCollisionPoints += collisionPoints
     moving.Point.plotAll(allCollisionPoints)