changeset 1071:58994b08be42

added multithreading for safety
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 18 Jul 2018 02:12:47 -0400
parents 0154133e77df
children c67f8c36ebc7 8ab92ee3cbef
files scripts/process.py
diffstat 1 files changed, 17 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/process.py	Tue Jul 17 10:34:39 2018 -0400
+++ b/scripts/process.py	Wed Jul 18 02:12:47 2018 -0400
@@ -217,32 +217,33 @@
         clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
         storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices])
 
-
 elif args.process == 'interaction':
     # safety analysis TODO make function in safety analysis script
     if args.predictionMethod == 'cvd':
         predictionParameters = prediction.CVDirectPredictionParameters()
-    if args.predictionMethod == 'cve':
+    elif args.predictionMethod == 'cve':
         predictionParameters = prediction.CVExactPredictionParameters()
     for vs in videoSequences:
         print('Processing '+vs.getDatabaseFilename())
+        if args.configFilename is None:
+            params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
+        else:
+            params = storage.ProcessParameters(args.configFilename)  
         objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
         interactions = events.createInteractions(objects)
-        #if args.nProcesses == 1:
-        #print(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
-        params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
-        #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
-        processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
+        if args.nProcesses == 1:
+            #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
+            processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None) # params.crossingZones
+        else:
+            #pool = Pool(processes = args.nProcesses)
+            nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
+            jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None)) for i in range(args.nProcesses)] # params.crossingZones
+            processed = []
+            for job in jobs:
+                processed += job.get()
+            #pool.close()
         storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed)
-    # else:
-    #     pool = Pool(processes = args.nProcesses)
-    #     nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
-    #     jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
-    #     processed = []
-    #     for job in jobs:
-    #         processed += job.get()
-    #     pool.close()
-
+            
 #################################
 # Analyze
 #################################