view scripts/process.py @ 1026:73b124160911 v0.2

more plumbing
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 13 Jun 2018 14:55:22 -0400
parents a13f47c8931d
children cc5cb04b04b0
line wrap: on
line source

#! /usr/bin/env python3

import sys, argparse
from pathlib import Path
from multiprocessing.pool import Pool

#import matplotlib
#atplotlib.use('Agg')
import matplotlib.pyplot as plt
from numpy import percentile
from pandas import DataFrame

import storage, events, prediction, cvutils, utils
from metadata import *

parser = argparse.ArgumentParser(description='This program manages the processing of several files based on a description of the sites and video data in an SQLite database following the metadata module.')
# input
parser.add_argument('--db', dest = 'metadataFilename', help = 'name of the metadata file', required = True)
parser.add_argument('--videos', dest = 'videoIds', help = 'indices of the video sequences', nargs = '*', type = int)
parser.add_argument('--sites', dest = 'siteIds', help = 'indices of the video sequences', nargs = '*', type = int)

# main function
parser.add_argument('--delete', dest = 'delete', help = 'data to delete', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--process', dest = 'process', help = 'data to process', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--display', dest = 'display', help = 'data to display (replay over video)', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--analyze', dest = 'analyze', help = 'data to analyze (results)', choices = ['feature', 'object', 'classification', 'interaction'])

# common options
parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
parser.add_argument('-n', dest = 'nObjects', help = 'number of objects/interactions to process', type = int)
parser.add_argument('--dry', dest = 'dryRun', help = 'dry run of processing', action = 'store_true')
parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)

# analysis options
parser.add_argument('--output', dest = 'output', help = 'kind of output to produce (interval means)', choices = ['figure', 'interval', 'event'])
parser.add_argument('--min-user-duration', dest = 'minUserDuration', help = 'mininum duration we have to see the user to take into account in the analysis (s)', type = float, default = 0.1)
parser.add_argument('--interval-duration', dest = 'intervalDuration', help = 'length of time interval to aggregate data (min)', type = float, default = 15.)
parser.add_argument('--aggregation', dest = 'aggMethod', help = 'aggregation method per user/event and per interval', choices = ['mean', 'median', 'centile'], nargs = '*', default = ['median'])
parser.add_argument('--aggregation-centile', dest = 'aggCentiles', help = 'centile(s) to compute from the observations', nargs = '*', type = int)
dpi = 150
# unit of analysis: site or video sequence?

# safety analysis
parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
# override other tracking config, erase sqlite?

# need way of selecting sites as similar as possible to sql alchemy syntax
# override tracking.cfg from db
# manage cfg files, overwrite them (or a subset of parameters)
# delete sqlite files
# info of metadata

args = parser.parse_args()

#################################
# Data preparation
#################################
session = connectDatabase(args.metadataFilename)
parentPath = Path(args.metadataFilename).parent # files are relative to metadata location
videoSequences = []
if args.videoIds is not None:
    videoSequences = [session.query(VideoSequence).get(videoId) for videoId in args.videoIds]
elif args.siteIds is not None:
    for siteId in args.siteIds:
        for site in getSite(session, siteId):
            for cv in site.cameraViews:
                videoSequences += cv.videoSequences
else:
    print('No video/site to process')

if args.nProcesses > 1:
    pool = Pool(args.nProcesses)

#################################
# Delete
#################################
if args.delete is not None:
    if args.delete == 'feature':
        response = input('Are you sure you want to delete the tracking results (SQLite files) of all these sites (y/n)?')
        if response == 'y':
            for vs in videoSequences:
                p = parentPath.absolute()/vs.getDatabaseFilename()
                p.unlink()
    elif args.delete in ['object', 'interaction']:
        #parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois', 'prototype'])
        for vs in videoSequences:
            storage.deleteFromSqlite(str(parentPath/vs.getDatabaseFilename()), args.delete)

#################################
# Process
#################################
if args.process in ['feature', 'object']: # tracking
    if args.nProcesses == 1:
        for vs in videoSequences:
            if not (parentPath/vs.getDatabaseFilename()).exists() or args.process == 'object':
                if args.configFilename is None:
                    configFilename = str(parentPath/vs.cameraView.getTrackingConfigurationFilename())
                else:
                    configFilename = args.configFilename
                if vs.cameraView.cameraType is None:
                    cvutils.tracking(configFilename, args.process == 'object', str(parentPath.absolute()/vs.getVideoSequenceFilename()), str(parentPath.absolute()/vs.getDatabaseFilename()), str(parentPath.absolute()/vs.cameraView.getHomographyFilename()), str(parentPath.absolute()/vs.cameraView.getMaskFilename()), False, None, None, args.dryRun)
                else:
                    cvutils.tracking(configFilename, args.process == 'object', str(parentPath.absolute()/vs.getVideoSequenceFilename()), str(parentPath.absolute()/vs.getDatabaseFilename()), str(parentPath.absolute()/vs.cameraView.getHomographyFilename()), str(parentPath.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, args.dryRun)
            else:
                print('SQLite already exists: {}'.format(parentPath/vs.getDatabaseFilename()))
    else:
        for vs in videoSequences:
            if not (parentPath/vs.getDatabaseFilename()).exists() or args.process == 'object':
                if args.configFilename is None:
                    configFilename = str(parentPath/vs.cameraView.getTrackingConfigurationFilename())
                else:
                    configFilename = args.configFilename
                if vs.cameraView.cameraType is None:
                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentPath.absolute()/vs.getVideoSequenceFilename()), str(parentPath.absolute()/vs.getDatabaseFilename()), str(parentPath.absolute()/vs.cameraView.getHomographyFilename()), str(parentPath.absolute()/vs.cameraView.getMaskFilename()), False, None, None, args.dryRun))
                else:
                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentPath.absolute()/vs.getVideoSequenceFilename()), str(parentPath.absolute()/vs.getDatabaseFilename()), str(parentPath.absolute()/vs.cameraView.getHomographyFilename()), str(parentPath.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, args.dryRun))
            else:
                print('SQLite already exists: {}'.format(parentPath/vs.getDatabaseFilename()))
        pool.close()
        pool.join()

elif args.process == 'interaction':
    # safety analysis TODO make function in safety analysis script
    if args.predictionMethod == 'cvd':
        predictionParameters = prediction.CVDirectPredictionParameters()
    if args.predictionMethod == 'cve':
        predictionParameters = prediction.CVExactPredictionParameters()
    for vs in videoSequences:
        print('Processing '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
        interactions = events.createInteractions(objects)
        #if args.nProcesses == 1:
        #print(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
        params = storage.ProcessParameters(str(parentPath/vs.cameraView.getTrackingConfigurationFilename()))
        #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
        processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
        storage.saveIndicatorsToSqlite(str(parentPath/vs.getDatabaseFilename()), processed)
    # else:
    #     pool = Pool(processes = args.nProcesses)
    #     nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
    #     jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
    #     processed = []
    #     for job in jobs:
    #         processed += job.get()
    #     pool.close()

#################################
# Analyze
#################################
if args.analyze == 'object':
    # user speeds, accelerations
    # aggregation per site
    data = [] # list of observation per site-user with time
    headers = ['sites', 'date', 'time', 'user_type']
    aggFunctions = {}
    for method in args.aggMethod:
        if method == 'centile':
            aggFunctions[method] = utils.aggregationFunction(method, args.aggCentiles)
            for c in args.aggCentiles:
                headers.append('{}{}'.format(method,c))
        else:
            aggFunctions[method] = utils.aggregationFunction(method)
            headers.append(method)
    for vs in videoSequences:
        d = vs.startTime.date()
        t1 = vs.startTime.time()
        minUserDuration = args.minUserDuration*vs.cameraView.cameraType.frameRate
        print('Extracting speed from '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentPath/vs.getDatabaseFilename()), 'object', args.nObjects)
        for o in objects:
            if o.length() > minUserDuration:
                row = [vs.cameraView.siteIdx, d, utils.framesToTime(o.getFirstInstant(), vs.cameraView.cameraType.frameRate, t1), o.getUserType()]
                tmp = o.getSpeeds()
                for method,func in aggFunctions.items():
                    aggSpeeds = vs.cameraView.cameraType.frameRate*3.6*func(tmp)
                    if method == 'centile':
                        row += aggSpeeds.tolist()
                    else:
                        row.append(aggSpeeds)
            data.append(row)
    data = DataFrame(data, columns = headers)
    if args.siteIds is None:
        siteIds = set([vs.cameraView.siteIdx for vs in videoSequences])
    else:
        siteIds = set(args.siteIds)
    if args.output == 'figure':
        for name in headers[4:]:
            plt.ioff()
            plt.figure()
            plt.boxplot([data.loc[data['sites']==siteId, name] for siteId in siteIds], labels = [session.query(Site).get(siteId).name for siteId in siteIds])
            plt.ylabel(name+' Speeds (km/h)')
            plt.savefig(name.lower()+'-speeds.png', dpi=dpi)
            plt.close()
    elif args.output == 'event':
        data.to_csv('speeds.csv', index = False)
if args.analyze == 'interaction':
    indicatorIds = [2,5,7,10]
    conversionFactors = {2: 1., 5: 30.*3.6, 7:1./30, 10:1./30}
    maxIndicatorValue = {2: float('inf'), 5: float('inf'), 7:10., 10:10.}
    indicators = {}
    interactions = {}
    for vs in videoSequences:
        if not vs.cameraView.siteIdx in interactions:
            interactions[vs.cameraView.siteIdx] = []
            indicators[vs.cameraView.siteIdx] = {}
            for i in indicatorIds:
                indicators[vs.cameraView.siteIdx][i] = []
        interactions[vs.cameraView.siteIdx] += storage.loadInteractionsFromSqlite(str(parentPath/vs.getDatabaseFilename()))
        print(vs.getDatabaseFilename(), len(interactions[vs.cameraView.siteIdx]))
        for inter in interactions[vs.cameraView.siteIdx]:
            for i in indicatorIds:
                indic = inter.getIndicator(events.Interaction.indicatorNames[i])
                if indic is not None:
                    v = indic.getMostSevereValue()*conversionFactors[i]
                    if v < maxIndicatorValue[i]:
                        indicators[vs.cameraView.siteIdx][i].append(v)

    for i in indicatorIds:
        tmp = [indicators[siteId][i] for siteId in indicators]
        plt.ioff()
        plt.figure()
        plt.boxplot(tmp, labels = [session.query(Site).get(siteId).name for siteId in indicators])
        plt.ylabel(events.Interaction.indicatorNames[i]+' ('+events.Interaction.indicatorUnits[i]+')')
        plt.savefig(events.Interaction.indicatorNames[i]+'.png', dpi=150)
        plt.close()