view scripts/safety-analysis.py @ 398:3399bd48cb40

Ajout d'une méthode pour obtenir le nombre de FPS Méthode de capture des trames vidéos plus résistante aux erreur Utilisation d'un dictionnaire pour les fichier de configuration afin de garder le nom des sections
author Jean-Philippe Jodoin <jpjodoin@gmail.com>
date Mon, 29 Jul 2013 13:46:07 -0400
parents 91679eb2ff2c
children 6551a3cf1750
line wrap: on
line source

#! /usr/bin/env python

import utils, storage, prediction, events, moving

import sys, argparse

import matplotlib.pyplot as plt
import numpy as np

# todo: very slow if too many predicted trajectories
# add computation of probality of unsucessful evasive action

parser = argparse.ArgumentParser(description='The program processes indicators for all pairs of road users in the scene')
parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (vector computation), constant velocity, normal adaptation, point set prediction)', choices = ['cvd', 'cv', 'na', 'ps'])
parser.add_argument('--display-cp', dest = 'displayCollisionPoints', help = 'display collision points')
args = parser.parse_args()

params = utils.TrackingParameters()
params.loadConfigFile(args.configFilename)

# parameters for prediction methods
if args.predictionMethod:
    predictionMethod = args.predictionMethod
else:
    predictionMethod = params.predictionMethod

if predictionMethod == 'cvd':
    predictionParameters = prediction.CVDirectPredictionParameters()
elif predictionMethod == 'cv':
    predictionParameters = prediction.ConstantPredictionParameters(params.maxPredictedSpeed)
elif predictionMethod == 'na':
    predictionParameters = prediction.NormalAdaptationPredictionParameters(params.maxPredictedSpeed, 
                                                                           params.nPredictedTrajectories, 
                                                                           params.maxAcceleration,
                                                                           params.maxSteering,
                                                                           params.useFeaturesForPrediction)
elif predictionMethod == 'ps':
    predictionParameters = prediction.PointSetPredictionParameters(params.nPredictedTrajectories,
                                                                   params.maxPredictedSpeed)
else:
    print('Prediction method {} is not valid. See help.'.format(predictionMethod))
    sys.exit()

evasiveActionPredictionParameters = prediction.EvasiveActionPredictionParameters(params.maxPredictedSpeed, 
                                                                                 params.nPredictedTrajectories, 
                                                                                 params.minAcceleration,
                                                                                 params.maxAcceleration,
                                                                                 params.maxSteering,
                                                                                 params.useFeaturesForPrediction)

objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename,'object')
if params.useFeaturesForPrediction:
    features = storage.loadTrajectoriesFromSqlite(params.databaseFilename,'feature') # needed if normal adaptation
    for obj in objects:
        obj.setFeatures(features)

interactions = events.createInteractions(objects)
for inter in interactions:
    inter.computeIndicators()
    inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones)

storage.saveIndicators(params.databaseFilename, interactions)

if args.displayCollisionPoints:
    plt.figure()
    allCollisionPoints = []
    for inter in interactions:
        for collisionPoints in inter.collisionPoints.values():
            allCollisionPoints += collisionPoints
    moving.Point.plotAll(allCollisionPoints)
    plt.axis('equal')