view scripts/process.py @ 1004:75601be6019f

work on process
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Sun, 03 Jun 2018 00:21:18 -0400
parents 75af46516b2b
children 666b38437d9a
line wrap: on
line source

#! /usr/bin/env python3

import sys, argparse
from pathlib import Path
from multiprocessing.pool import Pool

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from numpy import percentile

import storage, events, prediction, cvutils
from metadata import *

parser = argparse.ArgumentParser(description='This program manages the processing of several files based on a description of the sites and video data in an SQLite database following the metadata module.')
parser.add_argument('--db', dest = 'metadataFilename', help = 'name of the metadata file', required = True)
parser.add_argument('--videos', dest = 'videoIds', help = 'indices of the video sequences', nargs = '*', type = int)
parser.add_argument('--sites', dest = 'siteIds', help = 'indices of the video sequences', nargs = '*', type = int)
parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
parser.add_argument('-n', dest = 'nObjects', help = 'number of objects/interactions to process', type = int)
parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
# override other tracking config, erase sqlite?
parser.add_argument('--delete', dest = 'delete', help = 'data to delete', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--process', dest = 'process', help = 'data to process', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--display', dest = 'display', help = 'data to display (replay over video)', choices = ['feature', 'object', 'classification', 'interaction'])
parser.add_argument('--analyze', dest = 'analyze', help = 'data to analyze (results)', choices = ['feature', 'object', 'classification', 'interaction'])

# need way of selecting sites as similar as possible to sql alchemy syntax
# override tracking.cfg from db
# manage cfg files, overwrite them (or a subset of parameters)
# delete sqlite files

# info of metadata

parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)

args = parser.parse_args()
# files are relative to metadata location

session = connectDatabase(args.metadataFilename)
parentDir = Path(args.metadataFilename).parent

if args.delete is not None:
    if args.delete in ['object', 'interaction']:
        #parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois', 'prototype'])
        for videoId in args.videoIds:
            vs = session.query(VideoSequence).get(videoId)
            storage.deleteFromSqlite(str(parentDir/vs.getDatabaseFilename()), args.delete)

import time
def track(i):
    time.sleep(1)
    print('process {}'.format(i))

            
if args.process in ['feature', 'object']: # tracking
    if args.videoIds is not None:
        videoSequences = [session.query(VideoSequence).get(videoId) for videoId in args.videoIds]
    elif args.siteIds is not None:
        videoSequences = []
        for siteId in args.siteIds:
            for site in getSite(session, siteId):
                for cv in site.cameraViews:
                    videoSequences += cv.videoSequences
    else:
        print('No video/site to process')
        videoSequences = []
    if args.nProcesses == 1:
        pass
    else:
        pool = Pool(args.nProcesses)
        for vs in videoSequences:
            if not (parentDir/vs.getDatabaseFilename()).exists():
                if args.configFilename is None:
                    configFilename = vs.cameraView.getTrackingConfigurationFilename()
                else:
                    configFilename = args.configFilename
                if vs.cameraView.cameraType is None:
                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), False, None, None, True))
                else:
                    pool.apply_async(cvutils.tracking, args = (configFilename, args.process == 'object', str(parentDir.absolute()/vs.getVideoSequenceFilename()), str(parentDir.absolute()/vs.getDatabaseFilename()), str(parentDir.absolute()/vs.cameraView.getHomographyFilename()), str(parentDir.absolute()/vs.cameraView.getMaskFilename()), True, vs.cameraView.cameraType.intrinsicCameraMatrix, vs.cameraView.cameraType.distortionCoefficients, True))
            else:
                print('SQLite already exists: {}'.format(parentDir/vs.getDatabaseFilename()))
        pool.close()
        pool.join()

elif args.process == 'interaction':
    # safety analysis TODO make function in safety analysis script
    if args.predictionMethod == 'cvd':
        predictionParameters = prediction.CVDirectPredictionParameters()
    if args.predictionMethod == 'cve':
        predictionParameters = prediction.CVExactPredictionParameters()
    for videoId in args.videoIds:
        vs = session.query(VideoSequence).get(videoId)
        print('Processing '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentDir/vs.getDatabaseFilename()), 'object')#, args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
        interactions = events.createInteractions(objects)
        #if args.nProcesses == 1:
        #print(str(parentDir/vs.cameraView.getTrackingConfigurationFilename()))
        params = storage.ProcessParameters(str(parentDir/vs.cameraView.getTrackingConfigurationFilename()))
        #print(len(interactions), args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)
        processed = events.computeIndicators(interactions, True, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)
        storage.saveIndicatorsToSqlite(str(parentDir/vs.getDatabaseFilename()), processed)
    # else:
    #     pool = Pool(processes = args.nProcesses)
    #     nInteractionPerProcess = int(np.ceil(len(interactions)/float(args.nProcesses)))
    #     jobs = [pool.apply_async(events.computeIndicators, args = (interactions[i*nInteractionPerProcess:(i+1)*nInteractionPerProcess], not args.noMotionPrediction, args.computePET, predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, False, None)) for i in range(args.nProcesses)]
    #     processed = []
    #     for job in jobs:
    #         processed += job.get()
    #     pool.close()

if args.analyze == 'object': # user speed for now
    medianSpeeds = {}
    speeds85 = {}
    minLength = 2*30
    for videoId in args.videoIds:
        vs = session.query(VideoSequence).get(videoId)
        if not vs.cameraView.siteIdx in medianSpeeds:
            medianSpeeds[vs.cameraView.siteIdx] = []
            speeds85[vs.cameraView.siteIdx] = []
        print('Extracting speed from '+vs.getDatabaseFilename())
        objects = storage.loadTrajectoriesFromSqlite(str(parentDir/vs.getDatabaseFilename()), 'object')
        for o in objects:
            if o.length() > minLength:
                speeds = 30*3.6*percentile(o.getSpeeds(), [50, 85])
                medianSpeeds[vs.cameraView.siteIdx].append(speeds[0])
                speeds85[vs.cameraView.siteIdx].append(speeds[1])
    for speeds, name in zip([medianSpeeds, speeds85], ['Median', '85th Centile']):
        plt.ioff()
        plt.figure()
        plt.boxplot(list(speeds.values()), labels = [session.query(Site).get(siteId).name for siteId in speeds])
        plt.ylabel(name+' Speeds (km/h)')
        plt.savefig(name.lower()+'-speeds.png', dpi=150)
        plt.close()

if args.analyze == 'interaction':
    indicatorIds = [2,5,7,10]
    conversionFactors = {2: 1., 5: 30.*3.6, 7:1./30, 10:1./30}
    maxIndicatorValue = {2: float('inf'), 5: float('inf'), 7:10., 10:10.}
    indicators = {}
    interactions = {}
    for videoId in args.videoIds:
        vs = session.query(VideoSequence).get(videoId)
        if not vs.cameraView.siteIdx in interactions:
            interactions[vs.cameraView.siteIdx] = []
            indicators[vs.cameraView.siteIdx] = {}
            for i in indicatorIds:
                indicators[vs.cameraView.siteIdx][i] = []
        interactions[vs.cameraView.siteIdx] += storage.loadInteractionsFromSqlite(str(parentDir/vs.getDatabaseFilename()))
        print(vs.getDatabaseFilename(), len(interactions[vs.cameraView.siteIdx]))
        for inter in interactions[vs.cameraView.siteIdx]:
            for i in indicatorIds:
                indic = inter.getIndicator(events.Interaction.indicatorNames[i])
                if indic is not None:
                    v = indic.getMostSevereValue()*conversionFactors[i]
                    if v < maxIndicatorValue[i]:
                        indicators[vs.cameraView.siteIdx][i].append(v)

    for i in indicatorIds:
        tmp = [indicators[siteId][i] for siteId in indicators]
        plt.ioff()
        plt.figure()
        plt.boxplot(tmp, labels = [session.query(Site).get(siteId).name for siteId in indicators])
        plt.ylabel(events.Interaction.indicatorNames[i]+' ('+events.Interaction.indicatorUnits[i]+')')
        plt.savefig(events.Interaction.indicatorNames[i]+'.png', dpi=150)
        plt.close()