diff scripts/process.py @ 1023:a13f47c8931d

work on processing large datasets (generate speed data)
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 06 Jun 2018 16:51:15 -0400
parents 16932cefabc1
children 73b124160911
line wrap: on
line diff
--- a/scripts/process.py	Wed Jun 06 10:35:06 2018 -0400
+++ b/scripts/process.py	Wed Jun 06 16:51:15 2018 -0400
@@ -4,30 +4,47 @@
 from pathlib import Path
 from multiprocessing.pool import Pool
 
-import matplotlib
-matplotlib.use('Agg')
+#import matplotlib
+#atplotlib.use('Agg')
 import matplotlib.pyplot as plt
 from numpy import percentile
+from pandas import DataFrame
 
-import storage, events, prediction, cvutils
+import storage, events, prediction, cvutils, utils
 from metadata import *
 
 parser = argparse.ArgumentParser(description='This program manages the processing of several files based on a description of the sites and video data in an SQLite database following the metadata module.')
+# input
 parser.add_argument('--db', dest = 'metadataFilename', help = 'name of the metadata file', required = True)
 parser.add_argument('--videos', dest = 'videoIds', help = 'indices of the video sequences', nargs = '*', type = int)
 parser.add_argument('--sites', dest = 'siteIds', help = 'indices of the video sequences', nargs = '*', type = int)
-parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
-parser.add_argument('-n', dest = 'nObjects', help = 'number of objects/interactions to process', type = int)
-parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
-parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
-# override other tracking config, erase sqlite?
+
+# main function
 parser.add_argument('--delete', dest = 'delete', help = 'data to delete', choices = ['feature', 'object', 'classification', 'interaction'])
 parser.add_argument('--process', dest = 'process', help = 'data to process', choices = ['feature', 'object', 'classification', 'interaction'])
 parser.add_argument('--display', dest = 'display', help = 'data to display (replay over video)', choices = ['feature', 'object', 'classification', 'interaction'])
 parser.add_argument('--analyze', dest = 'analyze', help = 'data to analyze (results)', choices = ['feature', 'object', 'classification', 'interaction'])
+
+# common options
+parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
+parser.add_argument('-n', dest = 'nObjects', help = 'number of objects/interactions to process', type = int)
 parser.add_argument('--dry', dest = 'dryRun', help = 'dry run of processing', action = 'store_true')
 parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)
 
+# analysis options
+parser.add_argument('--output', dest = 'output', help = 'kind of output to produce (interval means)', choices = ['figure', 'interval', 'event'])
+parser.add_argument('--min-user-duration', dest = 'minUserDuration', help = 'mininum duration we have to see the user to take into account in the analysis (s)', type = float, default = 0.1)
+parser.add_argument('--interval-duration', dest = 'intervalDuration', help = 'length of time interval to aggregate data (min)', type = float, default = 15.)
+parser.add_argument('--aggregation', dest = 'aggMethod', help = 'aggregation method per user/event and per interval', choices = ['mean', 'median', 'centile'], nargs = '*', default = ['median'])
+parser.add_argument('--aggregation-centile', dest = 'aggCentiles', help = 'centile(s) to compute from the observations', nargs = '*', type = int)
+dpi = 150
+# unit of analysis: site or video sequence?
+
+# safety analysis
+parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
+parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
+# override other tracking config, erase sqlite?
+
 # need way of selecting sites as similar as possible to sql alchemy syntax
 # override tracking.cfg from db
 # manage cfg files, overwrite them (or a subset of parameters)
@@ -52,6 +69,9 @@
 else:
     print('No video/site to process')
 
+if args.nProcesses > 1:
+    pool = Pool(args.nProcesses)
+
 #################################
 # Delete
 #################################
@@ -81,7 +101,6 @@
             else:
                 print('SQLite already exists: {}'.format(parentPath/vs.getDatabaseFilename()))
     else:
-        pool = Pool(args.nProcesses)
         for vs in videoSequences:
             if not (parentPath/vs.getDatabaseFilename()).exists() or args.process == 'object':
                 if args.configFilename is None:
@@ -125,29 +144,52 @@
 #################################
 # Analyze
 #################################
-if args.analyze == 'object': # user speed for now
-    medianSpeeds = {}
-    speeds85 = {}
-    minLength = 2*30
+if args.analyze == 'object':
+    # user speeds, accelerations
+    # aggregation per site
+    data = [] # list of observation per site-user with time
+    headers = ['sites', 'date', 'time', 'user_type']
+    aggFunctions = {}
+    for method in args.aggMethod:
+        if method == 'centile':
+            aggFunctions[method] = utils.aggregationFunction(method, args.aggCentiles)
+            for c in args.aggCentiles:
+                headers.append('{}{}'.format(method,c))
+        else:
+            aggFunctions[method] = utils.aggregationFunction(method)
+            headers.append(method)
     for vs in videoSequences:
-        if not vs.cameraView.siteIdx in medianSpeeds:
-            medianSpeeds[vs.cameraView.siteIdx] = []
-            speeds85[vs.cameraView.siteIdx] = []
+        d = vs.startTime.date()
+        t1 = vs.startTime.time()
+        minUserDuration = args.minUserDuration*vs.cameraView.cameraType.frameRate
         print('Extracting speed from '+vs.getDatabaseFilename())
-        objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')
+        objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object', args.nObjects)
         for o in objects:
-            if o.length() > minLength:
-                speeds = 30*3.6*percentile(o.getSpeeds(), [50, 85])
-                medianSpeeds[vs.cameraView.siteIdx].append(speeds[0])
-                speeds85[vs.cameraView.siteIdx].append(speeds[1])
-    for speeds, name in zip([medianSpeeds, speeds85], ['Median', '85th Centile']):
-        plt.ioff()
-        plt.figure()
-        plt.boxplot(list(speeds.values()), labels = [session.query(Site).get(siteId).name for siteId in speeds])
-        plt.ylabel(name+' Speeds (km/h)')
-        plt.savefig(name.lower()+'-speeds.png', dpi=150)
-        plt.close()
-
+            if o.length() > minUserDuration:
+                row = [vs.cameraView.siteIdx, d, utils.framesToTime(o.getFirstInstant(), vs.cameraView.cameraType.frameRate, t1), o.getUserType()]
+                tmp = o.getSpeeds()
+                for method,func in aggFunctions.items():
+                    aggSpeeds = vs.cameraView.cameraType.frameRate*3.6*func(tmp)
+                    if method == 'centile':
+                        row += aggSpeeds.tolist()
+                    else:
+                        row.append(aggSpeeds)
+            data.append(row)
+    data = DataFrame(data, columns = headers)
+    if args.siteIds is None:
+        siteIds = set([vs.cameraView.siteIdx for vs in videoSequences])
+    else:
+        siteIds = set(args.siteIds)
+    if args.output == 'figure':
+        for name in headers[4:]:
+            plt.ioff()
+            plt.figure()
+            plt.boxplot([data.loc[data['sites']==siteId, name] for siteId in siteIds], labels = [session.query(Site).get(siteId).name for siteId in siteIds])
+            plt.ylabel(name+' Speeds (km/h)')
+            plt.savefig(name.lower()+'-speeds.png', dpi=dpi)
+            plt.close()
+    elif args.output == 'event':
+        data.to_csv('speeds.csv', index = False)
 if args.analyze == 'interaction':
     indicatorIds = [2,5,7,10]
     conversionFactors = {2: 1., 5: 30.*3.6, 7:1./30, 10:1./30}