changeset 872:c70adaeeddf5

solved issue with latest version of scikit-learn
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 08 Feb 2017 16:32:15 -0500
parents 6db83beb5350
children 6b474db46b45
files python/metadata.py python/storage.py python/tests/storage.txt
diffstat 3 files changed, 24 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/python/metadata.py	Fri Feb 03 16:26:18 2017 -0500
+++ b/python/metadata.py	Wed Feb 08 16:32:15 2017 -0500
@@ -140,9 +140,9 @@
             self.undistortedImageMultiplication = undistortedImageMultiplication
             
         if self.intrinsicCameraMatrix is not None:
-            self.intrinsicCameraMatrixStr = '{}'.format(self.intrinsicCameraMatrix.tolist())
+            self.intrinsicCameraMatrixStr = str(self.intrinsicCameraMatrix.tolist())
         if self.distortionCoefficients is not None and len(self.distortionCoefficients) == 5:
-            self.distortionCoefficientsStr = '{}'.format(self.distortionCoefficients)
+            self.distortionCoefficientsStr = str(self.distortionCoefficients)
 
     @orm.reconstructor
     def initOnLoad(self):
--- a/python/storage.py	Fri Feb 03 16:26:18 2017 -0500
+++ b/python/storage.py	Wed Feb 08 16:32:15 2017 -0500
@@ -576,9 +576,9 @@
         import sys
         sys.exit()
     try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covar00 REAL, covar01 REAL, covar10 REAL, covar11 REAL, covariance_type VARCHAR, weight, mixture_id INTEGER, PRIMARY KEY(id, mixture_id))')
+        cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))')
         for i in xrange(gmm.n_components):
-            cursor.execute('INSERT INTO gaussians2d VALUES({}, \'{}\', {}, {}, {}, {}, {}, {}, \'{}\', {}, {})'.format(i, gmmType, gmm.means_[i][0], gmm.means_[i][1], gmm.covariances_[i][0,0], gmm.covariances_[i][0,1], gmm.covariances_[i][1,0], gmm.covariances_[i][1,1], gmm.covariance_type, gmm.weights_[i], gmmId))
+            cursor.execute('INSERT INTO gaussians2d VALUES({}, {}, \'{}\', {}, {}, \'{}\', \'{}\', {}, \'{}\')'.format(gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist())))
         connection.commit()
     except sqlite3.OperationalError as error:
         printDBError(error)
@@ -587,6 +587,7 @@
 def loadPOIs(filename):
     'Loads all 2D Gaussians in the database'
     from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields
+    from ast import literal_eval
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
     pois = []
@@ -595,32 +596,36 @@
         gmmId = None
         gmm = []
         for row in cursor:
-            if gmmId is None or row[10] != gmmId:
+            if gmmId is None or row[0] != gmmId:
                 if len(gmm) > 0:
                     tmp = mixture.GaussianMixture(len(gmm), covarianceType)
                     tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
                     tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
                     tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
                     tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
+                    tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm])
                     pois.append(tmp)
-                gaussian = {'type': row[1],
-                            'mean': row[2:4],
-                            'covar': array(row[4:8]).reshape(2,2),
-                            'weight': row[9]}
+                gaussian = {'type': row[2],
+                            'mean': row[3:5],
+                            'covar': array(literal_eval(row[5])),
+                            'weight': row[7],
+                            'precisions': array(literal_eval(row[8]))}
                 gmm = [gaussian]
-                covarianceType = row[8]
-                gmmId = row[10]
+                covarianceType = row[6]
+                gmmId = row[0]
             else:
-                gmm.append({'type': row[1],
-                            'mean': row[2:4],
-                            'covar': array(row[4:8]).reshape(2,2),
-                            'weight': row[9]})
+                gmm.append({'type': row[2],
+                            'mean': row[3:5],
+                            'covar': array(literal_eval(row[5])),
+                            'weight': row[7],
+                            'precisions': array(literal_eval(row[8]))})
         if len(gmm) > 0:
             tmp = mixture.GaussianMixture(len(gmm), covarianceType)
             tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
             tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
             tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
             tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
+            tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm])
             pois.append(tmp)
     except sqlite3.OperationalError as error:
         printDBError(error)
--- a/python/tests/storage.txt	Fri Feb 03 16:26:18 2017 -0500
+++ b/python/tests/storage.txt	Wed Feb 08 16:32:15 2017 -0500
@@ -92,11 +92,11 @@
 >>> points = random_sample(nPoints*2).reshape(nPoints,2)
 >>> gmm = GaussianMixture(4, covariance_type = 'full')
 >>> tmp = gmm.fit(points)
->>> id = 0
->>> savePOIs('pois-tmp.sqlite', gmm, 'end', id)
+>>> gmmId = 0
+>>> savePOIs('pois-tmp.sqlite', gmm, 'end', gmmId)
 >>> reloadedGmm = loadPOIs('pois-tmp.sqlite')
->>> sum(gmm.predict(points) == reloadedGmm[id].predict(points)) == nPoints
+>>> sum(gmm.predict(points) == reloadedGmm[gmmId].predict(points)) == nPoints
 True
->>> reloadedGmm[id].gmmTypes[0] == 'end'
+>>> reloadedGmm[gmmId].gmmTypes[0] == 'end'
 True
 >>> remove('pois-tmp.sqlite')