Mercurial Hosting > traffic-intelligence
comparison python/storage.py @ 588:c5406edbcf12
added loading ground truth annotations (ground truth) from polytrack format
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Fri, 05 Dec 2014 00:54:38 -0500 |
parents | cf578ba866da |
children | 5800a87f11ae 0954aaf28231 |
comparison
equal
deleted
inserted
replaced
587:cf578ba866da | 588:c5406edbcf12 |
---|---|
32 for tableName in tableNames: | 32 for tableName in tableNames: |
33 cursor.execute('DROP TABLE IF EXISTS '+tableName) | 33 cursor.execute('DROP TABLE IF EXISTS '+tableName) |
34 except sqlite3.OperationalError as error: | 34 except sqlite3.OperationalError as error: |
35 printDBError(error) | 35 printDBError(error) |
36 | 36 |
37 # TODO: add test if database connection is open | |
37 # IO to sqlite | 38 # IO to sqlite |
38 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1): | 39 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1): |
39 """ | 40 """ |
40 This function writers trajectories to a specified sqlite file | 41 This function writers trajectories to a specified sqlite file |
41 @param[in] objects -> a list of trajectories | 42 @param[in] objects -> a list of trajectories |
269 | 270 |
270 returns a moving object''' | 271 returns a moving object''' |
271 cursor = connection.cursor() | 272 cursor = connection.cursor() |
272 | 273 |
273 try: | 274 try: |
274 trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) | 275 idQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) |
275 if trajectoryType == 'feature': | 276 if trajectoryType == 'feature': |
276 queryStatement = 'SELECT * from '+tableName+' '+trajectoryIdQuery+'ORDER BY trajectory_id, frame_number' | 277 queryStatement = 'SELECT * from '+tableName+' '+idQuery+'ORDER BY trajectory_id, frame_number' |
277 cursor.execute(queryStatement) | 278 cursor.execute(queryStatement) |
278 logging.debug(queryStatement) | 279 logging.debug(queryStatement) |
279 elif trajectoryType == 'object': | 280 elif trajectoryType == 'object': |
280 queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number' | 281 queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+idQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number' |
281 cursor.execute(queryStatement) | 282 cursor.execute(queryStatement) |
282 logging.debug(queryStatement) | 283 logging.debug(queryStatement) |
283 elif trajectoryType == 'bbtop' or trajectoryType == 'bbbottom': | 284 elif trajectoryType in ['bbtop', 'bbbottom']: |
284 if trajectoryType == 'bbtop': | 285 if trajectoryType == 'bbtop': |
285 corner = 'top_left' | 286 corner = 'top_left' |
286 elif trajectoryType == 'bbbottom': | 287 elif trajectoryType == 'bbbottom': |
287 corner = 'bottom_right' | 288 corner = 'bottom_right' |
288 queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number' | 289 queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number' |
298 obj = None | 299 obj = None |
299 objects = [] | 300 objects = [] |
300 for row in cursor: | 301 for row in cursor: |
301 if row[0] != objId: | 302 if row[0] != objId: |
302 objId = row[0] | 303 objId = row[0] |
303 if obj != None: | 304 if obj != None and obj.length() == obj.positions.length(): |
304 objects.append(obj) | 305 objects.append(obj) |
306 elif obj != None: | |
307 print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length())) | |
305 obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]])) | 308 obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]])) |
306 else: | 309 else: |
307 obj.timeInterval.last = row[1] | 310 obj.timeInterval.last = row[1] |
308 obj.positions.addPositionXY(row[2],row[3]) | 311 obj.positions.addPositionXY(row[2],row[3]) |
309 | 312 |
310 if obj: | 313 if obj != None and obj.length() == obj.positions.length(): |
311 objects.append(obj) | 314 objects.append(obj) |
315 elif obj != None: | |
316 print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length())) | |
312 | 317 |
313 return objects | 318 return objects |
319 | |
320 def loadUserTypesFromTable(cursor, trajectoryType, objectNumbers): | |
321 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) | |
322 if objectIdQuery == '': | |
323 cursor.execute('SELECT object_id, road_user_type from objects') | |
324 else: | |
325 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:]) | |
326 userTypes = {} | |
327 for row in cursor: | |
328 userTypes[row[0]] = row[1] | |
329 return userTypes | |
314 | 330 |
315 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None): | 331 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None): |
316 '''Loads the first objectNumbers objects or the indices in objectNumbers from the database''' | 332 '''Loads the first objectNumbers objects or the indices in objectNumbers from the database''' |
317 connection = sqlite3.connect(filename) # add test if it open | 333 connection = sqlite3.connect(filename) |
318 | 334 |
319 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers) | 335 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers) |
320 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers) | 336 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers) |
321 | 337 |
322 if len(objectVelocities) > 0: | 338 if len(objectVelocities) > 0: |
346 | 362 |
347 for obj in objects: | 363 for obj in objects: |
348 obj.featureNumbers = featureNumbers[obj.getNum()] | 364 obj.featureNumbers = featureNumbers[obj.getNum()] |
349 | 365 |
350 # load userType | 366 # load userType |
351 if objectIdQuery == '': | 367 userTypes = loadUserTypesFromTable(cursor, trajectoryType, objectNumbers) |
352 cursor.execute('SELECT object_id, road_user_type from objects') | |
353 else: | |
354 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:]) | |
355 userTypes = {} | |
356 for row in cursor: | |
357 userTypes[row[0]] = row[1] | |
358 | |
359 for obj in objects: | 368 for obj in objects: |
360 obj.userType = userTypes[obj.getNum()] | 369 obj.userType = userTypes[obj.getNum()] |
361 | 370 |
362 except sqlite3.OperationalError as error: | 371 except sqlite3.OperationalError as error: |
363 printDBError(error) | 372 printDBError(error) |
364 return [] | 373 objects = [] |
365 | 374 |
366 connection.close() | 375 connection.close() |
367 return objects | 376 return objects |
377 | |
378 def loadGroundTruthFromSqlite(filename, gtType, gtNumbers = None): | |
379 'Loads bounding box annotations (ground truth) from an SQLite ' | |
380 connection = sqlite3.connect(filename) | |
381 gt = [] | |
382 | |
383 if gtType == 'bb': | |
384 topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', gtNumbers) | |
385 bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', gtNumbers) | |
386 userTypes = loadUserTypesFromTable(connection.cursor(), 'object', gtNumbers) # string format is same as object | |
387 | |
388 for t, b in zip(topCorners, bottomCorners): | |
389 num = t.getNum() | |
390 if t.getNum() == b.getNum(): | |
391 annotation = moving.BBAnnotation(num, t.getTimeInterval(), t, b, userTypes[num]) | |
392 gt.append(annotation) | |
393 else: | |
394 print ('Unknown type of annotation {}'.format(gtType)) | |
395 | |
396 connection.close() | |
397 return gt | |
368 | 398 |
369 def deleteFromSqlite(filename, dataType): | 399 def deleteFromSqlite(filename, dataType): |
370 'Deletes (drops) some tables in the filename depending on type of data' | 400 'Deletes (drops) some tables in the filename depending on type of data' |
371 import os | 401 import os |
372 if os.path.isfile(filename): | 402 if os.path.isfile(filename): |