changeset 1078:8cc3feb1c1c5

merged
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 20 Jul 2018 16:26:57 -0400
parents b123fa0e5440 (diff) 3939ae415be0 (current diff)
children 845d694af7b8
files
diffstat 3 files changed, 85 insertions(+), 54 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/process.py	Fri Jul 20 14:03:34 2018 -0400
+++ b/scripts/process.py	Fri Jul 20 16:26:57 2018 -0400
@@ -60,7 +60,7 @@
 
 # analysis options
 parser.add_argument('--output', dest = 'output', help = 'kind of output to produce (interval means)', choices = ['figure', 'interval', 'event'])
-parser.add_argument('--min-user-duration', dest = 'minUserDuration', help = 'mininum duration we have to see the user to take into account in the analysis (s)', type = float, default = 0.1)
+parser.add_argument('--min-duration', dest = 'minDuration', help = 'mininum duration we have to see the user or interaction to take into account in the analysis (s)', type = float)
 parser.add_argument('--interval-duration', dest = 'intervalDuration', help = 'length of time interval to aggregate data (min)', type = int, default = 15)
 parser.add_argument('--aggregation', dest = 'aggMethods', help = 'aggregation method per user/interaction and per interval', choices = ['mean', 'median', 'centile'], nargs = '*', default = ['median'])
 parser.add_argument('--aggregation-centiles', dest = 'aggCentiles', help = 'centile(s) to compute from the observations', nargs = '*', type = int)
@@ -110,7 +110,7 @@
 #################################
 # Report progress in the processing
 #################################
-if args.progress:
+if args.progress: # TODO find video sequences that have null camera view, to work with them
     print('Providing information on data progress')
     headers = ['site', 'vs', 'features', 'objects', 'interactions'] # todo add prototypes and object classification
     data = []
@@ -217,32 +217,33 @@
         clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
         storage.savePrototypesToSqlite(str(parentPath/site.getPath()/outputPrototypeDatabaseFilename), [moving.Prototype(object2VideoSequences[trainingObjects[i]].getDatabaseFilename(False), trainingObjects[i].getNum(), prototypeType, clusterSizes[i]) for i in prototypeIndices])
 
-
 elif args.process == 'interaction':
     # safety analysis TODO make function in safety analysis script
     if args.predictionMethod == 'cvd':
         predictionParameters = prediction.CVDirectPredictionParameters()
-    if args.predictionMethod == 'cve':
+    elif args.predictionMethod == 'cve':
         predictionParameters = prediction.CVExactPredictionParameters()
     for vs in videoSequences:
         print('Processing '+vs.getDatabaseFilename())
+        if args.configFilename is None:
+            params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
+        else:
+            params = storage.ProcessParameters(args.configFilename)  
         objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
         interactions = events.createInteractions(objects)
-        #if args.nProcesses == 1:
-        #print(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
-        params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
-        #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
-        processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
+        if args.nProcesses == 1:
+            #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
+            processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None) # params.crossingZones
+        else:
+            #pool = Pool(processes = args.nProcesses)
+            nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
+            jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, False, False, None)) for i in range(args.nProcesses)] # params.crossingZones
+            processed = []
+            for job in jobs:
+                processed += job.get()
+            #pool.close()
         storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed)
-    # else:
-    #     pool = Pool(processes = args.nProcesses)
-    #     nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
-    #     jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
-    #     processed = []
-    #     for job in jobs:
-    #         processed += job.get()
-    #     pool.close()
-
+            
 #################################
 # Analyze
 #################################
@@ -258,9 +259,9 @@
     headers.extend(tmpheaders)
     if args.nProcesses == 1:
         for vs in videoSequences:
-            data.extend(processing.extractVideoSequenceSpeeds(str(parentPath/vs.getDatabaseFilename()), vs.cameraView.site.name, args.nObjects, vs.startTime, vs.cameraView.cameraType.frameRate, args.minUserDuration, args.aggMethods, args.aggCentiles))
+            data.extend(processing.extractVideoSequenceSpeeds(str(parentPath/vs.getDatabaseFilename()), vs.cameraView.site.name, args.nObjects, vs.startTime, vs.cameraView.cameraType.frameRate, vs.cameraView.cameraType.frameRate*args.minDuration, args.aggMethods, args.aggCentiles))
     else:
-        jobs = [pool.apply_async(processing.extractVideoSequenceSpeeds, args = (str(parentPath/vs.getDatabaseFilename()), vs.cameraView.site.name, args.nObjects, vs.startTime, vs.cameraView.cameraType.frameRate, args.minUserDuration, args.aggMethods, args.aggCentiles)) for vs in videoSequences]
+        jobs = [pool.apply_async(processing.extractVideoSequenceSpeeds, args = (str(parentPath/vs.getDatabaseFilename()), vs.cameraView.site.name, args.nObjects, vs.startTime, vs.cameraView.cameraType.frameRate, vs.cameraView.cameraType.frameRate*args.minDuration, args.aggMethods, args.aggCentiles)) for vs in videoSequences]
         for job in jobs:
             data.extend(job.get())
         pool.close()
@@ -279,33 +280,55 @@
 if args.analyze == 'interaction': # redo as for object, export in dataframe all interaction data
     indicatorIds = [2,5,7,10]
     conversionFactors = {2: 1., 5: 30.*3.6, 7:1./30, 10:1./30}
-    maxIndicatorValue = {2: float('inf'), 5: float('inf'), 7:10., 10:10.}
+    #maxIndicatorValue = {2: float('inf'), 5: float('inf'), 7:10., 10:10.}
+    data = [] # list of observation per site-user with time
+    headers = ['site', 'date', 'time', events.Interaction.indicatorNames[10].replace(' ','-')] # user types?
+    aggFunctions, tmpheaders = utils.aggregationMethods(args.aggMethods, args.aggCentiles)
+    for i in indicatorIds[:3]:
+        for h in tmpheaders:
+            headers.append(events.Interaction.indicatorNames[i].replace(' ','-')+'-'+h)
     indicators = {}
     interactions = {}
     for vs in videoSequences:
-        if not vs.cameraView.siteIdx in interactions:
-            interactions[vs.cameraView.siteIdx] = []
-            indicators[vs.cameraView.siteIdx] = {}
-            for i in indicatorIds:
-                indicators[vs.cameraView.siteIdx][i] = []
-        interactions[vs.cameraView.siteIdx] += storage.loadInteractionsFromSqlite(str(parentPath/vs.getDatabaseFilename()))
-        print(vs.getDatabaseFilename(), len(interactions[vs.cameraView.siteIdx]))
-        for inter in interactions[vs.cameraView.siteIdx]:
-            for i in indicatorIds:
-                indic = inter.getIndicator(events.Interaction.indicatorNames[i])
-                if indic is not None:
-                    v = indic.getMostSevereValue()*conversionFactors[i]
-                    if v < maxIndicatorValue[i]:
-                        indicators[vs.cameraView.siteIdx][i].append(v)
-
-    for i in indicatorIds:
-        tmp = [indicators[siteId][i] for siteId in indicators]
-        plt.ioff()
-        plt.figure()
-        plt.boxplot(tmp, labels = [session.query(Site).get(siteId).name for siteId in indicators])
-        plt.ylabel(events.Interaction.indicatorNames[i]+' ('+events.Interaction.indicatorUnits[i]+')')
-        plt.savefig(events.Interaction.indicatorNames[i]+'.png', dpi=150)
-        plt.close()
+        print('Extracting SMoS from '+vs.getDatabaseFilename())
+        interactions = storage.loadInteractionsFromSqlite(str(parentPath/vs.getDatabaseFilename()))
+        minDuration = vs.cameraView.cameraType.frameRate*args.minDuration
+        for inter in interactions:
+            if inter.length() > minDuration:
+                d = vs.startTime.date()
+                t = vs.startTime.time()
+                row = [vs.cameraView.site.name, d, utils.framesToTime(inter.getFirstInstant(), vs.cameraView.cameraType.frameRate, t)]
+                pet = inter.getIndicator('Post Encroachment Time')
+                if pet is None:
+                    row.append(None)
+                else:
+                    row.append(conversionFactors[10]*pet.getValues()[0])
+                for i in indicatorIds[:3]:
+                    indic = inter.getIndicator(events.Interaction.indicatorNames[i])
+                    if indic is not None:
+                        #v = indic.getMostSevereValue()*
+                        tmp = list(indic.values.values())
+                        for method,func in aggFunctions.items():
+                            agg = conversionFactors[i]*func(tmp)
+                            if method == 'centile':
+                                row.extend(agg.tolist())
+                            else:
+                                row.append(agg)
+                    else:
+                        row.extend([None]*len(aggFunctions))
+                data.append(row)
+    data = pd.DataFrame(data, columns = headers)
+    if args.output == 'figure':
+        for i in indicatorIds:
+            pass # tmp = [indicators[siteId][i] for siteId in indicators]
+            # plt.ioff()
+            # plt.figure()
+            # plt.boxplot(tmp, labels = [session.query(Site).get(siteId).name for siteId in indicators])
+            # plt.ylabel(events.Interaction.indicatorNames[i]+' ('+events.Interaction.indicatorUnits[i]+')')
+            # plt.savefig(events.Interaction.indicatorNames[i]+'.png', dpi=150)
+            # plt.close()
+    elif args.output == 'event':
+        data.to_csv(args.eventFilename, index = False)
 
 if args.analyze == 'event': # aggregate event data by 15 min interval (args.intervalDuration), count events with thresholds
     data = pd.read_csv(args.eventFilename, parse_dates = [2])
--- a/trafficintelligence/metadata.py	Fri Jul 20 14:03:34 2018 -0400
+++ b/trafficintelligence/metadata.py	Fri Jul 20 16:26:57 2018 -0400
@@ -1,4 +1,5 @@
 from datetime import datetime, timedelta
+from pathlib import Path
 from os import path, listdir, sep
 from math import floor
 
@@ -339,16 +340,24 @@
 
 def createDatabase(filename):
     'creates a session to query the filename'
-    engine = create_engine('sqlite:///'+filename)
-    Base.metadata.create_all(engine)
-    Session = sessionmaker(bind=engine)
-    return Session()
+    if Path(filename).is_file():
+        engine = create_engine('sqlite:///'+filename)
+        Base.metadata.create_all(engine)
+        Session = sessionmaker(bind=engine)
+        return Session()
+    else:
+        print('The file '+filename+' does not exist')
+        return None
 
 def connectDatabase(filename):
     'creates a session to query the filename'
-    engine = create_engine('sqlite:///'+filename)
-    Session = sessionmaker(bind=engine)
-    return Session()
+    if Path(filename).is_file():
+        engine = create_engine('sqlite:///'+filename)
+        Session = sessionmaker(bind=engine)
+        return Session()
+    else:
+        print('The file '+filename+' does not exist')
+        return None
 
 def getSite(session, siteId = None, name = None, description = None):
     'Returns the site(s) matching the index or the name'
--- a/trafficintelligence/processing.py	Fri Jul 20 14:03:34 2018 -0400
+++ b/trafficintelligence/processing.py	Fri Jul 20 16:26:57 2018 -0400
@@ -18,16 +18,15 @@
             objectsNotInZone.append(o)
     return speeds, objectsNotInZone
 
-def extractVideoSequenceSpeeds(dbFilename, siteName, nObjects, startTime, frameRate, minUserDurationSeconds, aggMethods, aggCentiles):
+def extractVideoSequenceSpeeds(dbFilename, siteName, nObjects, startTime, frameRate, minDuration, aggMethods, aggCentiles):
     data = []
     d = startTime.date()
     t1 = startTime.time()
-    minUserDuration = minUserDurationSeconds*frameRate
     print('Extracting speed from '+dbFilename)
     aggFunctions, tmpheaders = utils.aggregationMethods(aggMethods, aggCentiles)
     objects = storage.loadTrajectoriesFromSqlite(dbFilename, 'object', nObjects)
     for o in objects:
-        if o.length() > minUserDuration:
+        if o.length() > minDuration:
             row = [siteName, d, utils.framesToTime(o.getFirstInstant(), frameRate, t1), o.getUserType()]
             tmp = o.getSpeeds()
             for method,func in aggFunctions.items():