Mercurial Hosting > traffic-intelligence
comparison python/storage.py @ 871:6db83beb5350
work in progress to update gaussian mixtures
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Fri, 03 Feb 2017 16:26:18 -0500 |
parents | 2d6249fe905a |
children | c70adaeeddf5 |
comparison
equal
deleted
inserted
replaced
870:1535251a1f40 | 871:6db83beb5350 |
---|---|
565 ######################### | 565 ######################### |
566 # saving and loading for scene interpretation | 566 # saving and loading for scene interpretation |
567 ######################### | 567 ######################### |
568 | 568 |
569 def savePOIs(filename, gmm, gmmType, gmmId): | 569 def savePOIs(filename, gmm, gmmType, gmmId): |
570 '''Saves a Gaussian mixture model (of class sklearn.mixture.GMM) | 570 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
571 gmmType is a type of GMM, learnt either from beginnings or ends of trajectories''' | 571 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
572 connection = sqlite3.connect(filename) | 572 connection = sqlite3.connect(filename) |
573 cursor = connection.cursor() | 573 cursor = connection.cursor() |
574 if gmmType not in ['beginning', 'end']: | 574 if gmmType not in ['beginning', 'end']: |
575 print('Unknown POI type {}. Exiting'.format(gmmType)) | 575 print('Unknown POI type {}. Exiting'.format(gmmType)) |
576 import sys | 576 import sys |
577 sys.exit() | 577 sys.exit() |
578 try: | 578 try: |
579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') | 579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') |
580 for i in xrange(gmm.n_components): | 580 for i in xrange(gmm.n_components): |
581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covars_[i][0,0], gmm.covars_[i][0,1], gmm.covars_[i][1,0], gmm.covars_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) | 581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) |
582 connection.commit() | 582 connection.commit() |
583 except sqlite3.OperationalError as error: | 583 except sqlite3.OperationalError as error: |
584 printDBError(error) | 584 printDBError(error) |
585 connection.close() | 585 connection.close() |
586 | 586 |
595 gmmId = None | 595 gmmId = None |
596 gmm = [] | 596 gmm = [] |
597 for row in cursor: | 597 for row in cursor: |
598 if gmmId is None or row[10] != gmmId: | 598 if gmmId is None or row[10] != gmmId: |
599 if len(gmm) > 0: | 599 if len(gmm) > 0: |
600 tmp = mixture.GMM(len(gmm), covarianceType) | 600 tmp = mixture.GaussianMixture(len(gmm), covarianceType) |
601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) | 601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) |
602 tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) | 602 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) |
603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) | 603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) |
604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] | 604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] |
605 pois.append(tmp) | 605 pois.append(tmp) |
606 gaussian = {'type': row[1], | 606 gaussian = {'type': row[1], |
607 'mean': row[2:4], | 607 'mean': row[2:4], |
614 gmm.append({'type': row[1], | 614 gmm.append({'type': row[1], |
615 'mean': row[2:4], | 615 'mean': row[2:4], |
616 'covar': array(row[4:8]).reshape(2,2), | 616 'covar': array(row[4:8]).reshape(2,2), |
617 'weight': row[9]}) | 617 'weight': row[9]}) |
618 if len(gmm) > 0: | 618 if len(gmm) > 0: |
619 tmp = mixture.GMM(len(gmm), covarianceType) | 619 tmp = mixture.GaussianMixture(len(gmm), covarianceType) |
620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) | 620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) |
621 tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) | 621 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) |
622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) | 622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) |
623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] | 623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] |
624 pois.append(tmp) | 624 pois.append(tmp) |
625 except sqlite3.OperationalError as error: | 625 except sqlite3.OperationalError as error: |
626 printDBError(error) | 626 printDBError(error) |