Mercurial Hosting > traffic-intelligence
diff python/storage.py @ 938:fbf12382f3f8
replaced db connection using with
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Mon, 17 Jul 2017 16:11:18 -0400 |
parents | 0e63a918a1ca |
children | b1e8453c207c |
line wrap: on
line diff
--- a/python/storage.py Mon Jul 17 01:38:06 2017 -0400 +++ b/python/storage.py Mon Jul 17 16:11:18 2017 -0400 @@ -44,27 +44,25 @@ def deleteFromSqlite(filename, dataType): 'Deletes (drops) some tables in the filename depending on type of data' if path.isfile(filename): - connection = sqlite3.connect(filename) - if dataType == 'object': - dropTables(connection, ['objects', 'objects_features']) - elif dataType == 'interaction': - dropTables(connection, ['interactions', 'indicators']) - elif dataType == 'bb': - dropTables(connection, ['bounding_boxes']) - elif dataType == 'pois': - dropTables(connection, ['gaussians2d', 'objects_pois']) - elif dataType == 'prototype': - dropTables(connection, ['prototypes', 'objects_prototypes']) - else: - print('Unknown data type {} to delete from database'.format(dataType)) - connection.close() + with sqlite3.connect(filename) as connection: + if dataType == 'object': + dropTables(connection, ['objects', 'objects_features']) + elif dataType == 'interaction': + dropTables(connection, ['interactions', 'indicators']) + elif dataType == 'bb': + dropTables(connection, ['bounding_boxes']) + elif dataType == 'pois': + dropTables(connection, ['gaussians2d', 'objects_pois']) + elif dataType == 'prototype': + dropTables(connection, ['prototypes', 'objects_prototypes']) + else: + print('Unknown data type {} to delete from database'.format(dataType)) else: print('{} does not exist'.format(filename)) def tableExists(connection, tableName): 'indicates if the table exists in the database' try: - #connection = sqlite3.connect(filename) cursor = connection.cursor() cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'') return cursor.fetchone()[0] == 1 @@ -112,7 +110,6 @@ '''Creates an index for the column in the table I will make querying with a condition on this column faster''' try: - #connection = sqlite3.connect(filename) cursor = connection.cursor() s = "CREATE " if unique: @@ -250,113 +247,110 @@ either features, objects (feature groups) or bounding box series) The number loaded is either the first objectNumbers objects, or the indices in objectNumbers from the database''' - connection = sqlite3.connect(filename) - - if tablePrefix is None: - prefix = '' - else: - prefix = tablePrefix + '_' - objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep) - objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep) + objects = [] + with sqlite3.connect(filename) as connection: + if tablePrefix is None: + prefix = '' + else: + prefix = tablePrefix + '_' + objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep) + objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep) - if len(objectVelocities) > 0: - for o,v in zip(objects, objectVelocities): - if o.getNum() == v.getNum(): - o.velocities = v.positions - o.velocities.duplicateLastPosition() # avoid having velocity shorter by one position than positions - else: - print('Could not match positions {0} with velocities {1}'.format(o.getNum(), v.getNum())) - - if trajectoryType == 'object': - cursor = connection.cursor() - try: - # attribute feature numbers to objects - queryStatement = 'SELECT trajectory_id, object_id FROM objects_features' - if objectNumbers is not None: - queryStatement += ' WHERE object_id '+getObjectCriteria(objectNumbers) - queryStatement += ' ORDER BY object_id' # order is important to group all features per object - logging.debug(queryStatement) - cursor.execute(queryStatement) + if len(objectVelocities) > 0: + for o,v in zip(objects, objectVelocities): + if o.getNum() == v.getNum(): + o.velocities = v.positions + o.velocities.duplicateLastPosition() # avoid having velocity shorter by one position than positions + else: + print('Could not match positions {0} with velocities {1}'.format(o.getNum(), v.getNum())) - featureNumbers = {} - for row in cursor: - objId = row[1] - if objId not in featureNumbers: - featureNumbers[objId] = [row[0]] - else: - featureNumbers[objId].append(row[0]) - - for obj in objects: - obj.featureNumbers = featureNumbers[obj.getNum()] + if trajectoryType == 'object': + cursor = connection.cursor() + try: + # attribute feature numbers to objects + queryStatement = 'SELECT trajectory_id, object_id FROM objects_features' + if objectNumbers is not None: + queryStatement += ' WHERE object_id '+getObjectCriteria(objectNumbers) + queryStatement += ' ORDER BY object_id' # order is important to group all features per object + logging.debug(queryStatement) + cursor.execute(queryStatement) + + featureNumbers = {} + for row in cursor: + objId = row[1] + if objId not in featureNumbers: + featureNumbers[objId] = [row[0]] + else: + featureNumbers[objId].append(row[0]) - # load userType - userTypes = loadUserTypesFromTable(cursor, objectNumbers) - for obj in objects: - obj.userType = userTypes[obj.getNum()] + for obj in objects: + obj.featureNumbers = featureNumbers[obj.getNum()] - if withFeatures: - nFeatures = 0 + # load userType + userTypes = loadUserTypesFromTable(cursor, objectNumbers) for obj in objects: - nFeatures = max(nFeatures, max(obj.featureNumbers)) - features = loadTrajectoriesFromSqlite(filename, 'feature', nFeatures+1, timeStep = timeStep) - for obj in objects: - obj.setFeatures(features) - - except sqlite3.OperationalError as error: - printDBError(error) - objects = [] + obj.userType = userTypes[obj.getNum()] - connection.close() + if withFeatures: + nFeatures = 0 + for obj in objects: + nFeatures = max(nFeatures, max(obj.featureNumbers)) + features = loadTrajectoriesFromSqlite(filename, 'feature', nFeatures+1, timeStep = timeStep) + for obj in objects: + obj.setFeatures(features) + + except sqlite3.OperationalError as error: + printDBError(error) return objects def loadObjectFeatureFrameNumbers(filename, objectNumbers = None): 'Loads the feature frame numbers for each object' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - queryStatement = 'SELECT OF.object_id, TL.trajectory_id, TL.length FROM (SELECT trajectory_id, max(frame_number)-min(frame_number) AS length FROM positions GROUP BY trajectory_id) TL, objects_features OF WHERE TL.trajectory_id = OF.trajectory_id' - if objectNumbers is not None: - queryStatement += ' AND object_id '+getObjectCriteria(objectNumbers) - queryStatement += ' ORDER BY OF.object_id, TL.length DESC' - logging.debug(queryStatement) - cursor.execute(queryStatement) - objectFeatureNumbers = {} - for row in cursor: - objId = row[0] - if objId in objectFeatureNumbers: - objectFeatureNumbers[objId].append(row[1]) - else: - objectFeatureNumbers[objId] = [row[1]] - return objectFeatureNumbers - except sqlite3.OperationalError as error: - printDBError(error) - return None + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + queryStatement = 'SELECT OF.object_id, TL.trajectory_id, TL.length FROM (SELECT trajectory_id, max(frame_number)-min(frame_number) AS length FROM positions GROUP BY trajectory_id) TL, objects_features OF WHERE TL.trajectory_id = OF.trajectory_id' + if objectNumbers is not None: + queryStatement += ' AND object_id '+getObjectCriteria(objectNumbers) + queryStatement += ' ORDER BY OF.object_id, TL.length DESC' + logging.debug(queryStatement) + cursor.execute(queryStatement) + objectFeatureNumbers = {} + for row in cursor: + objId = row[0] + if objId in objectFeatureNumbers: + objectFeatureNumbers[objId].append(row[1]) + else: + objectFeatureNumbers[objId] = [row[1]] + return objectFeatureNumbers + except sqlite3.OperationalError as error: + printDBError(error) + return None def addCurvilinearTrajectoriesFromSqlite(filename, objects): '''Adds curvilinear positions (s_coordinate, y_coordinate, lane) from a database to an existing MovingObject dict (indexed by each objects's num)''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + + try: + cursor.execute('SELECT * from curvilinear_positions order by trajectory_id, frame_number') + except sqlite3.OperationalError as error: + printDBError(error) + return [] - try: - cursor.execute('SELECT * from curvilinear_positions order by trajectory_id, frame_number') - except sqlite3.OperationalError as error: - printDBError(error) - return [] - - missingObjectNumbers = [] - objNum = None - for row in cursor: - if objNum != row[0]: - objNum = row[0] + missingObjectNumbers = [] + objNum = None + for row in cursor: + if objNum != row[0]: + objNum = row[0] + if objNum in objects: + objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory() + else: + missingObjectNumbers.append(objNum) if objNum in objects: - objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory() - else: - missingObjectNumbers.append(objNum) - if objNum in objects: - objects[objNum].curvilinearPositions.addPositionSYL(row[2],row[3],row[4]) - if len(missingObjectNumbers) > 0: - print('List of missing objects to attach corresponding curvilinear trajectories: {}'.format(missingObjectNumbers)) + objects[objNum].curvilinearPositions.addPositionSYL(row[2],row[3],row[4]) + if len(missingObjectNumbers) > 0: + print('List of missing objects to attach corresponding curvilinear trajectories: {}'.format(missingObjectNumbers)) def saveTrajectoriesToTable(connection, objects, trajectoryType, tablePrefix = None): 'Saves trajectories in table tableName' @@ -427,12 +421,20 @@ Either feature positions (and velocities if they exist) or curvilinear positions will be saved at a time''' - connection = sqlite3.connect(outputFilename) - try: - saveTrajectoriesToTable(connection, objects, trajectoryType, None) - except sqlite3.OperationalError as error: - printDBError(error) - connection.close() + with sqlite3.connect(outputFilename) as connection: + try: + saveTrajectoriesToTable(connection, objects, trajectoryType, None) + except sqlite3.OperationalError as error: + printDBError(error) + +def setRoadUserTypes(filename, objects): + '''Saves the user types of the objects in the sqlite database stored in filename + The objects should exist in the objects table''' + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + for obj in objects: + cursor.execute('update objects set road_user_type = {} WHERE object_id = {}'.format(obj.getUserType(), obj.getNum())) + connection.commit() def loadBBMovingObjectsFromSqlite(filename, objectType = 'bb', objectNumbers = None, timeStep = None): '''Loads bounding box moving object from an SQLite @@ -440,23 +442,20 @@ or Urban Tracker Load descriptions?''' - connection = sqlite3.connect(filename) objects = [] + with sqlite3.connect(filename) as connection: + if objectType == 'bb': + topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', objectNumbers, timeStep) + bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', objectNumbers, timeStep) + userTypes = loadUserTypesFromTable(connection.cursor(), objectNumbers) # string format is same as object - if objectType == 'bb': - topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', objectNumbers, timeStep) - bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', objectNumbers, timeStep) - userTypes = loadUserTypesFromTable(connection.cursor(), objectNumbers) # string format is same as object - - for t, b in zip(topCorners, bottomCorners): - num = t.getNum() - if t.getNum() == b.getNum(): - annotation = moving.BBMovingObject(num, t.getTimeInterval(), t, b, userTypes[num]) - objects.append(annotation) - else: - print ('Unknown type of bounding box {}'.format(objectType)) - - connection.close() + for t, b in zip(topCorners, bottomCorners): + num = t.getNum() + if t.getNum() == b.getNum(): + annotation = moving.BBMovingObject(num, t.getTimeInterval(), t, b, userTypes[num]) + objects.append(annotation) + else: + print ('Unknown type of bounding box {}'.format(objectType)) return objects def saveInteraction(cursor, interaction): @@ -465,16 +464,15 @@ def saveInteractionsToSqlite(filename, interactions): 'Saves the interactions in the table' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - createInteractionTable(cursor) - for inter in interactions: - saveInteraction(cursor, inter) - except sqlite3.OperationalError as error: - printDBError(error) - connection.commit() - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + createInteractionTable(cursor) + for inter in interactions: + saveInteraction(cursor, inter) + except sqlite3.OperationalError as error: + printDBError(error) + connection.commit() def saveIndicator(cursor, interactionNum, indicator): for instant in indicator.getTimeInterval(): @@ -483,51 +481,49 @@ def saveIndicatorsToSqlite(filename, interactions, indicatorNames = events.Interaction.indicatorNames): 'Saves the indicator values in the table' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - createInteractionTable(cursor) - createIndicatorTable(cursor) - for inter in interactions: - saveInteraction(cursor, inter) - for indicatorName in indicatorNames: - indicator = inter.getIndicator(indicatorName) - if indicator is not None: - saveIndicator(cursor, inter.getNum(), indicator) - except sqlite3.OperationalError as error: - printDBError(error) - connection.commit() - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + createInteractionTable(cursor) + createIndicatorTable(cursor) + for inter in interactions: + saveInteraction(cursor, inter) + for indicatorName in indicatorNames: + indicator = inter.getIndicator(indicatorName) + if indicator is not None: + saveIndicator(cursor, inter.getNum(), indicator) + except sqlite3.OperationalError as error: + printDBError(error) + connection.commit() def loadInteractionsFromSqlite(filename): '''Loads interaction and their indicators TODO choose the interactions to load''' interactions = [] - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - cursor.execute('SELECT INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number') - interactionNum = -1 - indicatorTypeNum = -1 - tmpIndicators = {} - for row in cursor: - if row[0] != interactionNum: - interactionNum = row[0] - interactions.append(events.Interaction(interactionNum, moving.TimeInterval(row[3],row[4]), row[1], row[2])) - interactions[-1].indicators = {} - if indicatorTypeNum != row[5] or row[0] != interactionNum: - indicatorTypeNum = row[5] - indicatorName = events.Interaction.indicatorNames[indicatorTypeNum] - indicatorValues = {row[6]:row[7]} - interactions[-1].indicators[indicatorName] = indicators.SeverityIndicator(indicatorName, indicatorValues, mostSevereIsMax = not indicatorName in events.Interaction.timeIndicators) - else: - indicatorValues[row[6]] = row[7] - interactions[-1].indicators[indicatorName].timeInterval.last = row[6] - except sqlite3.OperationalError as error: - printDBError(error) - return [] - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('SELECT INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number') + interactionNum = -1 + indicatorTypeNum = -1 + tmpIndicators = {} + for row in cursor: + if row[0] != interactionNum: + interactionNum = row[0] + interactions.append(events.Interaction(interactionNum, moving.TimeInterval(row[3],row[4]), row[1], row[2])) + interactions[-1].indicators = {} + if indicatorTypeNum != row[5] or row[0] != interactionNum: + indicatorTypeNum = row[5] + indicatorName = events.Interaction.indicatorNames[indicatorTypeNum] + indicatorValues = {row[6]:row[7]} + interactions[-1].indicators[indicatorName] = indicators.SeverityIndicator(indicatorName, indicatorValues, mostSevereIsMax = not indicatorName in events.Interaction.timeIndicators) + else: + indicatorValues[row[6]] = row[7] + interactions[-1].indicators[indicatorName].timeInterval.last = row[6] + except sqlite3.OperationalError as error: + printDBError(error) + return [] return interactions # load first and last object instants # CREATE TEMP TABLE IF NOT EXISTS object_instants AS SELECT OF.object_id, min(frame_number) as first_instant, max(frame_number) as last_instant from positions P, objects_features OF WHERE P.trajectory_id = OF.trajectory_id group by OF.object_id order by OF.object_id @@ -535,35 +531,33 @@ def createBoundingBoxTable(filename, invHomography = None): '''Create the table to store the object bounding boxes in image space ''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - cursor.execute('CREATE TABLE IF NOT EXISTS bounding_boxes (object_id INTEGER, frame_number INTEGER, x_top_left REAL, y_top_left REAL, x_bottom_right REAL, y_bottom_right REAL, PRIMARY KEY(object_id, frame_number))') - cursor.execute('INSERT INTO bounding_boxes SELECT object_id, frame_number, min(x), min(y), max(x), max(y) from ' - '(SELECT object_id, frame_number, (x*{}+y*{}+{})/w as x, (x*{}+y*{}+{})/w as y from ' - '(SELECT OF.object_id, P.frame_number, P.x_coordinate as x, P.y_coordinate as y, P.x_coordinate*{}+P.y_coordinate*{}+{} as w from positions P, objects_features OF WHERE P.trajectory_id = OF.trajectory_id)) '.format(invHomography[0,0], invHomography[0,1], invHomography[0,2], invHomography[1,0], invHomography[1,1], invHomography[1,2], invHomography[2,0], invHomography[2,1], invHomography[2,2])+ - 'GROUP BY object_id, frame_number') - except sqlite3.OperationalError as error: - printDBError(error) - connection.commit() - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('CREATE TABLE IF NOT EXISTS bounding_boxes (object_id INTEGER, frame_number INTEGER, x_top_left REAL, y_top_left REAL, x_bottom_right REAL, y_bottom_right REAL, PRIMARY KEY(object_id, frame_number))') + cursor.execute('INSERT INTO bounding_boxes SELECT object_id, frame_number, min(x), min(y), max(x), max(y) from ' + '(SELECT object_id, frame_number, (x*{}+y*{}+{})/w as x, (x*{}+y*{}+{})/w as y from ' + '(SELECT OF.object_id, P.frame_number, P.x_coordinate as x, P.y_coordinate as y, P.x_coordinate*{}+P.y_coordinate*{}+{} as w from positions P, objects_features OF WHERE P.trajectory_id = OF.trajectory_id)) '.format(invHomography[0,0], invHomography[0,1], invHomography[0,2], invHomography[1,0], invHomography[1,1], invHomography[1,2], invHomography[2,0], invHomography[2,1], invHomography[2,2])+ + 'GROUP BY object_id, frame_number') + except sqlite3.OperationalError as error: + printDBError(error) + connection.commit() def loadBoundingBoxTableForDisplay(filename): '''Loads bounding boxes from bounding_boxes table for display over trajectories''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() boundingBoxes = {} # list of bounding boxes for each instant - try: - cursor.execute('SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'bounding_boxes\'') - result = cursor.fetchall() - if len(result) > 0: - cursor.execute('SELECT * FROM bounding_boxes') - for row in cursor: - boundingBoxes.setdefault(row[1], []).append([moving.Point(row[2], row[3]), moving.Point(row[4], row[5])]) - except sqlite3.OperationalError as error: - printDBError(error) - return boundingBoxes - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'bounding_boxes\'') + result = cursor.fetchall() + if len(result) > 0: + cursor.execute('SELECT * FROM bounding_boxes') + for row in cursor: + boundingBoxes.setdefault(row[1], []).append([moving.Point(row[2], row[3]), moving.Point(row[4], row[5])]) + except sqlite3.OperationalError as error: + printDBError(error) + return boundingBoxes return boundingBoxes ######################### @@ -572,56 +566,53 @@ def savePrototypesToSqlite(filename, prototypes): '''save the prototypes (a prototype is defined by a filename, a number and type''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (prototype_filename, prototype_id, trajectory_type))') - for p in prototypes: - cursor.execute('INSERT INTO prototypes VALUES(?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) - except sqlite3.OperationalError as error: - printDBError(error) - connection.commit() - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (prototype_filename, prototype_id, trajectory_type))') + for p in prototypes: + cursor.execute('INSERT INTO prototypes VALUES(?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) + except sqlite3.OperationalError as error: + printDBError(error) + connection.commit() def savePrototypeAssignmentsToSqlite(filename, objects, labels, prototypes): - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))') - for obj, label in zip(objects, labels): - proto = prototypes[label] - cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType())) - except sqlite3.OperationalError as error: - printDBError(error) - connection.commit() - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))') + for obj, label in zip(objects, labels): + proto = prototypes[label] + cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType())) + except sqlite3.OperationalError as error: + printDBError(error) + connection.commit() def loadPrototypesFromSqlite(filename, withTrajectories = True): 'Loads prototype ids and matchings (if stored)' - connection = sqlite3.connect(filename) - cursor = connection.cursor() prototypes = [] - objects = [] - try: - cursor.execute('SELECT * FROM prototypes') - for row in cursor: - prototypes.append(moving.Prototype(row[0], row[1], row[2], row[3])) - if withTrajectories: - for p in prototypes: - p.setMovingObject(loadTrajectoriesFromSqlite(p.getFilename(), p.getTrajectoryType(), [p.getNum()])[0]) - # loadingInformation = {} # complicated slightly optimized - # for p in prototypes: - # dbfn = p.getFilename() - # trajType = p.getTrajectoryType() - # if (dbfn, trajType) in loadingInformation: - # loadingInformation[(dbfn, trajType)].append(p) - # else: - # loadingInformation[(dbfn, trajType)] = [p] - # for k, v in loadingInformation.iteritems(): - # objects += loadTrajectoriesFromSqlite(k[0], k[1], [p.getNum() for p in v]) - except sqlite3.OperationalError as error: - printDBError(error) - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + objects = [] + try: + cursor.execute('SELECT * FROM prototypes') + for row in cursor: + prototypes.append(moving.Prototype(row[0], row[1], row[2], row[3])) + if withTrajectories: + for p in prototypes: + p.setMovingObject(loadTrajectoriesFromSqlite(p.getFilename(), p.getTrajectoryType(), [p.getNum()])[0]) + # loadingInformation = {} # complicated slightly optimized + # for p in prototypes: + # dbfn = p.getFilename() + # trajType = p.getTrajectoryType() + # if (dbfn, trajType) in loadingInformation: + # loadingInformation[(dbfn, trajType)].append(p) + # else: + # loadingInformation[(dbfn, trajType)] = [p] + # for k, v in loadingInformation.iteritems(): + # objects += loadTrajectoriesFromSqlite(k[0], k[1], [p.getNum() for p in v]) + except sqlite3.OperationalError as error: + printDBError(error) if len(set([p.getTrajectoryType() for p in prototypes])) > 1: print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes]))) return prototypes @@ -629,80 +620,77 @@ def savePOIsToSqlite(filename, gmm, gmmType, gmmId): '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - if gmmType not in ['beginning', 'end']: - print('Unknown POI type {}. Exiting'.format(gmmType)) - import sys - sys.exit() - try: - cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') - for i in xrange(gmm.n_components): - cursor.execute('INSERT INTO gaussians2d VALUES(?,?,?,?,?,?,?,?,?)', (gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) - connection.commit() - except sqlite3.OperationalError as error: - printDBError(error) - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + if gmmType not in ['beginning', 'end']: + print('Unknown POI type {}. Exiting'.format(gmmType)) + import sys + sys.exit() + try: + cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') + for i in xrange(gmm.n_components): + cursor.execute('INSERT INTO gaussians2d VALUES(?,?,?,?,?,?,?,?,?)', (gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) + connection.commit() + except sqlite3.OperationalError as error: + printDBError(error) def savePOIAssignmentsToSqlite(filename, objects): 'save the od fields of objects' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - try: - cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))') - for o in objects: - cursor.execute('INSERT INTO objects_pois VALUES(?,?,?)', (o.getNum(), o.od[0], o.od[1])) - connection.commit() - except sqlite3.OperationalError as error: - printDBError(error) - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))') + for o in objects: + cursor.execute('INSERT INTO objects_pois VALUES(?,?,?)', (o.getNum(), o.od[0], o.od[1])) + connection.commit() + except sqlite3.OperationalError as error: + printDBError(error) def loadPOIsFromSqlite(filename): 'Loads all 2D Gaussians in the database' from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields from ast import literal_eval - connection = sqlite3.connect(filename) - cursor = connection.cursor() pois = [] - try: - cursor.execute('SELECT * from gaussians2d') - gmmId = None - gmm = [] - for row in cursor: - if gmmId is None or row[0] != gmmId: - if len(gmm) > 0: - tmp = mixture.GaussianMixture(len(gmm), covarianceType) - tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) - tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) - tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) - tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] - tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm]) - pois.append(tmp) - gaussian = {'type': row[2], - 'mean': row[3:5], - 'covar': array(literal_eval(row[5])), - 'weight': row[7], - 'precisions': array(literal_eval(row[8]))} - gmm = [gaussian] - covarianceType = row[6] - gmmId = row[0] - else: - gmm.append({'type': row[2], - 'mean': row[3:5], - 'covar': array(literal_eval(row[5])), - 'weight': row[7], - 'precisions': array(literal_eval(row[8]))}) - if len(gmm) > 0: - tmp = mixture.GaussianMixture(len(gmm), covarianceType) - tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) - tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) - tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) - tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] - tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm]) - pois.append(tmp) - except sqlite3.OperationalError as error: - printDBError(error) - connection.close() + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + try: + cursor.execute('SELECT * from gaussians2d') + gmmId = None + gmm = [] + for row in cursor: + if gmmId is None or row[0] != gmmId: + if len(gmm) > 0: + tmp = mixture.GaussianMixture(len(gmm), covarianceType) + tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) + tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) + tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) + tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] + tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm]) + pois.append(tmp) + gaussian = {'type': row[2], + 'mean': row[3:5], + 'covar': array(literal_eval(row[5])), + 'weight': row[7], + 'precisions': array(literal_eval(row[8]))} + gmm = [gaussian] + covarianceType = row[6] + gmmId = row[0] + else: + gmm.append({'type': row[2], + 'mean': row[3:5], + 'covar': array(literal_eval(row[5])), + 'weight': row[7], + 'precisions': array(literal_eval(row[8]))}) + if len(gmm) > 0: + tmp = mixture.GaussianMixture(len(gmm), covarianceType) + tmp.means_ = array([gaussian['mean'] for gaussian in gmm]) + tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm]) + tmp.weights_ = array([gaussian['weight'] for gaussian in gmm]) + tmp.gmmTypes = [gaussian['type'] for gaussian in gmm] + tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm]) + pois.append(tmp) + except sqlite3.OperationalError as error: + printDBError(error) return pois ######################### @@ -882,16 +870,6 @@ connection.commit() connection.close() -def setRoadUserTypes(filename, objects): - '''Saves the user types of the objects in the sqlite database stored in filename - The objects should exist in the objects table''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - for obj in objects: - cursor.execute('update objects set road_user_type = {} WHERE object_id = {}'.format(obj.getUserType(), obj.getNum())) - connection.commit() - connection.close() - ######################### # txt files ######################### @@ -902,7 +880,7 @@ try: return open(filename, option) except IOError: - print 'File %s could not be opened.' % filename + print 'File {} could not be opened.'.format(filename) if quitting: from sys import exit exit() @@ -996,25 +974,25 @@ def loadObjectNumbersInLinkFromVissimFile(filename, linkIds): '''Finds the ids of the objects that go through any of the link in the list linkIds''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - queryStatement = 'SELECT DISTINCT trajectory_id FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+')' - try: - cursor.execute(queryStatement) - return [row[0] for row in cursor] - except sqlite3.OperationalError as error: - printDBError(error) + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + queryStatement = 'SELECT DISTINCT trajectory_id FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+')' + try: + cursor.execute(queryStatement) + return [row[0] for row in cursor] + except sqlite3.OperationalError as error: + printDBError(error) def getNObjectsInLinkFromVissimFile(filename, linkIds): '''Returns the number of objects that traveled through the link ids''' - connection = sqlite3.connect(filename) - cursor = connection.cursor() - queryStatement = 'SELECT link_id, COUNT(DISTINCT trajectory_id) FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+') GROUP BY link_id' - try: - cursor.execute(queryStatement) - return {row[0]:row[1] for row in cursor} - except sqlite3.OperationalError as error: - printDBError(error) + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + queryStatement = 'SELECT link_id, COUNT(DISTINCT trajectory_id) FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+') GROUP BY link_id' + try: + cursor.execute(queryStatement) + return {row[0]:row[1] for row in cursor} + except sqlite3.OperationalError as error: + printDBError(error) def loadTrajectoriesFromVissimFile(filename, simulationStepsPerTimeUnit, objectNumbers = None, warmUpLastInstant = None, usePandas = False, nDecimals = 2, lowMemory = True): '''Reads data from VISSIM .fzp trajectory file @@ -1067,30 +1045,30 @@ objects[objNum].curvilinearPositions.addPositionSYL(s, y, lane) line = readline(inputfile, '*$') elif filename.endswith(".sqlite"): - connection = sqlite3.connect(filename) - cursor = connection.cursor() - queryStatement = 'SELECT t, trajectory_id, link_id, lane_id, s_coordinate, y_coordinate FROM curvilinear_positions' - if objectNumbers is not None: - queryStatement += ' WHERE trajectory_id '+getObjectCriteria(objectNumbers) - queryStatement += ' ORDER BY trajectory_id, t' - try: - cursor.execute(queryStatement) - for row in cursor: - objNum = row[1] - instant = row[0]*simulationStepsPerTimeUnit - s = row[4] - y = row[5] - lane = '{}_{}'.format(row[2], row[3]) - if objNum not in objects: - if warmUpLastInstant is None or instant >= warmUpLastInstant: - if objectNumbers is None or len(objects) < objectNumbers: - objects[objNum] = moving.MovingObject(num = objNum, timeInterval = moving.TimeInterval(instant, instant)) - objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory() - if (warmUpLastInstant is None or instant >= warmUpLastInstant) and objNum in objects: - objects[objNum].timeInterval.last = instant - objects[objNum].curvilinearPositions.addPositionSYL(s, y, lane) - except sqlite3.OperationalError as error: - printDBError(error) + with sqlite3.connect(filename) as connection: + cursor = connection.cursor() + queryStatement = 'SELECT t, trajectory_id, link_id, lane_id, s_coordinate, y_coordinate FROM curvilinear_positions' + if objectNumbers is not None: + queryStatement += ' WHERE trajectory_id '+getObjectCriteria(objectNumbers) + queryStatement += ' ORDER BY trajectory_id, t' + try: + cursor.execute(queryStatement) + for row in cursor: + objNum = row[1] + instant = row[0]*simulationStepsPerTimeUnit + s = row[4] + y = row[5] + lane = '{}_{}'.format(row[2], row[3]) + if objNum not in objects: + if warmUpLastInstant is None or instant >= warmUpLastInstant: + if objectNumbers is None or len(objects) < objectNumbers: + objects[objNum] = moving.MovingObject(num = objNum, timeInterval = moving.TimeInterval(instant, instant)) + objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory() + if (warmUpLastInstant is None or instant >= warmUpLastInstant) and objNum in objects: + objects[objNum].timeInterval.last = instant + objects[objNum].curvilinearPositions.addPositionSYL(s, y, lane) + except sqlite3.OperationalError as error: + printDBError(error) else: print("File type of "+filename+" not supported (only .sqlite and .fzp files)") return objects.values()