diff python/storage.py @ 619:dc2d0a0d7fe1

merged code from Mohamed Gomaa Mohamed for the use of points of interests in mation pattern learning and motion prediction (TRB 2015)
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 10 Dec 2014 15:27:08 -0500
parents 5800a87f11ae 0954aaf28231
children 977407c9f815
line wrap: on
line diff
--- a/python/storage.py	Sun Dec 07 23:01:02 2014 -0500
+++ b/python/storage.py	Wed Dec 10 15:27:08 2014 -0500
@@ -64,7 +64,7 @@
                     
     connection.commit()
     connection.close()
-	
+    
 def writeFeaturesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1):
     '''write features trajectories maintain trajectory ID,velocities dataset  '''
     connection = sqlite3.connect(outputFilename)
@@ -72,7 +72,7 @@
 
     cursor.execute("CREATE TABLE IF NOT EXISTS \"positions\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))")
     cursor.execute("CREATE TABLE IF NOT EXISTS \"velocities\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))")
-	
+    
     if trajectoryType == 'feature':
         if type(objectNumbers) == int and objectNumbers == -1:
             for trajectory in objects:
@@ -85,14 +85,14 @@
                     
     connection.commit()
     connection.close()
-	
+    
 def writePrototypesToSqlite(prototypes,nMatching, outputFilename):
     """ prototype dataset is a dictionary with  keys== routes, values== prototypes Ids """
     connection = sqlite3.connect(outputFilename)
     cursor = connection.cursor()
 
     cursor.execute("CREATE TABLE IF NOT EXISTS \"prototypes\"(prototype_id INTEGER,routeIDstart INTEGER,routeIDend INTEGER, nMatching INTEGER, PRIMARY KEY(prototype_id))")
-	
+    
     for route in prototypes.keys():
         if prototypes[route]!=[]:
             for i in prototypes[route]:
@@ -100,7 +100,7 @@
                     
     connection.commit()
     connection.close()
-	
+    
 def loadPrototypesFromSqlite(filename):
     """
     This function loads the prototype file in the database 
@@ -127,7 +127,7 @@
 
     connection.close()
     return prototypes,nMatching
-	
+    
 def writeLabelsToSqlite(labels, outputFilename):
     """ labels is a dictionary with  keys: routes, values: prototypes Ids
     """
@@ -135,7 +135,7 @@
     cursor = connection.cursor()
 
     cursor.execute("CREATE TABLE IF NOT EXISTS \"labels\"(object_id INTEGER,routeIDstart INTEGER,routeIDend INTEGER, prototype_id INTEGER, PRIMARY KEY(object_id))")
-	
+    
     for route in labels.keys():
         if labels[route]!=[]:
             for i in labels[route]:
@@ -144,7 +144,7 @@
                     
     connection.commit()
     connection.close()
-	
+    
 def loadLabelsFromSqlite(filename):
     labels = {}
 
@@ -168,6 +168,50 @@
 
     connection.close()
     return labels
+def writeSpeedPrototypeToSqlite(prototypes,nmatching, outFilename):
+    """ to match the format of second layer prototypes"""
+    connection = sqlite3.connect(outFilename)
+    cursor = connection.cursor()
+
+    cursor.execute("CREATE TABLE IF NOT EXISTS \"speedprototypes\"(spdprototype_id INTEGER,prototype_id INTEGER,routeID_start INTEGER, routeID_end INTEGER, nMatching INTEGER, PRIMARY KEY(spdprototype_id))")
+    
+    for route in prototypes.keys():
+        if prototypes[route]!={}:
+            for i in prototypes[route]:
+                if prototypes[route][i]!= []:
+                    for j in prototypes[route][i]:
+                        cursor.execute("insert into speedprototypes (spdprototype_id,prototype_id, routeID_start, routeID_end, nMatching) values (?,?,?,?,?)",(j,i,route[0],route[1],nmatching[j]))
+                    
+    connection.commit()
+    connection.close()
+    
+def loadSpeedPrototypeFromSqlite(filename):
+    """
+    This function loads the prototypes table in the database of name <filename>.
+    """
+    prototypes = {}
+    nMatching={}
+    connection = sqlite3.connect(filename)
+    cursor = connection.cursor()
+
+    try:
+        cursor.execute('SELECT * from speedprototypes order by spdprototype_id,prototype_id, routeID_start, routeID_end, nMatching')
+    except sqlite3.OperationalError as error:
+        utils.printDBError(error)
+        return []
+
+    for row in cursor:
+        route=(row[2],row[3])
+        if route not in prototypes.keys():
+            prototypes[route]={}
+        if row[1] not in prototypes[route].keys():
+            prototypes[route][row[1]]=[]
+        prototypes[route][row[1]].append(row[0])
+        nMatching[row[0]]=row[4]
+
+    connection.close()
+    return prototypes,nMatching
+
 
 def writeRoutesToSqlite(Routes, outputFilename):
     """ This function writes the activity path define by start and end IDs"""
@@ -175,7 +219,7 @@
     cursor = connection.cursor()
 
     cursor.execute("CREATE TABLE IF NOT EXISTS \"routes\"(object_id INTEGER,routeIDstart INTEGER,routeIDend INTEGER, PRIMARY KEY(object_id))")
-	
+    
     for route in Routes.keys():
         if Routes[route]!=[]:
             for i in Routes[route]:
@@ -183,7 +227,7 @@
                     
     connection.commit()
     connection.close()
-	
+    
 def loadRoutesFromSqlite(filename):
     Routes = {}
 
@@ -203,14 +247,14 @@
         Routes[route].append(row[0])
 
     connection.close()
-    return Routes	
+    return Routes
 
 def setRoutes(filename, objects):
     connection = sqlite3.connect(filename)
     cursor = connection.cursor()
     for obj in objects:
         cursor.execute('update objects set startRouteID = {} where object_id = {}'.format(obj.startRouteID, obj.getNum()))
-        cursor.execute('update objects set endRouteID = {} where object_id = {}'.format(obj.endRouteID, obj.getNum()))	        
+        cursor.execute('update objects set endRouteID = {} where object_id = {}'.format(obj.endRouteID, obj.getNum()))        
     connection.commit()
     connection.close()
 
@@ -843,6 +887,7 @@
         return configDict
 
 
+
 if __name__ == "__main__":
     import doctest
     import unittest