diff python/storage.py @ 918:3a06007a4bb7

modularized save trajectories, added slice to Trajectory, etc
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 05 Jul 2017 12:19:59 -0400
parents 89cc05867c4c
children 7b3f2e0a2652
line wrap: on
line diff
--- a/python/storage.py	Tue Jul 04 18:00:01 2017 -0400
+++ b/python/storage.py	Wed Jul 05 12:19:59 2017 -0400
@@ -375,7 +375,70 @@
     if len(missingObjectNumbers) > 0:
         print('List of missing objects to attach corresponding curvilinear trajectories: {}'.format(missingObjectNumbers))
 
-def saveTrajectoriesToSqlite(outputFilename, objects, trajectoryType, withFeatures = False):
+def saveTrajectoriesToTable(connection, objects, trajectoryType, tablePrefix = None):
+    'Saves trajectories in table tableName'
+    cursor = connection.cursor()
+    # Parse feature and/or object structure and commit to DB
+    if(trajectoryType == 'feature' or trajectoryType == 'object'):
+        # Extract features from objects
+        if trajectoryType == 'object':
+            features = []
+            for obj in objects:
+                if obj.hasFeatures():
+                    features += obj.getFeatures()
+            if len(features) == 0:
+                print('Warning, objects have no features') # todo save centroid trajectories?
+        elif trajectoryType == 'feature':
+            features = objects
+        # Setup feature queries
+        if tablePrefix is None:
+            prefix = ''
+        else:
+            prefix = tablePrefix+'_'
+        createTrajectoryTable(cursor, prefix+"positions")
+        createTrajectoryTable(cursor, prefix+"velocities")
+        positionQuery = insertTrajectoryQuery(prefix+"positions")
+        velocityQuery = insertTrajectoryQuery(prefix+"velocities")
+        # Setup object queries
+        if trajectoryType == 'object':    
+            createObjectsTable(cursor)
+            createObjectsFeaturesTable(cursor)
+            objectQuery = insertObjectQuery()
+            objectFeatureQuery = insertObjectFeatureQuery()
+        for feature in features:
+            num = feature.getNum()
+            frameNum = feature.getFirstInstant()
+            for p in feature.getPositions():
+                cursor.execute(positionQuery, (num, frameNum, p.x, p.y))
+                frameNum += 1
+            velocities = feature.getVelocities()
+            if velocities is not None:
+                frameNum = feature.getFirstInstant()
+                for v in velocities[:-1]:
+                    cursor.execute(velocityQuery, (num, frameNum, v.x, v.y))
+                    frameNum += 1
+        if trajectoryType == 'object':
+            for obj in objects:
+                if obj.hasFeatures():
+                    for feature in obj.getFeatures():
+                        featureNum = feature.getNum()
+                        cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum))
+                cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1))   
+    # Parse curvilinear position structure
+    elif(trajectoryType == 'curvilinear'):
+        createCurvilinearTrajectoryTable(cursor)
+        curvilinearQuery = "insert into curvilinear_positions (trajectory_id, frame_number, s_coordinate, y_coordinate, lane) values (?,?,?,?,?)"
+        for obj in objects:
+            num = obj.getNum()
+            frameNum = obj.getFirstInstant()
+            for p in obj.getCurvilinearPositions():
+                cursor.execute(curvilinearQuery, (num, frameNum, p[0], p[1], p[2]))
+                frameNum += 1
+    else:
+        print('Unknown trajectory type {}'.format(trajectoryType))
+    connection.commit()
+
+def saveTrajectoriesToSqlite(outputFilename, objects, trajectoryType):
     '''Writes features, ie the trajectory positions (and velocities if exist)
     with their instants to a specified sqlite file
     Either feature positions (and velocities if they exist)
@@ -383,60 +446,7 @@
 
     connection = sqlite3.connect(outputFilename)
     try:
-        cursor = connection.cursor()
-        # Parse feature and/or object structure and commit to DB
-        if(trajectoryType == 'feature' or trajectoryType == 'object'):
-            # Extract features from objects
-            if(trajectoryType == 'object'):
-                features = []
-                for obj in objects:
-                    if(obj.hasFeatures()):
-                        features += obj.getFeatures()
-            elif(trajectoryType == 'feature'):
-                features = objects
-            # Setup feature queries
-            createTrajectoryTable(cursor, "positions")
-            createTrajectoryTable(cursor, "velocities")
-            positionQuery = insertTrajectoryQuery("positions")
-            velocityQuery = insertTrajectoryQuery("velocities")
-            # Setup object queries
-            if(trajectoryType == 'object'):    
-                createObjectsTable(cursor)
-                createObjectsFeaturesTable(cursor)
-                objectQuery = insertObjectQuery()
-                objectFeatureQuery = insertObjectFeatureQuery()
-            for feature in features:
-                num = feature.getNum()
-                frameNum = feature.getFirstInstant()
-                for position in feature.getPositions():
-                    cursor.execute(positionQuery, (num, frameNum, position.x, position.y))
-                    frameNum += 1
-                velocities = feature.getVelocities()
-                if velocities is not None:
-                    frameNum = feature.getFirstInstant()
-                    for i in xrange(velocities.length()-1):
-                        v = velocities[i]
-                        cursor.execute(velocityQuery, (num, frameNum, v.x, v.y))
-                        frameNum += 1    
-            if(trajectoryType == 'object'):
-                for obj in objects:
-                    for feature in obj.getFeatures():
-                        featureNum = feature.getNum()
-                        cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum))
-                    cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1))   
-        # Parse curvilinear position structure
-        elif(trajectoryType == 'curvilinear'):
-            createCurvilinearTrajectoryTable(cursor)
-            curvilinearQuery = "insert into curvilinear_positions (trajectory_id, frame_number, s_coordinate, y_coordinate, lane) values (?,?,?,?,?)"
-            for obj in objects:
-                num = obj.getNum()
-                frameNum = obj.getFirstInstant()
-                for position in obj.getCurvilinearPositions():
-                    cursor.execute(curvilinearQuery, (num, frameNum, position[0], position[1], position[2]))
-                    frameNum += 1
-        else:
-            print('Unknown trajectory type {}'.format(trajectoryType))
-        connection.commit()
+        saveTrajectoriesToTable(connection, objects, trajectoryType, None)
     except sqlite3.OperationalError as error:
         printDBError(error)
     connection.close()
@@ -598,7 +608,7 @@
                 dbfn = filename
             cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i))
         #cursor.execute('SELECT * from sqlite_master WHERE type = \"table\" and name = \"{}\"'.format(tableNames[trajectoryType]))
-        if objects is not None:
+        if objects is not None: # save positions and velocities
             pass 
     except sqlite3.OperationalError as error:
         printDBError(error)