changeset 844:5a68779d7777

added capability to save prototypes
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Thu, 14 Jul 2016 00:34:59 -0400
parents 5dc7a507353e
children aa98e773ac91
files python/storage.py scripts/delete-tables.py scripts/learn-motion-patterns.py scripts/learn-poi.py
diffstat 4 files changed, 40 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/python/storage.py	Wed Jul 13 23:45:47 2016 -0400
+++ b/python/storage.py	Thu Jul 14 00:34:59 2016 -0400
@@ -364,21 +364,44 @@
         printDBError(error)
     connection.close()
 
-def savePrototypesToSqlite(filename, prototypes, trajectoryType = 'feature'):
-    'Work in progress, do not use'
+def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, nMatchings = None):
+    '''save the prototype indices
+    nMatchings, if not None, is a dictionnary between indices and number of matches'''
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
     try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER PRIMARY KEY, object_id INTEGER, trajectory_id INTEGER, nMatchings INTEGER, FOREIGN KEY(object_id) REFERENCES objects(id), FOREIGN KEY(trajectory_id) REFERENCES positions(trajectory_id))')
-        #for inter in interactions:
-        #    saveInteraction(cursor, inter)
+        cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER PRIMARY KEY, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nMatchings INTEGER)')
+        for i in prototypeIndices:
+            if nMatchings is not None:
+                n = nMatchings[i]
+            else:
+                n = 'NULL'
+            cursor.execute('INSERT INTO prototypes (id, trajectory_type, nMatchings) VALUES ({},\"{}\",{})'.format(i, trajectoryType, n))
     except sqlite3.OperationalError as error:
         printDBError(error)
     connection.commit()
     connection.close()
 
 def loadPrototypesFromSqlite(filename):
-    pass
+    'Loads prototype ids and matchings (if stored)'
+    connection = sqlite3.connect(filename)
+    cursor = connection.cursor()
+    prototypeIndices = []
+    trajectoryTypes = []
+    nMatchings = {}
+    try:
+        cursor.execute('SELECT * FROM prototypes')
+        for row in cursor:
+            prototypeIndices.append(row[0])
+            trajectoryTypes.append(row[1])
+            if row[2] is not None:
+                nMatchings[row[0]] = row[2]
+    except sqlite3.OperationalError as error:
+        printDBError(error)
+    connection.close()
+    if len(set(trajectoryTypes)) > 1:
+        print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes)))
+    return prototypeIndices, trajectoryTypes[0], nMatchings
 
 def loadBBMovingObjectsFromSqlite(filename, objectType = 'bb', objectNumbers = None, timeStep = None):
     '''Loads bounding box moving object from an SQLite
@@ -418,6 +441,8 @@
             dropTables(connection, ['bounding_boxes'])
         elif dataType == 'pois':
             dropTables(connection, ['gaussians2d'])
+        elif dataType == 'prototype':
+            dropTables(connection, ['prototypes'])
         else:
             print('Unknown data type {} to delete from database'.format(dataType))
         connection.close()
--- a/scripts/delete-tables.py	Wed Jul 13 23:45:47 2016 -0400
+++ b/scripts/delete-tables.py	Thu Jul 14 00:34:59 2016 -0400
@@ -8,7 +8,7 @@
 parser = argparse.ArgumentParser(description='The program deletes (drops) the tables in the database before saving new results (for objects, tables object_features and objects are dropped; for interactions, the tables interactions and indicators are dropped')
 #parser.add_argument('configFilename', help = 'name of the configuration file')
 parser.add_argument('-d', dest = 'databaseFilename', help = 'name of the Sqlite database', required = True)
-parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois'])
+parser.add_argument('-t', dest = 'dataType', help = 'type of the data to remove', required = True, choices = ['object','interaction', 'bb', 'pois', 'prototype'])
 args = parser.parse_args()
 
 storage.deleteFromSqlite(args.databaseFilename, args.dataType)
--- a/scripts/learn-motion-patterns.py	Wed Jul 13 23:45:47 2016 -0400
+++ b/scripts/learn-motion-patterns.py	Thu Jul 14 00:34:59 2016 -0400
@@ -50,13 +50,16 @@
 
 prototypeIndices, labels = ml.prototypeCluster(trajectories, similarities, args.minSimilarity, lambda x,y : lcss.computeNormalized(x, y), args.minClusterSize, args.randomInitialization) # this line can be called again without reinitializing similarities
 
-print(ml.computeClusterSizes(labels, prototypeIndices, -1))
+clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
+print(clusterSizes)
+
+storage.savePrototypesToSqlite(args.databaseFilename, [objects[i].getNum() for i in prototypeIndices], args.trajectoryType, {objects[i].getNum():clusterSizes[i] for i in prototypeIndices})
 
 if args.saveSimilarities:
-    np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4')
+    np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')
 
 if args.display:
-    from matplotlib.pyplot import figure, show
+    from matplotlib.pyplot import figure, show, axis
     figure()
     for i,o in enumerate(objects):
         if i not in prototypeIndices:
@@ -66,6 +69,7 @@
                 o.plot(utils.colors[labels[i]])
     for i in prototypeIndices:
             objects[i].plot(utils.colors[i]+'o')
+    axis('equal')
     show()
 
 # TODO store the prototypes trajectories, add option so store similarities (the most expensive stuff) with limited accuracy
--- a/scripts/learn-poi.py	Wed Jul 13 23:45:47 2016 -0400
+++ b/scripts/learn-poi.py	Thu Jul 14 00:34:59 2016 -0400
@@ -59,6 +59,7 @@
     gmmId += 1
 
 if args.display:
+    plt.axis('equal')
     plt.show()
 
 # fig = plt.figure()