comparison python/storage.py @ 329:a70c205ebdd9

added sqlite code, in particular to load and save road user type
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Thu, 13 Jun 2013 00:42:40 -0400
parents 42f2b46ec210
children 00800ebae698
comparison
equal deleted inserted replaced
328:5e43b7389c25 329:a70c205ebdd9
10 10
11 ngsimUserTypes = {'twowheels':1, 11 ngsimUserTypes = {'twowheels':1,
12 'car':2, 12 'car':2,
13 'truck':3} 13 'truck':3}
14 14
15 sqliteUserTypeNames = ['unknown',
16 'car',
17 'pedestrian',
18 'motorcycle',
19 'bicycle',
20 'bus',
21 'truck']
22
23
24 #########################
25 # txt files
26 #########################
27
28
29
30 ######################### 15 #########################
31 # Sqlite 16 # Sqlite
32 ######################### 17 #########################
33 18
34 def writeTrajectoriesToSqlite(objects, outFile, trajectoryType, objectNumbers = -1): 19 def writeTrajectoriesToSqlite(objects, outFilename, trajectoryType, objectNumbers = -1):
35 """ 20 """
36 This function writers trajectories to a specified sqlite file 21 This function writers trajectories to a specified sqlite file
37 @param[in] objects -> a list of trajectories 22 @param[in] objects -> a list of trajectories
38 @param[in] trajectoryType - 23 @param[in] trajectoryType -
39 @param[out] outFile -> the .sqlite file containting the written objects 24 @param[out] outFile -> the .sqlite file containting the written objects
40 @param[in] objectNumber : number of objects loaded 25 @param[in] objectNumber : number of objects loaded
41 """ 26 """
42 27
43 import sqlite3 28 import sqlite3
44 connection = sqlite3.connect(outFile) 29 connection = sqlite3.connect(outFilename)
45 cursor = connection.cursor() 30 cursor = connection.cursor()
46 31
47 schema = "CREATE TABLE \"positions\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))" 32 schema = "CREATE TABLE \"positions\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))"
48 cursor.execute(schema) 33 cursor.execute(schema)
49 34
57 for position in trajectory.getPositions(): 42 for position in trajectory.getPositions():
58 frame_number += 1 43 frame_number += 1
59 query = "insert into positions (trajectory_id, frame_number, x_coordinate, y_coordinate) values (?,?,?,?)" 44 query = "insert into positions (trajectory_id, frame_number, x_coordinate, y_coordinate) values (?,?,?,?)"
60 cursor.execute(query,(trajectory_id,frame_number,position.x,position.y)) 45 cursor.execute(query,(trajectory_id,frame_number,position.x,position.y))
61 46
62 connection.commit() 47 connection.commit()
48 connection.close()
49
50 def setRoadUserTypes(filename, objects):
51 import sqlite3
52 connection = sqlite3.connect(filename)
53 cursor = connection.cursor()
54 for obj in objects:
55 cursor.execute('update objects set road_user_type = {} where object_id = {}'.format(obj.getUserType(), obj.getNum()))
56 connection.commit()
63 connection.close() 57 connection.close()
64 58
65 def loadPrototypeMatchIndexesFromSqlite(filename): 59 def loadPrototypeMatchIndexesFromSqlite(filename):
66 """ 60 """
67 This function loads the prototypes table in the database of name <filename>. 61 This function loads the prototypes table in the database of name <filename>.
85 connection.close() 79 connection.close()
86 return matched_indexes 80 return matched_indexes
87 81
88 def getTrajectoryIdQuery(objectNumbers, trajectoryType): 82 def getTrajectoryIdQuery(objectNumbers, trajectoryType):
89 if trajectoryType == 'feature': 83 if trajectoryType == 'feature':
90 statementBeginning = ' where trajectory_id' 84 statementBeginning = 'trajectory_id'
91 elif trajectoryType == 'object': 85 elif trajectoryType == 'object':
92 statementBeginning = ' and OF.object_id' 86 statementBeginning = 'object_id'
93 else: 87 else:
94 print('no trajectory type was chosen') 88 print('no trajectory type was chosen')
95 89
96 if type(objectNumbers) == int: 90 if type(objectNumbers) == int:
97 if objectNumbers == -1: 91 if objectNumbers == -1:
112 cursor = connection.cursor() 106 cursor = connection.cursor()
113 107
114 try: 108 try:
115 if trajectoryType == 'feature': 109 if trajectoryType == 'feature':
116 trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) 110 trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
117 cursor.execute('SELECT * from '+tableName+trajectoryIdQuery+' order by trajectory_id, frame_number') 111 cursor.execute('SELECT * from '+tableName+' where '+trajectoryIdQuery+' order by trajectory_id, frame_number')
118 elif trajectoryType == 'object': 112 elif trajectoryType == 'object':
119 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) 113 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
120 cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+' group by object_id, frame_number') 114 cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.'+objectIdQuery+' group by OF.object_id, P.frame_number order by OF.object_id, P.frame_number')
121 else: 115 else:
122 print('no trajectory type was chosen') 116 print('no trajectory type was chosen')
123 except sqlite3.OperationalError as err: 117 except sqlite3.OperationalError as err:
124 print('DB Error: {0}'.format(err)) 118 print('DB Error: {0}'.format(err))
125 return [] 119 return []
154 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers) 148 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers)
155 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers) 149 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers)
156 150
157 if len(objectVelocities) > 0: 151 if len(objectVelocities) > 0:
158 for o,v in zip(objects, objectVelocities): 152 for o,v in zip(objects, objectVelocities):
159 if o.num == v.num: 153 if o.getNum() == v.getNum():
160 o.velocities = v.positions 154 o.velocities = v.positions
161 else: 155 else:
162 print('Could not match positions {0} with velocities {1}'.format(o.num, v.num)) 156 print('Could not match positions {0} with velocities {1}'.format(o.getNum(), v.getNum()))
163 157
164 if trajectoryType == 'object': 158 if trajectoryType == 'object':
165 cursor = connection.cursor() 159 cursor = connection.cursor()
166 try: 160 try:
161 # attribute feature numbers to objects
167 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) 162 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
168 cursor.execute('SELECT P.trajectory_id, OF.object_id from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+' group by P.trajectory_id order by OF.object_id') 163 cursor.execute('SELECT P.trajectory_id, OF.object_id from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.'+objectIdQuery+' group by P.trajectory_id order by OF.object_id') # order is important to group all features per object
169 164
170 # attribute feature numbers to objects
171 objId = -1
172 featureNumbers = {} 165 featureNumbers = {}
173 for row in cursor: 166 for row in cursor:
174 if row[1] != objId: 167 objId = row[1]
175 objId = row[1] 168 if objId not in featureNumbers:
176 featureNumbers[objId] = [row[0]] 169 featureNumbers[objId] = [row[0]]
177 else: 170 else:
178 featureNumbers[objId].append(row[0]) 171 featureNumbers[objId].append(row[0])
179 172
180 for obj in objects: 173 for obj in objects:
181 obj.featureNumbers = featureNumbers[obj.num] 174 obj.featureNumbers = featureNumbers[obj.getNum()]
175
176 # load userType
177 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery)
178 userTypes = {}
179 for row in cursor:
180 userTypes[row[0]] = row[1]
181
182 for obj in objects:
183 obj.userType = userTypes[obj.getNum()]
184
182 except sqlite3.OperationalError as err: 185 except sqlite3.OperationalError as err:
183 print('DB Error: {0}'.format(err)) 186 print('DB Error: {0}'.format(err))
184 return [] 187 return []
185 188
186 connection.close() 189 connection.close()
190 'Removes the objects and object_features tables in the filename' 193 'Removes the objects and object_features tables in the filename'
191 import sqlite3 194 import sqlite3
192 connection = sqlite3.connect(filename) 195 connection = sqlite3.connect(filename)
193 utils.dropTables(connection, ['objects', 'objects_features']) 196 utils.dropTables(connection, ['objects', 'objects_features'])
194 connection.close() 197 connection.close()
198
199
200 #########################
201 # txt files
202 #########################
195 203
196 def loadTrajectoriesFromNgsimFile(filename, nObjects = -1, sequenceNum = -1): 204 def loadTrajectoriesFromNgsimFile(filename, nObjects = -1, sequenceNum = -1):
197 '''Reads data from the trajectory data provided by NGSIM project 205 '''Reads data from the trajectory data provided by NGSIM project
198 and returns the list of Feature objects''' 206 and returns the list of Feature objects'''
199 objects = [] 207 objects = []
229 if (len(numbers) > 0): 237 if (len(numbers) > 0):
230 obj = createObject(numbers) 238 obj = createObject(numbers)
231 239
232 for line in input: 240 for line in input:
233 numbers = line.strip().split() 241 numbers = line.strip().split()
234 if obj.num != int(numbers[0]): 242 if obj.getNum() != int(numbers[0]):
235 # check and adapt the length to deal with issues in NGSIM data 243 # check and adapt the length to deal with issues in NGSIM data
236 if (obj.length() != obj.positions.length()): 244 if (obj.length() != obj.positions.length()):
237 print 'length pb with object %s (%d,%d)' % (obj.num,obj.length(),obj.positions.length()) 245 print 'length pb with object %s (%d,%d)' % (obj.getNum(),obj.length(),obj.positions.length())
238 obj.last = obj.getFirstInstant()+obj.positions.length()-1 246 obj.last = obj.getFirstInstant()+obj.positions.length()-1
239 #obj.velocities = utils.computeVelocities(f.positions) # compare norm to speeds ? 247 #obj.velocities = utils.computeVelocities(f.positions) # compare norm to speeds ?
240 objects.append(obj) 248 objects.append(obj)
241 if (nObjects>0) and (len(objects)>=nObjects): 249 if (nObjects>0) and (len(objects)>=nObjects):
242 break 250 break
250 obj.followingVehicles.append(int(numbers[15])) 258 obj.followingVehicles.append(int(numbers[15]))
251 obj.spaceHeadways.append(float(numbers[16])) 259 obj.spaceHeadways.append(float(numbers[16]))
252 obj.timeHeadways.append(float(numbers[17])) 260 obj.timeHeadways.append(float(numbers[17]))
253 261
254 if (obj.size[0] != float(numbers[8])): 262 if (obj.size[0] != float(numbers[8])):
255 print 'changed length obj %d' % (f.num) 263 print 'changed length obj %d' % (obj.getNum())
256 if (obj.size[1] != float(numbers[9])): 264 if (obj.size[1] != float(numbers[9])):
257 print 'changed width obj %d' % (f.num) 265 print 'changed width obj %d' % (obj.getNum())
258 266
259 input.close() 267 input.close()
260 return objects 268 return objects
261 269
262 def convertNgsimFile(inFile, outFile, append = False, nObjects = -1, sequenceNum = 0): 270 def convertNgsimFile(inFile, outFile, append = False, nObjects = -1, sequenceNum = 0):