diff python/storage.py @ 231:249d65ff6c35

merged modifications for windows
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 02 Jul 2012 23:49:39 -0400
parents b5772df11b37
children 584613399513
line wrap: on
line diff
--- a/python/storage.py	Fri Jun 29 16:15:13 2012 -0400
+++ b/python/storage.py	Mon Jul 02 23:49:39 2012 -0400
@@ -1,4 +1,5 @@
 #! /usr/bin/env python
+# -*- coding: utf-8 -*-
 '''Various utilities to save and load data'''
 
 import utils
@@ -11,35 +12,93 @@
                   'car':2,
                   'truck':3}
 
+def writeTrajectoriesToSqlite(objects, outFile, trajectoryType, objectNumbers = -1):
+    """
+    This function writers trajectories to a specified sqlite file
+    @param[in] objects -> a list of trajectories
+    @param[in] trajectoryType -
+    @param[out] outFile -> the .sqlite file containting the written objects
+    @param[in] objectNumber : number of objects loaded
+    """
+
+    import sqlite3
+    connection = sqlite3.connect(outFile)
+    cursor = connection.cursor()
+
+    schema = "CREATE TABLE \"positions\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))"
+    cursor.execute(schema)
+
+    trajectory_id = 0
+    frame_number = 0
+    if trajectoryType == 'feature':
+        if type(objectNumbers) == int and objectNumbers == -1:
+            for trajectory in objects:
+                trajectory_id += 1
+                frame_number = 0
+                for position in trajectory.getPositions():
+                    frame_number += 1
+                    query = "insert into positions (trajectory_id, frame_number, x_coordinate, y_coordinate) values (?,?,?,?)"
+                    cursor.execute(query,(trajectory_id,frame_number,position.x,position.y))
+                    
+    connection.commit()            
+    connection.close()
+
+def loadPrototypeMatchIndexesFromSqlite(filename):
+    """
+    This function loads the prototypes table in the database of name <filename>.
+    It returns a list of tuples representing matching ids : [(prototype_id, matched_trajectory_id),...]
+    """
+    matched_indexes = []
+
+    import sqlite3    
+    connection = sqlite3.connect(filename)
+    cursor = connection.cursor()
+
+    try:
+        cursor.execute('SELECT * from prototypes order by prototype_id, trajectory_id_matched')
+    except sqlite3.OperationalError as err:
+        print('DB Error: {0}'.format(err))
+        return []
+
+    for row in cursor:
+        matched_indexes.append((row[0],row[1]))
+
+    connection.close()
+    return matched_indexes
+
 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = -1):
     '''Loads nObjects or the indices in objectNumbers from the database 
     TODO: load velocities (replace table name 'positions' by 'velocities'
     TODO: load features as well, other ways of averaging trajectories
     '''
     import sqlite3
-
+    
     connection = sqlite3.connect(filename) # add test if it open
     cursor = connection.cursor()
 
-    if trajectoryType == 'feature':
-        if type(objectNumbers) == int:
-            if objectNumbers == -1:
-                cursor.execute('SELECT * from positions order by trajectory_id, frame_number')
-            else:
-                cursor.execute('SELECT * from positions where trajectory_id between 0 and {0} order by trajectory_id, frame_number'.format(objectNumbers))
-        elif type(objectNumbers) == list:
-            cursor.execute('SELECT * from positions where trajectory_id in ('+', '.join([str(n) for n in objectNumbers])+') order by trajectory_id, frame_number')
-    elif trajectoryType == 'object':
-        if type(objectNumbers) == int:
-            if objectNumbers == -1:
-                cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id group by object_id, frame_number')
-            else:
-                cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id and trajectory_id between 0 and {0} group by object_id, frame_number'.format(objectNumbers))
-        elif type(objectNumbers) == list:
-            cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id and trajectory_id in ('+', '.join([str(n) for n in objectNumbers])+') group by object_id, frame_number')
-    else:
-        print('no trajectory type was chosen')
-
+    try:
+        if trajectoryType == 'feature':
+            if type(objectNumbers) == int:
+                if objectNumbers == -1:
+                    cursor.execute('SELECT * from positions order by trajectory_id, frame_number')
+                else:
+                    cursor.execute('SELECT * from positions where trajectory_id between 0 and {0} order by trajectory_id, frame_number'.format(objectNumbers))
+            elif type(objectNumbers) == list:
+                cursor.execute('SELECT * from positions where trajectory_id in ('+', '.join([str(n) for n in objectNumbers])+') order by trajectory_id, frame_number')
+        elif trajectoryType == 'object':
+            if type(objectNumbers) == int:
+                if objectNumbers == -1:
+                    cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id group by object_id, frame_number')
+                else:
+                    cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.object_id between 0 and {0} group by object_id, frame_number'.format(objectNumbers))
+            elif type(objectNumbers) == list:
+                cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.object_id in ('+', '.join([str(n) for n in objectNumbers])+') group by object_id, frame_number')
+        else:
+            print('no trajectory type was chosen')
+    except sqlite3.OperationalError as err:
+        print('DB Error: {0}'.format(err))
+        return []
+    
     objId = -1
     obj = None
     objects = []
@@ -143,13 +202,10 @@
         
     out.close()
 
-
-
-
-# if __name__ == "__main__":
-#     import doctest
-#     import unittest
-#     suite = doctest.DocFileSuite('tests/ubc_utils.txt')
-#     unittest.TextTestRunner().run(suite)
+if __name__ == "__main__":
+    import doctest
+    import unittest
+    suite = doctest.DocFileSuite('tests/storage.txt')
+    unittest.TextTestRunner().run(suite)
 #     #doctest.testmod()
 #     #doctest.testfile("example.txt")