changeset 919:7b3f2e0a2652

saving and loading prototype trajectories
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 05 Jul 2017 13:16:47 -0400
parents 3a06007a4bb7
children 499154254f37
files python/storage.py scripts/learn-motion-patterns.py
diffstat 2 files changed, 33 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/python/storage.py	Wed Jul 05 12:19:59 2017 -0400
+++ b/python/storage.py	Wed Jul 05 13:16:47 2017 -0400
@@ -6,6 +6,7 @@
 from base import VideoFilenameAddable
 
 from os import path
+from copy import copy
 import sqlite3, logging
 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg
 from pandas import read_csv, merge
@@ -53,17 +54,17 @@
         elif dataType == 'pois':
             dropTables(connection, ['gaussians2d', 'objects_pois'])
         elif dataType == 'prototype':
-            dropTables(connection, ['prototypes'])
+            dropTables(connection, ['prototypes', 'prototype_positions', 'prototype_velocities'])
         else:
             print('Unknown data type {} to delete from database'.format(dataType))
         connection.close()
     else:
         print('{} does not exist'.format(filename))
 
-def tableExists(filename, tableName):
+def tableExists(connection, tableName):
     'indicates if the table exists in the database'
     try:
-        connection = sqlite3.connect(filename)
+        #connection = sqlite3.connect(filename)
         cursor = connection.cursor()
         cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'')
         return cursor.fetchone()[0] == 1
@@ -71,7 +72,7 @@
         printDBError(error)        
 
 def createTrajectoryTable(cursor, tableName):
-    if tableName in ['positions', 'velocities']:
+    if tableName.endswith('positions') or tableName.endswith('velocities'):
         cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))")
     else:
         print('Unallowed name {} for trajectory table'.format(tableName))
@@ -266,15 +267,19 @@
         userTypes[row[0]] = row[1]
     return userTypes
 
-def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None):
+def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None, tablePrefix = None):
     '''Loads the trajectories (in the general sense, 
     either features, objects (feature groups) or bounding box series) 
     The number loaded is either the first objectNumbers objects,
     or the indices in objectNumbers from the database'''
     connection = sqlite3.connect(filename)
 
-    objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers, timeStep)
-    objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers, timeStep)
+    if tablePrefix is None:
+        prefix = ''
+    else:
+        prefix = tablePrefix + '_'
+    objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep)
+    objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep)
 
     if len(objectVelocities) > 0:
         for o,v in zip(objects, objectVelocities):
@@ -590,9 +595,12 @@
 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None):
     '''save the prototype indices
     if objects is not None, the trajectories are also saved in prototype_positions and _velocities
-    (prototypeIndices have to be in objects)
+    (prototypeIndices have to be in objects
+    objects will be saved as features, with the centroid trajectory as if it is a feature)
     nMatchings, if not None, is a list of the number of matches
-    dbFilenames, if not None, is a list of the DB filenames'''
+    dbFilenames, if not None, is a list of the DB filenames
+
+    The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent'''
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
     try:
@@ -607,9 +615,13 @@
             else:
                 dbfn = filename
             cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i))
-        #cursor.execute('SELECT * from sqlite_master WHERE type = \"table\" and name = \"{}\"'.format(tableNames[trajectoryType]))
         if objects is not None: # save positions and velocities
-            pass 
+            features = []
+            for i, o in enumerate(objects):
+                f = copy(o)
+                f.num = i
+                features.append(f)
+            saveTrajectoriesToTable(connection, features, 'feature', 'prototype')
     except sqlite3.OperationalError as error:
         printDBError(error)
     connection.commit()
@@ -626,6 +638,7 @@
     dbFilenames = []
     trajectoryTypes = []
     nMatchings = []
+    trajectoryNumbers = []
     try:
         cursor.execute('SELECT * FROM prototypes')
         for row in cursor:
@@ -634,12 +647,18 @@
             trajectoryTypes.append(row[2])
             if row[3] is not None:
                 nMatchings.append(row[3])
+            if row[4] is not None:
+                trajectoryNumbers.append(row[4])
+        if tableExists(connection, 'prototype_positions'): # load prototypes trajectories
+            objects = loadTrajectoriesFromSqlite(filename, 'feature', trajectoryNumbers, tablePrefix = 'prototype')
+        else:
+            objects = None
     except sqlite3.OperationalError as error:
         printDBError(error)
     connection.close()
     if len(set(trajectoryTypes)) > 1:
         print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes)))
-    return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings
+    return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings, objects
 
 def savePOIs(filename, gmm, gmmType, gmmId):
     '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture)
--- a/scripts/learn-motion-patterns.py	Wed Jul 05 12:19:59 2017 -0400
+++ b/scripts/learn-motion-patterns.py	Wed Jul 05 13:16:47 2017 -0400
@@ -67,7 +67,8 @@
 clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
 print(clusterSizes)
 
-storage.savePrototypesToSqlite(args.databaseFilename, [objects[i].getNum() for i in prototypeIndices], prototypeType, [clusterSizes[i] for i in prototypeIndices]) # if saving filenames, add for example [objects[i].dbFilename for i in prototypeIndices]
+prototypes = [objects[i] for i in prototypeIndices]
+storage.savePrototypesToSqlite(args.databaseFilename, [p.getNum() for p in prototypes], prototypeType, prototypes, [clusterSizes[i] for i in prototypeIndices]) # if saving filenames, add for example [objects[i].dbFilename for i in prototypeIndices]
 
 if args.saveSimilarities:
     np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')