diff scripts/learn-poi.py @ 916:7345f0d51faa

added display of paths
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 04 Jul 2017 17:36:24 -0400
parents 13434f5017dd
children c030f735c594
line wrap: on
line diff
--- a/scripts/learn-poi.py	Tue Jul 04 17:03:29 2017 -0400
+++ b/scripts/learn-poi.py	Tue Jul 04 17:36:24 2017 -0400
@@ -17,8 +17,9 @@
 parser.add_argument('--covariance-type', dest = 'covarianceType', help = 'type of covariance of Gaussian model', default = "full")
 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image')
 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.)
-parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance
-parser.add_argument('--assign', dest = 'assign', help = 'display points of interests', action = 'store_true')
+parser.add_argument('--display', dest = 'display', help = 'displays points of interests', action = 'store_true') # default is manhattan distance
+parser.add_argument('--assign', dest = 'assign', help = 'assigns the trajectories to the POIs and saves the assignments', action = 'store_true')
+parser.add_argument('--display-paths', dest = 'displayPaths', help = 'displays all possible origin destination if assignment is done', action = 'store_true')
 
 # TODO test Variational Bayesian Gaussian Mixture BayesianGaussianMixture
 
@@ -76,6 +77,21 @@
 
 if args.assign:
     storage.savePOIAssignments(args.databaseFilename, objects)
+    if args.displayPaths:
+        for i in xrange(args.nOriginClusters):
+            for j in xrange(args.nDestinationClusters):
+                odObjects = [o for o in objects if o.od[0] == i and o.od[1] == j]
+                if len(odObjects) > 0:
+                    fig = plt.figure()
+                    ax = fig.add_subplot(111)
+                    ml.plotGMM(models['beginning'].means_[i], models['beginning'].covariances_[i], i, fig, 'b')
+                    ml.plotGMM(models['end'].means_[j], models['end'].covariances_[j], j, fig, 'r')
+                    for o in odObjects:
+                        o.plot(withOrigin = True)
+                    plt.title('OD {} to {}'.format(i,j))
+                    plt.axis('equal')
+                    plt.show()
+
 
 if args.display:
     plt.axis('equal')