diff python/storage.py @ 615:0954aaf28231

Merge
author MohamedGomaa
date Wed, 10 Dec 2014 14:12:06 -0500
parents 84690dfe5560 c5406edbcf12
children dc2d0a0d7fe1
line wrap: on
line diff
--- a/python/storage.py	Thu Dec 04 19:07:55 2014 -0500
+++ b/python/storage.py	Wed Dec 10 14:12:06 2014 -0500
@@ -34,6 +34,7 @@
     except sqlite3.OperationalError as error:
         printDBError(error)
 
+# TODO: add test if database connection is open
 # IO to sqlite
 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1):
     """
@@ -293,20 +294,21 @@
     if trajectoryType == 'feature':
         statementBeginning = 'where trajectory_id '
     elif trajectoryType == 'object':
-        statementBeginning =  'and OF.object_id '
+        statementBeginning = 'and OF.object_id '
+    elif trajectoryType == 'bbtop' or 'bbbottom':
+        statementBeginning = 'where object_id '
     else:
         print('no trajectory type was chosen')
 
-    if type(objectNumbers) == int:
-        if objectNumbers == -1:
-            query = ''
-        else:
-            query = statementBeginning+'between 0 and {0} '.format(objectNumbers)
+    if objectNumbers is None:
+        query = ''
+    elif type(objectNumbers) == int:
+        query = statementBeginning+'between 0 and {0} '.format(objectNumbers)
     elif type(objectNumbers) == list:
         query = statementBeginning+'in ('+', '.join([str(n) for n in objectNumbers])+') '
     return query
 
-def loadTrajectoriesFromTable(connection, tableName, trajectoryType, objectNumbers = -1):
+def loadTrajectoriesFromTable(connection, tableName, trajectoryType, objectNumbers = None):
     '''Loads trajectories (in the general sense) from the given table
     can be positions or velocities
 
@@ -314,14 +316,21 @@
     cursor = connection.cursor()
 
     try:
+        idQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
         if trajectoryType == 'feature':
-            trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
-            queryStatement = 'SELECT * from '+tableName+' '+trajectoryIdQuery+'order by trajectory_id, frame_number'
+            queryStatement = 'SELECT * from '+tableName+' '+idQuery+'ORDER BY trajectory_id, frame_number'
             cursor.execute(queryStatement)
             logging.debug(queryStatement)
         elif trajectoryType == 'object':
-            objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
-            queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by OF.object_id, P.frame_number order by OF.object_id, P.frame_number'
+            queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+idQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number'
+            cursor.execute(queryStatement)
+            logging.debug(queryStatement)
+        elif trajectoryType in ['bbtop', 'bbbottom']:
+            if trajectoryType == 'bbtop':
+                corner = 'top_left'
+            elif trajectoryType == 'bbbottom':
+                corner = 'bottom_right'
+            queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number'
             cursor.execute(queryStatement)
             logging.debug(queryStatement)
         else:
@@ -336,21 +345,36 @@
     for row in cursor:
         if row[0] != objId:
             objId = row[0]
-            if obj:
+            if obj != None and obj.length() == obj.positions.length():
                 objects.append(obj)
+            elif obj != None:
+                print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length()))
             obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]]))
         else:
             obj.timeInterval.last = row[1]
             obj.positions.addPositionXY(row[2],row[3])
 
-    if obj:
+    if obj != None and obj.length() == obj.positions.length():
         objects.append(obj)
+    elif obj != None:
+        print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length()))
 
     return objects
 
-def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = -1):
-    '''Loads nObjects or the indices in objectNumbers from the database'''
-    connection = sqlite3.connect(filename) # add test if it open
+def loadUserTypesFromTable(cursor, trajectoryType, objectNumbers):
+    objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
+    if objectIdQuery == '':
+        cursor.execute('SELECT object_id, road_user_type from objects')
+    else:
+        cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:])
+    userTypes = {}
+    for row in cursor:
+        userTypes[row[0]] = row[1]
+    return userTypes
+
+def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None):
+    '''Loads the first objectNumbers objects or the indices in objectNumbers from the database'''
+    connection = sqlite3.connect(filename)
 
     objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers)
     objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers)
@@ -384,26 +408,40 @@
                 obj.featureNumbers = featureNumbers[obj.getNum()]
 
             # load userType
-            if objectIdQuery == '':
-                cursor.execute('SELECT object_id, road_user_type from objects')
-            else:
-                cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:])
-            userTypes = {}
-            for row in cursor:
-                userTypes[row[0]] = row[1]
-            
+            userTypes = loadUserTypesFromTable(cursor, trajectoryType, objectNumbers)
             for obj in objects:
                 obj.userType = userTypes[obj.getNum()]
              
         except sqlite3.OperationalError as error:
             printDBError(error)
-            return []
+            objects = []
 
     connection.close()
     return objects
 
-def removeFromSqlite(filename, dataType):
-    'Removes some tables in the filename depending on type of data'
+def loadGroundTruthFromSqlite(filename, gtType, gtNumbers = None):
+    'Loads bounding box annotations (ground truth) from an SQLite '
+    connection = sqlite3.connect(filename)
+    gt = []
+
+    if gtType == 'bb':
+        topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', gtNumbers)
+        bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', gtNumbers)
+        userTypes = loadUserTypesFromTable(connection.cursor(), 'object', gtNumbers) # string format is same as object
+        
+        for t, b in zip(topCorners, bottomCorners):
+            num = t.getNum()
+            if t.getNum() == b.getNum():
+                annotation = moving.BBAnnotation(num, t.getTimeInterval(), t, b, userTypes[num])
+                gt.append(annotation)
+    else:
+        print ('Unknown type of annotation {}'.format(gtType))
+
+    connection.close()
+    return gt
+
+def deleteFromSqlite(filename, dataType):
+    'Deletes (drops) some tables in the filename depending on type of data'
     import os
     if os.path.isfile(filename):
         connection = sqlite3.connect(filename)
@@ -525,7 +563,7 @@
     connection.commit()
     connection.close()
 
-def loadBoundingBoxTable(filename):
+def loadBoundingBoxTableForDisplay(filename):
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
     boundingBoxes = {} # list of bounding boxes for each instant
@@ -534,9 +572,7 @@
         result = [row for row in cursor]
         if len(result) > 0:
             cursor.execute('SELECT * FROM bounding_boxes')
-            #objId = -1
             for row in cursor:
-                #if row[0] != objId:
                 boundingBoxes.setdefault(row[1], []).append([moving.Point(row[2], row[3]), moving.Point(row[4], row[5])])
     except sqlite3.OperationalError as error:
         printDBError(error)
@@ -544,6 +580,19 @@
     connection.close()
     return boundingBoxes
 
+def loadBoundingBoxTable(filename):
+    connection = sqlite3.connect(filename)
+    cursor = connection.cursor()
+    boundingBoxes = []
+    
+    try:
+        pass
+    except sqlite3.OperationalError as error:
+        printDBError(error)
+        return boundingBoxes
+    connection.close()
+    return boundingBoxes
+
 
 #########################
 # txt files