comparison python/storage.py @ 871:6db83beb5350

work in progress to update gaussian mixtures
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 03 Feb 2017 16:26:18 -0500
parents 2d6249fe905a
children c70adaeeddf5
comparison
equal deleted inserted replaced
870:1535251a1f40 871:6db83beb5350
565 ######################### 565 #########################
566 # saving and loading for scene interpretation 566 # saving and loading for scene interpretation
567 ######################### 567 #########################
568 568
569 def savePOIs(filename, gmm, gmmType, gmmId): 569 def savePOIs(filename, gmm, gmmType, gmmId):
570 '''Saves a Gaussian mixture model (of class sklearn.mixture.GMM) 570 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture)
571 gmmType is a type of GMM, learnt either from beginnings or ends of trajectories''' 571 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories'''
572 connection = sqlite3.connect(filename) 572 connection = sqlite3.connect(filename)
573 cursor = connection.cursor() 573 cursor = connection.cursor()
574 if gmmType not in ['beginning', 'end']: 574 if gmmType not in ['beginning', 'end']:
575 print('Unknown POI type {}. Exiting'.format(gmmType)) 575 print('Unknown POI type {}. Exiting'.format(gmmType))
576 import sys 576 import sys
577 sys.exit() 577 sys.exit()
578 try: 578 try:
579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))') 579 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))')
580 for i in xrange(gmm.n_components): 580 for i in xrange(gmm.n_components):
581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covars_[i][0,0], gmm.covars_[i][0,1], gmm.covars_[i][1,0], gmm.covars_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId)) 581 cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId))
582 connection.commit() 582 connection.commit()
583 except sqlite3.OperationalError as error: 583 except sqlite3.OperationalError as error:
584 printDBError(error) 584 printDBError(error)
585 connection.close() 585 connection.close()
586 586
595 gmmId = None 595 gmmId = None
596 gmm = [] 596 gmm = []
597 for row in cursor: 597 for row in cursor:
598 if gmmId is None or row[10] != gmmId: 598 if gmmId is None or row[10] != gmmId:
599 if len(gmm) > 0: 599 if len(gmm) > 0:
600 tmp = mixture.GMM(len(gmm), covarianceType) 600 tmp = mixture.GaussianMixture(len(gmm), covarianceType)
601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) 601 tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
602 tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) 602 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) 603 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] 604 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
605 pois.append(tmp) 605 pois.append(tmp)
606 gaussian = {'type': row[1], 606 gaussian = {'type': row[1],
607 'mean': row[2:4], 607 'mean': row[2:4],
614 gmm.append({'type': row[1], 614 gmm.append({'type': row[1],
615 'mean': row[2:4], 615 'mean': row[2:4],
616 'covar': array(row[4:8]).reshape(2,2), 616 'covar': array(row[4:8]).reshape(2,2),
617 'weight': row[9]}) 617 'weight': row[9]})
618 if len(gmm) > 0: 618 if len(gmm) > 0:
619 tmp = mixture.GMM(len(gmm), covarianceType) 619 tmp = mixture.GaussianMixture(len(gmm), covarianceType)
620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) 620 tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
621 tmp.covars_ = array([gaussian['covar'] for gaussian in gmm]) 621 tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) 622 tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] 623 tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
624 pois.append(tmp) 624 pois.append(tmp)
625 except sqlite3.OperationalError as error: 625 except sqlite3.OperationalError as error:
626 printDBError(error) 626 printDBError(error)