Mercurial Hosting > traffic-intelligence
changeset 1071:58994b08be42
added multithreading for safety
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 18 Jul 2018 02:12:47 -0400 |
parents | 0154133e77df |
children | c67f8c36ebc7 8ab92ee3cbef |
files | scripts/process.py |
diffstat | 1 files changed, 17 insertions(+), 16 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/process.py Tue Jul 17 10:34:39 2018 -0400 +++ b/scripts/process.py Wed Jul 18 02:12:47 2018 -0400 @@ -217,32 +217,33 @@ clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1) storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices]) - elif args.process == 'interaction': # safety analysis TODO make function in safety analysis script if args.predictionMethod == 'cvd': predictionParameters = prediction.CVDirectPredictionParameters() - if args.predictionMethod == 'cve': + elif args.predictionMethod == 'cve': predictionParameters = prediction.CVExactPredictionParameters() for vs in videoSequences: print('Processing '+vs.getDatabaseFilename()) + if args.configFilename is None: + params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) + else: + params = storage.ProcessParameters(args.configFilename) objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp')) interactions = events.createInteractions(objects) - #if args.nProcesses == 1: - #print(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) - params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename())) - #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones) - processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None) + if args.nProcesses == 1: + #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones) + processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None) # params.crossingZones + else: + #pool = Pool(processes = args.nProcesses) + nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses))) + jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None)) for i in range(args.nProcesses)] # params.crossingZones + processed = [] + for job in jobs: + processed += job.get() + #pool.close() storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed) - # else: - # pool = Pool(processes = args.nProcesses) - # nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses))) - # jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)] - # processed = [] - # for job in jobs: - # processed += job.get() - # pool.close() - + ################################# # Analyze #################################