changeset 1037:6a6c37eb3a74

added function to load prototype assignments
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 03 Jul 2018 23:58:16 -0400
parents 0d7e5e290ea3
children d24deb61f550
files trafficintelligence/moving.py trafficintelligence/storage.py
diffstat 2 files changed, 32 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/trafficintelligence/moving.py	Tue Jul 03 16:47:37 2018 -0400
+++ b/trafficintelligence/moving.py	Tue Jul 03 23:58:16 2018 -0400
@@ -1784,6 +1784,10 @@
         self.movingObject = o
     def __str__(self):
         return '{} {} {}'.format(self.filename, self.num, self.trajectoryType)
+    def __eq__(self, p2):
+        return self.filename == p2.filename and self.num == p2.num and self.trajectoryType == p2.trajectoryType
+    def __hash__(self):
+        return hash((self.filename, self.num, self.trajectoryType))
     
 ##################
 # Annotations
--- a/trafficintelligence/storage.py	Tue Jul 03 16:47:37 2018 -0400
+++ b/trafficintelligence/storage.py	Tue Jul 03 23:58:16 2018 -0400
@@ -22,10 +22,6 @@
               'object': 'objects',
               'objectfeatures': 'positions'}
 
-assignmentTableNames = {'feature':'positions',
-                        'object': 'objects',
-                        'objectfeatures': 'positions'}
-
 #########################
 # Sqlite
 #########################
@@ -593,16 +589,21 @@
             printDBError(error)
         connection.commit()
 
+def prototypeAssignmentNames(objectType):
+    tableName = objectType+'s_prototypes'
+    if objectType == 'feature':
+        #tableName = 'features_prototypes'
+        objectIdColumnName = 'trajectory_id'
+    elif objectType == 'object':
+        #tableName = 'objects_prototypes'
+        objectIdColumnName = 'object_id'
+    return tableName, objectIdColumnName
+        
 def savePrototypeAssignmentsToSqlite(filename, objectNumbers, objectType, labels, prototypes):
     with sqlite3.connect(filename) as connection:
         cursor = connection.cursor()
         try:
-            if objectType == 'feature':
-                tableName = 'features_prototypes'
-                objectIdColumnName = 'trajectory_id'
-            elif objectType == 'object':
-                tableName = 'objects_prototypes'
-                objectIdColumnName = 'object_id'
+            tableName, objectIdColumnName = prototypeAssignmentNames(objectType)
             cursor.execute('CREATE TABLE IF NOT EXISTS '+tableName+' ('+objectIdColumnName+' INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY('+objectIdColumnName+', prototype_filename, prototype_id, trajectory_type))')
             for objNum, label in zip(objectNumbers, labels):
                 if label >=0:
@@ -612,6 +613,23 @@
             printDBError(error)
         connection.commit()
 
+def loadPrototypeAssignmentsFromSqlite(filename, objectType):
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            tableName, objectIdColumnName = prototypeAssignmentNames(objectType)
+            cursor.execute('SELECT * FROM '+tableName)
+            prototypeAssignments = {}
+            for row in cursor:
+                p = moving.Prototype(row[1], row[2], row[3])
+                if p in prototypeAssignments:
+                    prototypeAssignments[p].append(row[0])
+                else:
+                    prototypeAssignments[p] = [row[0]]
+            return prototypeAssignments
+        except sqlite3.OperationalError as error:
+            printDBError(error)   
+        
 def loadPrototypesFromSqlite(filename, withTrajectories = True):
     'Loads prototype ids and matchings (if stored)'
     prototypes = []