Mercurial Hosting > traffic-intelligence
diff python/storage.py @ 918:3a06007a4bb7
modularized save trajectories, added slice to Trajectory, etc
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 05 Jul 2017 12:19:59 -0400 |
parents | 89cc05867c4c |
children | 7b3f2e0a2652 |
line wrap: on
line diff
--- a/python/storage.py Tue Jul 04 18:00:01 2017 -0400 +++ b/python/storage.py Wed Jul 05 12:19:59 2017 -0400 @@ -375,7 +375,70 @@ if len(missingObjectNumbers) > 0: print('List of missing objects to attach corresponding curvilinear trajectories: {}'.format(missingObjectNumbers)) -def saveTrajectoriesToSqlite(outputFilename, objects, trajectoryType, withFeatures = False): +def saveTrajectoriesToTable(connection, objects, trajectoryType, tablePrefix = None): + 'Saves trajectories in table tableName' + cursor = connection.cursor() + # Parse feature and/or object structure and commit to DB + if(trajectoryType == 'feature' or trajectoryType == 'object'): + # Extract features from objects + if trajectoryType == 'object': + features = [] + for obj in objects: + if obj.hasFeatures(): + features += obj.getFeatures() + if len(features) == 0: + print('Warning, objects have no features') # todo save centroid trajectories? + elif trajectoryType == 'feature': + features = objects + # Setup feature queries + if tablePrefix is None: + prefix = '' + else: + prefix = tablePrefix+'_' + createTrajectoryTable(cursor, prefix+"positions") + createTrajectoryTable(cursor, prefix+"velocities") + positionQuery = insertTrajectoryQuery(prefix+"positions") + velocityQuery = insertTrajectoryQuery(prefix+"velocities") + # Setup object queries + if trajectoryType == 'object': + createObjectsTable(cursor) + createObjectsFeaturesTable(cursor) + objectQuery = insertObjectQuery() + objectFeatureQuery = insertObjectFeatureQuery() + for feature in features: + num = feature.getNum() + frameNum = feature.getFirstInstant() + for p in feature.getPositions(): + cursor.execute(positionQuery, (num, frameNum, p.x, p.y)) + frameNum += 1 + velocities = feature.getVelocities() + if velocities is not None: + frameNum = feature.getFirstInstant() + for v in velocities[:-1]: + cursor.execute(velocityQuery, (num, frameNum, v.x, v.y)) + frameNum += 1 + if trajectoryType == 'object': + for obj in objects: + if obj.hasFeatures(): + for feature in obj.getFeatures(): + featureNum = feature.getNum() + cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum)) + cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1)) + # Parse curvilinear position structure + elif(trajectoryType == 'curvilinear'): + createCurvilinearTrajectoryTable(cursor) + curvilinearQuery = "insert into curvilinear_positions (trajectory_id, frame_number, s_coordinate, y_coordinate, lane) values (?,?,?,?,?)" + for obj in objects: + num = obj.getNum() + frameNum = obj.getFirstInstant() + for p in obj.getCurvilinearPositions(): + cursor.execute(curvilinearQuery, (num, frameNum, p[0], p[1], p[2])) + frameNum += 1 + else: + print('Unknown trajectory type {}'.format(trajectoryType)) + connection.commit() + +def saveTrajectoriesToSqlite(outputFilename, objects, trajectoryType): '''Writes features, ie the trajectory positions (and velocities if exist) with their instants to a specified sqlite file Either feature positions (and velocities if they exist) @@ -383,60 +446,7 @@ connection = sqlite3.connect(outputFilename) try: - cursor = connection.cursor() - # Parse feature and/or object structure and commit to DB - if(trajectoryType == 'feature' or trajectoryType == 'object'): - # Extract features from objects - if(trajectoryType == 'object'): - features = [] - for obj in objects: - if(obj.hasFeatures()): - features += obj.getFeatures() - elif(trajectoryType == 'feature'): - features = objects - # Setup feature queries - createTrajectoryTable(cursor, "positions") - createTrajectoryTable(cursor, "velocities") - positionQuery = insertTrajectoryQuery("positions") - velocityQuery = insertTrajectoryQuery("velocities") - # Setup object queries - if(trajectoryType == 'object'): - createObjectsTable(cursor) - createObjectsFeaturesTable(cursor) - objectQuery = insertObjectQuery() - objectFeatureQuery = insertObjectFeatureQuery() - for feature in features: - num = feature.getNum() - frameNum = feature.getFirstInstant() - for position in feature.getPositions(): - cursor.execute(positionQuery, (num, frameNum, position.x, position.y)) - frameNum += 1 - velocities = feature.getVelocities() - if velocities is not None: - frameNum = feature.getFirstInstant() - for i in xrange(velocities.length()-1): - v = velocities[i] - cursor.execute(velocityQuery, (num, frameNum, v.x, v.y)) - frameNum += 1 - if(trajectoryType == 'object'): - for obj in objects: - for feature in obj.getFeatures(): - featureNum = feature.getNum() - cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum)) - cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1)) - # Parse curvilinear position structure - elif(trajectoryType == 'curvilinear'): - createCurvilinearTrajectoryTable(cursor) - curvilinearQuery = "insert into curvilinear_positions (trajectory_id, frame_number, s_coordinate, y_coordinate, lane) values (?,?,?,?,?)" - for obj in objects: - num = obj.getNum() - frameNum = obj.getFirstInstant() - for position in obj.getCurvilinearPositions(): - cursor.execute(curvilinearQuery, (num, frameNum, position[0], position[1], position[2])) - frameNum += 1 - else: - print('Unknown trajectory type {}'.format(trajectoryType)) - connection.commit() + saveTrajectoriesToTable(connection, objects, trajectoryType, None) except sqlite3.OperationalError as error: printDBError(error) connection.close() @@ -598,7 +608,7 @@ dbfn = filename cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) #cursor.execute('SELECT * from sqlite_master WHERE type = \"table\" and name = \"{}\"'.format(tableNames[trajectoryType])) - if objects is not None: + if objects is not None: # save positions and velocities pass except sqlite3.OperationalError as error: printDBError(error)