changeset 330:00800ebae698

corrected bug in db loading
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Thu, 13 Jun 2013 23:05:28 -0400
parents a70c205ebdd9
children 40790d93200e
files python/storage.py
diffstat 1 files changed, 13 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/python/storage.py	Thu Jun 13 00:42:40 2013 -0400
+++ b/python/storage.py	Thu Jun 13 23:05:28 2013 -0400
@@ -48,6 +48,8 @@
     connection.close()
 
 def setRoadUserTypes(filename, objects):
+    '''Saves the user types of the objects in the sqlite database stored in filename
+    The objects should exist in the objects table'''
     import sqlite3
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
@@ -81,9 +83,9 @@
 
 def getTrajectoryIdQuery(objectNumbers, trajectoryType):
     if trajectoryType == 'feature':
-        statementBeginning = 'trajectory_id'
+        statementBeginning = 'where trajectory_id '
     elif trajectoryType == 'object':
-        statementBeginning =  'object_id'
+        statementBeginning =  'and OF.object_id '
     else:
         print('no trajectory type was chosen')
 
@@ -91,9 +93,9 @@
         if objectNumbers == -1:
             query = ''
         else:
-            query = statementBeginning+' between 0 and {0}'.format(objectNumbers)
+            query = statementBeginning+'between 0 and {0} '.format(objectNumbers)
     elif type(objectNumbers) == list:
-        query = statementBeginning+' in ('+', '.join([str(n) for n in objectNumbers])+')'
+        query = statementBeginning+'in ('+', '.join([str(n) for n in objectNumbers])+') '
     return query
 
 def loadTrajectoriesFromTable(connection, tableName, trajectoryType, objectNumbers = -1):
@@ -108,10 +110,10 @@
     try:
         if trajectoryType == 'feature':
             trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
-            cursor.execute('SELECT * from '+tableName+' where '+trajectoryIdQuery+' order by trajectory_id, frame_number')
+            cursor.execute('SELECT * from '+tableName+' '+trajectoryIdQuery+'order by trajectory_id, frame_number')
         elif trajectoryType == 'object':
             objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
-            cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.'+objectIdQuery+' group by OF.object_id, P.frame_number order by OF.object_id, P.frame_number')
+            cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by OF.object_id, P.frame_number order by OF.object_id, P.frame_number')
         else:
             print('no trajectory type was chosen')
     except sqlite3.OperationalError as err:
@@ -160,7 +162,7 @@
         try:
             # attribute feature numbers to objects
             objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
-            cursor.execute('SELECT P.trajectory_id, OF.object_id from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.'+objectIdQuery+' group by P.trajectory_id order by OF.object_id') # order is important to group all features per object
+            cursor.execute('SELECT P.trajectory_id, OF.object_id from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by P.trajectory_id order by OF.object_id') # order is important to group all features per object
 
             featureNumbers = {}
             for row in cursor:
@@ -174,7 +176,10 @@
                 obj.featureNumbers = featureNumbers[obj.getNum()]
 
             # load userType
-            cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery)
+            if objectIdQuery == '':
+                cursor.execute('SELECT object_id, road_user_type from objects')
+            else:
+                cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:])
             userTypes = {}
             for row in cursor:
                 userTypes[row[0]] = row[1]