diff scripts/learn-motion-patterns.py @ 818:181bcb6dad3a

added option to learn motion patterns and show to display results
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 21 Jun 2016 17:08:07 -0400
parents 0e875a7f5759
children f3ae72d86762
line wrap: on
line diff
--- a/scripts/learn-motion-patterns.py	Mon Jun 20 10:56:41 2016 -0400
+++ b/scripts/learn-motion-patterns.py	Tue Jun 21 17:08:07 2016 -0400
@@ -11,6 +11,7 @@
 #parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
 parser.add_argument('-d', dest = 'databaseFilename', help = 'name of the Sqlite database file', required = True)
 parser.add_argument('-t', dest = 'trajectoryType', help = 'type of trajectories to display', choices = ['objectfeatures', 'feature', 'object'], default = 'objectfeatures')
+parser.add_argument('--max-nobjectfeatures', dest = 'maxNObjectFeatures', help = 'maximum number of features per object to load', type = int, default = 3)
 parser.add_argument('-n', dest = 'nTrajectories', help = 'number of the object or feature trajectories to load', type = int, default = None)
 parser.add_argument('-e', dest = 'epsilon', help = 'distance for the similarity of trajectory points', type = float, required = True)
 parser.add_argument('--metric', dest = 'metric', help = 'metric for the similarity of trajectory points', default = 'cityblock') # default is manhattan distance
@@ -34,7 +35,7 @@
     features = []
     for o in objects:
         tmp = utils.sortByLength(o.getFeatures(), reverse = True)
-        features += tmp[:min(len(tmp), 3)]
+        features += tmp[:min(len(tmp), args.maxNObjectFeatures)]
     objects = features
 
 trajectories = [o.getPositions().asArray().T for o in objects]
@@ -50,8 +51,10 @@
 
 prototypeIndices, labels = ml.prototypeCluster(trajectories, similarities, args.minSimilarity, lambda x,y : lcss.computeNormalized(x, y), args.minClusterSize) # this line can be called again without reinitializing similarities
 
+print(ml.computeClusterSizes(labels, prototypeIndices, -1))
+
 if args.display:
-    from matplotlib.pyplot import figure
+    from matplotlib.pyplot import figure, show
     figure()
     for i,o in enumerate(objects):
         if i not in prototypeIndices:
@@ -61,5 +64,6 @@
                 o.plot(utils.colors[labels[i]])
     for i in prototypeIndices:
             objects[i].plot(utils.colors[i]+'o')
+    show()
 
 # TODO store the prototypes (if features, easy, if objects, info must be stored about the type)