comparison python/storage.py @ 588:c5406edbcf12

added loading ground truth annotations (ground truth) from polytrack format
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 05 Dec 2014 00:54:38 -0500
parents cf578ba866da
children 5800a87f11ae 0954aaf28231
comparison
equal deleted inserted replaced
587:cf578ba866da 588:c5406edbcf12
32 for tableName in tableNames: 32 for tableName in tableNames:
33 cursor.execute('DROP TABLE IF EXISTS '+tableName) 33 cursor.execute('DROP TABLE IF EXISTS '+tableName)
34 except sqlite3.OperationalError as error: 34 except sqlite3.OperationalError as error:
35 printDBError(error) 35 printDBError(error)
36 36
37 # TODO: add test if database connection is open
37 # IO to sqlite 38 # IO to sqlite
38 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1): 39 def writeTrajectoriesToSqlite(objects, outputFilename, trajectoryType, objectNumbers = -1):
39 """ 40 """
40 This function writers trajectories to a specified sqlite file 41 This function writers trajectories to a specified sqlite file
41 @param[in] objects -> a list of trajectories 42 @param[in] objects -> a list of trajectories
269 270
270 returns a moving object''' 271 returns a moving object'''
271 cursor = connection.cursor() 272 cursor = connection.cursor()
272 273
273 try: 274 try:
274 trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType) 275 idQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
275 if trajectoryType == 'feature': 276 if trajectoryType == 'feature':
276 queryStatement = 'SELECT * from '+tableName+' '+trajectoryIdQuery+'ORDER BY trajectory_id, frame_number' 277 queryStatement = 'SELECT * from '+tableName+' '+idQuery+'ORDER BY trajectory_id, frame_number'
277 cursor.execute(queryStatement) 278 cursor.execute(queryStatement)
278 logging.debug(queryStatement) 279 logging.debug(queryStatement)
279 elif trajectoryType == 'object': 280 elif trajectoryType == 'object':
280 queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number' 281 queryStatement = 'SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+idQuery+'group by OF.object_id, P.frame_number ORDER BY OF.object_id, P.frame_number'
281 cursor.execute(queryStatement) 282 cursor.execute(queryStatement)
282 logging.debug(queryStatement) 283 logging.debug(queryStatement)
283 elif trajectoryType == 'bbtop' or trajectoryType == 'bbbottom': 284 elif trajectoryType in ['bbtop', 'bbbottom']:
284 if trajectoryType == 'bbtop': 285 if trajectoryType == 'bbtop':
285 corner = 'top_left' 286 corner = 'top_left'
286 elif trajectoryType == 'bbbottom': 287 elif trajectoryType == 'bbbottom':
287 corner = 'bottom_right' 288 corner = 'bottom_right'
288 queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number' 289 queryStatement = 'SELECT object_id, frame_number, x_'+corner+', y_'+corner+' FROM '+tableName+' '+trajectoryIdQuery+'ORDER BY object_id, frame_number'
298 obj = None 299 obj = None
299 objects = [] 300 objects = []
300 for row in cursor: 301 for row in cursor:
301 if row[0] != objId: 302 if row[0] != objId:
302 objId = row[0] 303 objId = row[0]
303 if obj != None: 304 if obj != None and obj.length() == obj.positions.length():
304 objects.append(obj) 305 objects.append(obj)
306 elif obj != None:
307 print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length()))
305 obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]])) 308 obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]]))
306 else: 309 else:
307 obj.timeInterval.last = row[1] 310 obj.timeInterval.last = row[1]
308 obj.positions.addPositionXY(row[2],row[3]) 311 obj.positions.addPositionXY(row[2],row[3])
309 312
310 if obj: 313 if obj != None and obj.length() == obj.positions.length():
311 objects.append(obj) 314 objects.append(obj)
315 elif obj != None:
316 print('Object {} is missing {} positions'.format(obj.getNum(), int(obj.length())-obj.positions.length()))
312 317
313 return objects 318 return objects
319
320 def loadUserTypesFromTable(cursor, trajectoryType, objectNumbers):
321 objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
322 if objectIdQuery == '':
323 cursor.execute('SELECT object_id, road_user_type from objects')
324 else:
325 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:])
326 userTypes = {}
327 for row in cursor:
328 userTypes[row[0]] = row[1]
329 return userTypes
314 330
315 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None): 331 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None):
316 '''Loads the first objectNumbers objects or the indices in objectNumbers from the database''' 332 '''Loads the first objectNumbers objects or the indices in objectNumbers from the database'''
317 connection = sqlite3.connect(filename) # add test if it open 333 connection = sqlite3.connect(filename)
318 334
319 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers) 335 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers)
320 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers) 336 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers)
321 337
322 if len(objectVelocities) > 0: 338 if len(objectVelocities) > 0:
346 362
347 for obj in objects: 363 for obj in objects:
348 obj.featureNumbers = featureNumbers[obj.getNum()] 364 obj.featureNumbers = featureNumbers[obj.getNum()]
349 365
350 # load userType 366 # load userType
351 if objectIdQuery == '': 367 userTypes = loadUserTypesFromTable(cursor, trajectoryType, objectNumbers)
352 cursor.execute('SELECT object_id, road_user_type from objects')
353 else:
354 cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:])
355 userTypes = {}
356 for row in cursor:
357 userTypes[row[0]] = row[1]
358
359 for obj in objects: 368 for obj in objects:
360 obj.userType = userTypes[obj.getNum()] 369 obj.userType = userTypes[obj.getNum()]
361 370
362 except sqlite3.OperationalError as error: 371 except sqlite3.OperationalError as error:
363 printDBError(error) 372 printDBError(error)
364 return [] 373 objects = []
365 374
366 connection.close() 375 connection.close()
367 return objects 376 return objects
377
378 def loadGroundTruthFromSqlite(filename, gtType, gtNumbers = None):
379 'Loads bounding box annotations (ground truth) from an SQLite '
380 connection = sqlite3.connect(filename)
381 gt = []
382
383 if gtType == 'bb':
384 topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', gtNumbers)
385 bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', gtNumbers)
386 userTypes = loadUserTypesFromTable(connection.cursor(), 'object', gtNumbers) # string format is same as object
387
388 for t, b in zip(topCorners, bottomCorners):
389 num = t.getNum()
390 if t.getNum() == b.getNum():
391 annotation = moving.BBAnnotation(num, t.getTimeInterval(), t, b, userTypes[num])
392 gt.append(annotation)
393 else:
394 print ('Unknown type of annotation {}'.format(gtType))
395
396 connection.close()
397 return gt
368 398
369 def deleteFromSqlite(filename, dataType): 399 def deleteFromSqlite(filename, dataType):
370 'Deletes (drops) some tables in the filename depending on type of data' 400 'Deletes (drops) some tables in the filename depending on type of data'
371 import os 401 import os
372 if os.path.isfile(filename): 402 if os.path.isfile(filename):