view scripts/process.py @ 1009:0d29b75f74ea

cleaning
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 04 Jun 2018 10:25:30 -0400
parents 192de96e5255
children 16932cefabc1
line wrap: on
line source

#! /usr/bin/env python3

import sys, argparse
from pathlib import Path
from multiprocessing.pool import Pool

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from numpy import percentile

import storage, events, prediction, cvutils
from metadata import *

parser = argparse.ArgumentParser(description='This program manages the processing of several files based on a description of the sites and video data in an SQLite database following the metadata module.')
parser.add_argument('--db', dest = 'metadataFilename', help = 'name of the metadata file', required = True)
parser.add_argument('--videos', dest = 'videoIds', help = 'indices of the video sequences', nargs = '*', type = int)
parser.add_argument('--sites', dest = 'siteIds', help = 'indices of the video sequences', nargs = '*', type = int)
parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
parser.add_argument('-n', dest = 'nObjects', help = 'number of objects/interactions to process', type = int)
parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
# override other tracking config, erase sqlite?
parser.add_argument('--delete', dest = 'delete', help = 'data to delete', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--process', dest = 'process', help = 'data to process', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--display', dest = 'display', help = 'data to display (replay over video)', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--analyze', dest = 'analyze', help = 'data to analyze (results)', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--dry', dest = 'dryRun', help = 'dry run of processing', action = 'store_true')
parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)

# need way of selecting sites as similar as possible to sql alchemy syntax
# override tracking.cfg from db
# manage cfg files, overwrite them (or a subset of parameters)
# delete sqlite files
# info of metadata

args = parser.parse_args()

#################################
# Data preparation
#################################
session = connectDatabase(args.metadataFilename)
parentDir = Path(args.metadataFilename).parent # files are relative to metadata location
videoSequences = []
if args.videoIds is not None:
    videoSequences = [session.query(VideoSequence).get(videoId) for videoId in args.videoIds]
elif args.siteIds is not None:
    for siteId in args.siteIds:
        for site in getSite(session, siteId):
            for cv in site.cameraViews:
                videoSequences += cv.videoSequences
else:
    print('No video/site to process')

#################################
# Delete
#################################
if args.delete is not None:
    if args.delete == 'feature':
        pass
    elif args.delete in ['object', 'interaction']:
        #parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois', 'prototype'])
        for vs in videoSequences:
            storage.deleteFromSqlite(str(parentDir/vs.getDatabaseFilename()), args.delete)

#################################
# Process
#################################
if args.process in ['feature', 'object']: # tracking
    if args.nProcesses == 1:
        for vs in videoSequences:
            if not (parentDir/vs.getDatabaseFilename()).exists() or args.process == 'object':
                if args.configFilename is None:
                    configFilename = str(parentDir/vs.cameraView.getTrackingConfigurationFilename())
                else:
                    configFilename = args.configFilename
                if vs.cameraView.cameraType is None:
                    cvutils.tracking(configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), False, None, None, args.dryRun)
                else:
                    cvutils.tracking(configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, args.dryRun)
            else:
                print('SQLite already exists: {}'.format(parentDir/vs.getDatabaseFilename()))
    else:
        pool = Pool(args.nProcesses)
        for vs in videoSequences:
            if not (parentDir/vs.getDatabaseFilename()).exists() or args.process == 'object':
                if args.configFilename is None:
                    configFilename = str(parentDir/vs.cameraView.getTrackingConfigurationFilename())
                else:
                    configFilename = args.configFilename
                if vs.cameraView.cameraType is None:
                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), False, None, None, args.dryRun))
                else:
                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, args.dryRun))
            else:
                print('SQLite already exists: {}'.format(parentDir/vs.getDatabaseFilename()))
        pool.close()
        pool.join()

elif args.process == 'interaction':
    # safety analysis TODO make function in safety analysis script
    if args.predictionMethod == 'cvd':
        predictionParameters = prediction.CVDirectPredictionParameters()
    if args.predictionMethod == 'cve':
        predictionParameters = prediction.CVExactPredictionParameters()
    for vs in videoSequences:
        print('Processing '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentDir/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
        interactions = events.createInteractions(objects)
        #if args.nProcesses == 1:
        #print(str(parentDir/vs.cameraView.getTrackingConfigurationFilename()))
        params = storage.ProcessParameters(str(parentDir/vs.cameraView.getTrackingConfigurationFilename()))
        #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
        processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
        storage.saveIndicatorsToSqlite(str(parentDir/vs.getDatabaseFilename()), processed)
    # else:
    #     pool = Pool(processes = args.nProcesses)
    #     nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
    #     jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
    #     processed = []
    #     for job in jobs:
    #         processed += job.get()
    #     pool.close()

#################################
# Analyze
#################################
if args.analyze == 'object': # user speed for now
    medianSpeeds = {}
    speeds85 = {}
    minLength = 2*30
    for vs in videoSequences:
        if not vs.cameraView.siteIdx in medianSpeeds:
            medianSpeeds[vs.cameraView.siteIdx] = []
            speeds85[vs.cameraView.siteIdx] = []
        print('Extracting speed from '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentDir/vs.getDatabaseFilename()), 'object')
        for o in objects:
            if o.length() > minLength:
                speeds = 30*3.6*percentile(o.getSpeeds(), [50, 85])
                medianSpeeds[vs.cameraView.siteIdx].append(speeds[0])
                speeds85[vs.cameraView.siteIdx].append(speeds[1])
    for speeds, name in zip([medianSpeeds, speeds85], ['Median', '85th Centile']):
        plt.ioff()
        plt.figure()
        plt.boxplot(list(speeds.values()), labels = [session.query(Site).get(siteId).name for siteId in speeds])
        plt.ylabel(name+' Speeds (km/h)')
        plt.savefig(name.lower()+'-speeds.png', dpi=150)
        plt.close()

if args.analyze == 'interaction':
    indicatorIds = [2,5,7,10]
    conversionFactors = {2: 1., 5: 30.*3.6, 7:1./30, 10:1./30}
    maxIndicatorValue = {2: float('inf'), 5: float('inf'), 7:10., 10:10.}
    indicators = {}
    interactions = {}
    for vs in videoSequences:
        if not vs.cameraView.siteIdx in interactions:
            interactions[vs.cameraView.siteIdx] = []
            indicators[vs.cameraView.siteIdx] = {}
            for i in indicatorIds:
                indicators[vs.cameraView.siteIdx][i] = []
        interactions[vs.cameraView.siteIdx] += storage.loadInteractionsFromSqlite(str(parentDir/vs.getDatabaseFilename()))
        print(vs.getDatabaseFilename(), len(interactions[vs.cameraView.siteIdx]))
        for inter in interactions[vs.cameraView.siteIdx]:
            for i in indicatorIds:
                indic = inter.getIndicator(events.Interaction.indicatorNames[i])
                if indic is not None:
                    v = indic.getMostSevereValue()*conversionFactors[i]
                    if v < maxIndicatorValue[i]:
                        indicators[vs.cameraView.siteIdx][i].append(v)

    for i in indicatorIds:
        tmp = [indicators[siteId][i] for siteId in indicators]
        plt.ioff()
        plt.figure()
        plt.boxplot(tmp, labels = [session.query(Site).get(siteId).name for siteId in indicators])
        plt.ylabel(events.Interaction.indicatorNames[i]+' ('+events.Interaction.indicatorUnits[i]+')')
        plt.savefig(events.Interaction.indicatorNames[i]+'.png', dpi=150)
        plt.close()