diff scripts/safety-analysis.py @ 943:b1e8453c207c

work on motion prediction using motion patterns
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 19 Jul 2017 18:02:38 -0400
parents a2f3f3ca241e
children 84ebe1b031f1
line wrap: on
line diff
--- a/scripts/safety-analysis.py	Tue Jul 18 18:01:16 2017 -0400
+++ b/scripts/safety-analysis.py	Wed Jul 19 18:02:38 2017 -0400
@@ -14,8 +14,8 @@
 parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file', required = True)
 parser.add_argument('-n', dest = 'nObjects', help = 'number of objects to analyse', type = int)
 # TODO analyze only 
-parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'proto'])
-parser.add_argument('--cfg', dest = 'prototypeDatabaseFilename', help = 'name of the database containing the prototypes')
+parser.add_argument('--prediction-method', dest = 'predictionMethod', help = 'prediction method (constant velocity (cvd: vector computation (approximate); cve: equation solving; cv: discrete time (approximate)), normal adaptation, point set prediction)', choices = ['cvd', 'cve', 'cv', 'na', 'ps', 'mp'])
+parser.add_argument('--prototypeDatabaseFilename', dest = 'prototypeDatabaseFilename', help = 'name of the database containing the prototypes')
 parser.add_argument('--pet', dest = 'computePET', help = 'computes PET', action = 'store_true')
 parser.add_argument('--display-cp', dest = 'displayCollisionPoints', help = 'display collision points', action = 'store_true')
 parser.add_argument('--nthreads', dest = 'nProcesses', help = 'number of processes to run in parallel', type = int, default = 1)
@@ -48,10 +48,15 @@
                                                                            params.useFeaturesForPrediction)
 elif predictionMethod == 'ps':
     predictionParameters = prediction.PointSetPredictionParameters(params.maxPredictedSpeed)
-elif predictionMethod == 'proto':
-    prototypes = storage.loadPrototypesFromSqlite(args.prototypeDatabaseFilename)
+elif predictionMethod == 'mp':
+    if args.prototypeDatabaseFilename is None:
+        prototypes = storage.loadPrototypesFromSqlite(params.databaseFilename)
+    else:
+        prototypes = storage.loadPrototypesFromSqlite(args.prototypeDatabaseFilename)
     for p in prototypes:
-        p.getMovingObject().getPositions().computeCumulativeDistances()
+        p.getMovingObject().computeCumulativeDistances()
+    predictionParameters = prediction.PrototypePredictionParameters(prototypes, params.nPredictedTrajectories, 2., 0.5, 'cityblock', 10, params.constantSpeedPrototypePrediction, params.useFeaturesForPrediction)
+# else:
 # no else required, since parameters is required as argument
 
 # evasiveActionPredictionParameters = prediction.EvasiveActionPredictionParameters(params.maxPredictedSpeed, 
@@ -61,7 +66,7 @@
 #                                                                                  params.maxExtremeSteering,
 #                                                                                  params.useFeaturesForPrediction)
 
-objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or (predictionMethod == 'ps')))
+objects = storage.loadTrajectoriesFromSqlite(params.databaseFilename, 'object', args.nObjects, withFeatures = (params.useFeaturesForPrediction or predictionMethod == 'ps' or predictionMethod == 'mp'))
 # if params.useFeaturesForPrediction:
 #     features = storage.loadTrajectoriesFromSqlite(params.databaseFilename,'feature') # needed if normal adaptation
 #     for obj in objects:
@@ -70,7 +75,7 @@
 interactions = events.createInteractions(objects)
 for inter in interactions:
     inter.computeIndicators()
-    inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, nProcesses = args.nProcesses)
+    inter.computeCrossingsCollisions(predictionParameters, params.collisionDistance, params.predictionTimeHorizon, params.crossingZones, nProcesses = args.nProcesses, debug = True)
 
 if args.computePET:
     for inter in interactions: