Mercurial Hosting > traffic-intelligence
diff scripts/safety-analysis.py @ 949:d6c1c05d11f5
modified multithreading at the interaction level for safety computations
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Fri, 21 Jul 2017 17:52:56 -0400 |
parents | 584b9405e494 |
children | eb42f2f51490 |
line wrap: on
line diff
--- a/scripts/safety-analysis.py Fri Jul 21 12:11:55 2017 -0400 +++ b/scripts/safety-analysis.py Fri Jul 21 17:52:56 2017 -0400 @@ -3,6 +3,7 @@ import storage, prediction, events, moving import sys, argparse, random +from multiprocessing import Pool import matplotlib.pyplot as plt import numpy as np @@ -70,22 +71,22 @@ objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) interactions = events.createInteractions(objects) -for inter in interactions: - print('processing interaction {}'.format(inter.getNum()) - inter.computeIndicators() - if not args.noMotionPrediction: - inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, nProcesses = args.nProcesses) - -if args.computePET: - for inter in interactions: - inter.computePET(params.collisionDistance) - -storage.saveIndicatorsToSqlite(params.databaseFilename, interactions) +if args.nProcesses == 1: + processed = events.computeIndicators(interactions, not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None) +else: + pool = Pool(processes = args.nProcesses) + nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses))) + jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)] + processed = [] + for job in jobs: + processed += job.get() + pool.close() +storage.saveIndicatorsToSqlite(params.databaseFilename, processed) if args.displayCollisionPoints: plt.figure() allCollisionPoints = [] - for inter in interactions: + for inter in processed: for collisionPoints in inter.collisionPoints.values(): allCollisionPoints += collisionPoints moving.Point.plotAll(allCollisionPoints)