Mercurial Hosting > traffic-intelligence
comparison python/storage.py @ 872:c70adaeeddf5
solved issue with latest version of scikit-learn
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 08 Feb 2017 16:32:15 -0500 |
parents | 6db83beb5350 |
children | f9ea5083588e |
comparison
equal
deleted
inserted
replaced
871:6db83beb5350 | 872:c70adaeeddf5 |
---|---|
574 if gmmType not in ['beginning', 'end']: | 574 if gmmType not in ['beginning', 'end']: |
575 print('Unknown POI type {}. Exiting'.format(gmmType)) | 575 print('Unknown POI type {}. Exiting'.format(gmmType)) |
576 import sys | 576 import sys |
577 sys.exit() | 577 sys.exit() |
578 try: | 578 try: |
579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') | 579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') |
580 for i in xrange(gmm.n_components): | 580 for i in xrange(gmm.n_components): |
581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) | 581 cursor.execute('INSERT INTO gaussians2d VALUES({}, {}, \'{}\', {}, {}, \'{}\', \'{}\', {}, \'{}\')'.format(gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) |
582 connection.commit() | 582 connection.commit() |
583 except sqlite3.OperationalError as error: | 583 except sqlite3.OperationalError as error: |
584 printDBError(error) | 584 printDBError(error) |
585 connection.close() | 585 connection.close() |
586 | 586 |
587 def loadPOIs(filename): | 587 def loadPOIs(filename): |
588 'Loads all 2D Gaussians in the database' | 588 'Loads all 2D Gaussians in the database' |
589 from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields | 589 from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields |
590 from ast import literal_eval | |
590 connection = sqlite3.connect(filename) | 591 connection = sqlite3.connect(filename) |
591 cursor = connection.cursor() | 592 cursor = connection.cursor() |
592 pois = [] | 593 pois = [] |
593 try: | 594 try: |
594 cursor.execute('SELECT * from gaussians2d') | 595 cursor.execute('SELECT * from gaussians2d') |
595 gmmId = None | 596 gmmId = None |
596 gmm = [] | 597 gmm = [] |
597 for row in cursor: | 598 for row in cursor: |
598 if gmmId is None or row[10] != gmmId: | 599 if gmmId is None or row[0] != gmmId: |
599 if len(gmm) > 0: | 600 if len(gmm) > 0: |
600 tmp = mixture.GaussianMixture(len(gmm), covarianceType) | 601 tmp = mixture.GaussianMixture(len(gmm), covarianceType) |
601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) | 602 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) |
602 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) | 603 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) |
603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) | 604 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) |
604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] | 605 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] |
606 tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm]) | |
605 pois.append(tmp) | 607 pois.append(tmp) |
606 gaussian = {'type': row[1], | 608 gaussian = {'type': row[2], |
607 'mean': row[2:4], | 609 'mean': row[3:5], |
608 'covar': array(row[4:8]).reshape(2,2), | 610 'covar': array(literal_eval(row[5])), |
609 'weight': row[9]} | 611 'weight': row[7], |
612 'precisions': array(literal_eval(row[8]))} | |
610 gmm = [gaussian] | 613 gmm = [gaussian] |
611 covarianceType = row[8] | 614 covarianceType = row[6] |
612 gmmId = row[10] | 615 gmmId = row[0] |
613 else: | 616 else: |
614 gmm.append({'type': row[1], | 617 gmm.append({'type': row[2], |
615 'mean': row[2:4], | 618 'mean': row[3:5], |
616 'covar': array(row[4:8]).reshape(2,2), | 619 'covar': array(literal_eval(row[5])), |
617 'weight': row[9]}) | 620 'weight': row[7], |
621 'precisions': array(literal_eval(row[8]))}) | |
618 if len(gmm) > 0: | 622 if len(gmm) > 0: |
619 tmp = mixture.GaussianMixture(len(gmm), covarianceType) | 623 tmp = mixture.GaussianMixture(len(gmm), covarianceType) |
620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) | 624 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) |
621 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) | 625 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) |
622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) | 626 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) |
623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] | 627 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] |
628 tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm]) | |
624 pois.append(tmp) | 629 pois.append(tmp) |
625 except sqlite3.OperationalError as error: | 630 except sqlite3.OperationalError as error: |
626 printDBError(error) | 631 printDBError(error) |
627 connection.close() | 632 connection.close() |
628 return pois | 633 return pois |