changeset 871:6db83beb5350

work in progress to update gaussian mixtures
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 03 Feb 2017 16:26:18 -0500
parents 1535251a1f40
children c70adaeeddf5
files python/storage.py python/tests/storage.txt scripts/learn-poi.py
diffstat 3 files changed, 10 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/python/storage.py	Fri Feb 03 16:15:06 2017 -0500
+++ b/python/storage.py	Fri Feb 03 16:26:18 2017 -0500
@@ -567,8 +567,8 @@
 #########################
 
 def savePOIs(filename, gmm, gmmType, gmmId):
-    '''Saves a Gaussian mixture model (of class sklearn.mixture.GMM)
-    gmmType is a type of GMM, learnt either from beginnings or ends of trajectories'''
+    '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture)
+    gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories'''
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
     if gmmType not in ['beginning', 'end']:
@@ -578,7 +578,7 @@
     try:
         cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))')
         for i in xrange(gmm.n_components):
-            cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covars_[i][0,0], gmm.covars_[i][0,1], gmm.covars_[i][1,0], gmm.covars_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId))
+            cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId))
         connection.commit()
     except sqlite3.OperationalError as error:
         printDBError(error)
@@ -597,9 +597,9 @@
         for row in cursor:
             if gmmId is None or row[10] != gmmId:
                 if len(gmm) > 0:
-                    tmp = mixture.GMM(len(gmm), covarianceType)
+                    tmp = mixture.GaussianMixture(len(gmm), covarianceType)
                     tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
-                    tmp.covars_ = array([gaussian['covar'] for gaussian in gmm])
+                    tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
                     tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
                     tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
                     pois.append(tmp)
@@ -616,9 +616,9 @@
                             'covar': array(row[4:8]).reshape(2,2),
                             'weight': row[9]})
         if len(gmm) > 0:
-            tmp = mixture.GMM(len(gmm), covarianceType)
+            tmp = mixture.GaussianMixture(len(gmm), covarianceType)
             tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
-            tmp.covars_ = array([gaussian['covar'] for gaussian in gmm])
+            tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
             tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
             tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
             pois.append(tmp)
--- a/python/tests/storage.txt	Fri Feb 03 16:15:06 2017 -0500
+++ b/python/tests/storage.txt	Fri Feb 03 16:26:18 2017 -0500
@@ -86,11 +86,11 @@
 >>> readline(strio, '%#')
 'sadlkfjsdlakjf'
 
->>> from sklearn.mixture import GMM
+>>> from sklearn.mixture import GaussianMixture
 >>> from numpy.random import random_sample
 >>> nPoints = 50
 >>> points = random_sample(nPoints*2).reshape(nPoints,2)
->>> gmm = GMM(4, covariance_type = 'full')
+>>> gmm = GaussianMixture(4, covariance_type = 'full')
 >>> tmp = gmm.fit(points)
 >>> id = 0
 >>> savePOIs('pois-tmp.sqlite', gmm, 'end', id)
--- a/scripts/learn-poi.py	Fri Feb 03 16:15:06 2017 -0500
+++ b/scripts/learn-poi.py	Fri Feb 03 16:26:18 2017 -0500
@@ -40,7 +40,7 @@
                                    [beginnings, ends],
                                    ['beginning', 'end']):
     # estimation
-    gmm = mixture.GMM(n_components=nClusters, covariance_type = args.covarianceType)
+    gmm = mixture.GaussianMixture(n_components=nClusters, covariance_type = args.covarianceType)
     model=gmm.fit(beginnings)
     if not model.converged_:
         print('Warning: model for '+gmmType+' points did not converge')