view scripts/process.py @ 987:f026ce2af637

found bug with direct ttc computation
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 07 Mar 2018 23:37:00 -0500
parents 3be8aaa47651
children dc0be55e2bf5
line wrap: on
line source

#! /usr/bin/env python

import sys, argparse
from pathlib2 import Path

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import storage, events, prediction
from metadata import *

parser = argparse.ArgumentParser(description='This program manages the processing of several files based on a description of the sites and video data in an SQLite database following the metadata module.')
parser.add_argument('--db', dest = 'metadataFilename', help = 'name of the metadata file', required = True)
parser.add_argument('--videos', dest = 'videoIds', help = 'indices of the video sequences', nargs = '*', type = int)
parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
parser.add_argument('--delete', dest = 'delete', help = 'data to delete', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--process', dest = 'process', help = 'data to process', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--analyze', dest = 'analyze', help = 'data to analyze (results)', choices = ['feature', 'object', 'classification', 'interaction'])

# need way of selecting sites as similar as possible to sql alchemy syntax
# override tracking.cfg from db
# manage cfg files, overwrite them (or a subset of parameters)
# delete sqlite files

parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)

args = parser.parse_args()
# files are relative to metadata location

session = connectDatabase(args.metadataFilename)
parentDir = Path(args.metadataFilename).parent

if args.delete is not None:
    if args.delete in ['object', 'interaction']:
        #parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois', 'prototype'])
        for videoId in args.videoIds:
            vs = session.query(VideoSequence).get(videoId)
            storage.deleteFromSqlite(str(parentDir/vs.getDatabaseFilename()), args.delete)

if args.process == 'interaction':
    # safety analysis TODO make function in safety analysis script
    if args.predictionMethod == 'cvd':
        predictionParameters = prediction.CVDirectPredictionParameters()
    if args.predictionMethod == 'cve':
        predictionParameters = prediction.CVExactPredictionParameters()
    for videoId in args.videoIds:
        vs = session.query(VideoSequence).get(videoId)
        print('Processing '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentDir/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
        interactions = events.createInteractions(objects)
        #if args.nProcesses == 1:
        #print(str(parentDir/vs.cameraView.getTrackingConfigurationFilename()))
        params = storage.ProcessParameters(str(parentDir/vs.cameraView.getTrackingConfigurationFilename()))
        #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
        processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
        storage.saveIndicatorsToSqlite(str(parentDir/vs.getDatabaseFilename()), processed)
    # else:
    #     pool = Pool(processes = args.nProcesses)
    #     nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
    #     jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
    #     processed = []
    #     for job in jobs:
    #         processed += job.get()
    #     pool.close()

if args.analyze == 'interaction':
    indicatorIds = [2,5,7,10]
    conversionFactors = {2: 1., 5: 30.*3.6, 7:1./30, 10:1./30}
    indicators = {}
    interactions = {}
    for videoId in args.videoIds:
        vs = session.query(VideoSequence).get(videoId)
        if not vs.cameraView.siteIdx in interactions:
            interactions[vs.cameraView.siteIdx] = []
            indicators[vs.cameraView.siteIdx] = {}
            for i in indicatorIds:
                indicators[vs.cameraView.siteIdx][i] = []
        interactions[vs.cameraView.siteIdx] += storage.loadInteractionsFromSqlite(str(parentDir/vs.getDatabaseFilename()))
        print(vs.getDatabaseFilename(), len(interactions[vs.cameraView.siteIdx]))
        for inter in interactions[vs.cameraView.siteIdx]:
            for i in indicatorIds:
                indic = inter.getIndicator(events.Interaction.indicatorNames[i])
                if indic is not None:
                    indicators[vs.cameraView.siteIdx][i].append(indic.getMostSevereValue()*conversionFactors[i])

    for i in indicatorIds:
        tmp = [indicators[siteId][i] for siteId in indicators]
        plt.ioff()
        plt.figure()
        plt.boxplot(tmp, labels = [session.query(Site).get(siteId).name for siteId in indicators])
        plt.title(events.Interaction.indicatorNames[i])
        plt.savefig(events.Interaction.indicatorNames[i]+'.png')
        plt.close()