Mercurial Hosting > traffic-intelligence
comparison scripts/learn-poi.py @ 913:1cd878812529
work in progress
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 28 Jun 2017 17:57:06 -0400 |
parents | 6db83beb5350 |
children | f228fd649644 |
comparison
equal
deleted
inserted
replaced
912:fd057a6b04db | 913:1cd878812529 |
---|---|
15 parser.add_argument('-ndestinations', dest = 'nDestinationClusters', help = 'number of clusters for trajectory destinations (=norigins if not provided)', type = int) | 15 parser.add_argument('-ndestinations', dest = 'nDestinationClusters', help = 'number of clusters for trajectory destinations (=norigins if not provided)', type = int) |
16 parser.add_argument('--covariance-type', dest = 'covarianceType', help = 'type of covariance of Gaussian model', default = "full") | 16 parser.add_argument('--covariance-type', dest = 'covarianceType', help = 'type of covariance of Gaussian model', default = "full") |
17 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image') | 17 parser.add_argument('-w', dest = 'worldImageFilename', help = 'filename of the world image') |
18 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.) | 18 parser.add_argument('-u', dest = 'unitsPerPixel', help = 'number of units of distance per pixel', type = float, default = 1.) |
19 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance | 19 parser.add_argument('--display', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance |
20 parser.add_argument('--assign', dest = 'display', help = 'display points of interests', action = 'store_true') # default is manhattan distance | |
20 | 21 |
21 args = parser.parse_args() | 22 args = parser.parse_args() |
22 | 23 |
23 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType) | 24 objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, args.trajectoryType) |
24 | 25 |
35 if args.nDestinationClusters is None: | 36 if args.nDestinationClusters is None: |
36 nDestinationClusters = args.nOriginClusters | 37 nDestinationClusters = args.nOriginClusters |
37 | 38 |
38 gmmId=0 | 39 gmmId=0 |
39 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters], | 40 for nClusters, points, gmmType in zip([args.nOriginClusters, nDestinationClusters], |
40 [beginnings, ends], | 41 [beginnings, ends], |
41 ['beginning', 'end']): | 42 ['beginning', 'end']): |
42 # estimation | 43 # estimation |
43 gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType) | 44 gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType) |
44 model=gmm.fit(beginnings) | 45 model=gmm.fit(points) |
45 if not model.converged_: | 46 if not model.converged_: |
46 print('Warning: model for '+gmmType+' points did not converge') | 47 print('Warning: model for '+gmmType+' points did not converge') |
47 # plot | 48 # plot |
48 if args.display: | 49 if args.display: |
49 fig = plt.figure() | 50 fig = plt.figure() |
50 if args.worldImageFilename is not None and args.unitsPerPixel is not None: | 51 if args.worldImageFilename is not None and args.unitsPerPixel is not None: |
51 img = plt.imread(args.worldImageFilename) | 52 img = plt.imread(args.worldImageFilename) |
52 plt.imshow(img) | 53 plt.imshow(img) |
53 labels = ml.plotGMMClusters(model, points, fig, nUnitsPerPixel = args.unitsPerPixel) | 54 labels = model.predict(points) |
55 labels = ml.plotGMMClusters(model, labels, points, fig, nUnitsPerPixel = args.unitsPerPixel) | |
54 plt.axis('image') | 56 plt.axis('image') |
55 plt.title(gmmType) | 57 plt.title(gmmType) |
56 print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(model.n_components)))) | 58 print(gmmType+' Clusters:\n{}'.format(ml.computeClusterSizes(labels, range(model.n_components)))) |
57 # save | 59 # save |
58 storage.savePOIs(args.databaseFilename, model, gmmType, gmmId) | 60 storage.savePOIs(args.databaseFilename, model, gmmType, gmmId) |
61 # save assignments | |
62 if args.assign: | |
63 pass | |
59 gmmId += 1 | 64 gmmId += 1 |
60 | 65 |
61 if args.display: | 66 if args.display: |
62 plt.axis('equal') | 67 plt.axis('equal') |
63 plt.show() | 68 plt.show() |