diff python/storage.py @ 938:fbf12382f3f8

replaced db connection using with
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 17 Jul 2017 16:11:18 -0400
parents 0e63a918a1ca
children b1e8453c207c
line wrap: on
line diff
--- a/python/storage.py	Mon Jul 17 01:38:06 2017 -0400
+++ b/python/storage.py	Mon Jul 17 16:11:18 2017 -0400
@@ -44,27 +44,25 @@
 def deleteFromSqlite(filename, dataType):
     'Deletes (drops) some tables in the filename depending on type of data'
     if path.isfile(filename):
-        connection = sqlite3.connect(filename)
-        if dataType == 'object':
-            dropTables(connection, ['objects', 'objects_features'])
-        elif dataType == 'interaction':
-            dropTables(connection, ['interactions', 'indicators'])
-        elif dataType == 'bb':
-            dropTables(connection, ['bounding_boxes'])
-        elif dataType == 'pois':
-            dropTables(connection, ['gaussians2d', 'objects_pois'])
-        elif dataType == 'prototype':
-            dropTables(connection, ['prototypes', 'objects_prototypes'])
-        else:
-            print('Unknown data type {} to delete from database'.format(dataType))
-        connection.close()
+        with sqlite3.connect(filename) as connection:
+            if dataType == 'object':
+                dropTables(connection, ['objects', 'objects_features'])
+            elif dataType == 'interaction':
+                dropTables(connection, ['interactions', 'indicators'])
+            elif dataType == 'bb':
+                dropTables(connection, ['bounding_boxes'])
+            elif dataType == 'pois':
+                dropTables(connection, ['gaussians2d', 'objects_pois'])
+            elif dataType == 'prototype':
+                dropTables(connection, ['prototypes', 'objects_prototypes'])
+            else:
+                print('Unknown data type {} to delete from database'.format(dataType))
     else:
         print('{} does not exist'.format(filename))
 
 def tableExists(connection, tableName):
     'indicates if the table exists in the database'
     try:
-        #connection = sqlite3.connect(filename)
         cursor = connection.cursor()
         cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'')
         return cursor.fetchone()[0] == 1
@@ -112,7 +110,6 @@
     '''Creates an index for the column in the table
     I will make querying with a condition on this column faster'''
     try:
-        #connection = sqlite3.connect(filename)
         cursor = connection.cursor()
         s = "CREATE "
         if unique:
@@ -250,113 +247,110 @@
     either features, objects (feature groups) or bounding box series) 
     The number loaded is either the first objectNumbers objects,
     or the indices in objectNumbers from the database'''
-    connection = sqlite3.connect(filename)
-
-    if tablePrefix is None:
-        prefix = ''
-    else:
-        prefix = tablePrefix + '_'
-    objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep)
-    objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep)
+    objects = []
+    with sqlite3.connect(filename) as connection:
+        if tablePrefix is None:
+            prefix = ''
+        else:
+            prefix = tablePrefix + '_'
+        objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep)
+        objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep)
 
-    if len(objectVelocities) > 0:
-        for o,v in zip(objects, objectVelocities):
-            if o.getNum() == v.getNum():
-                o.velocities = v.positions
-                o.velocities.duplicateLastPosition() # avoid having velocity shorter by one position than positions
-            else:
-                print('Could not match positions {0} with velocities {1}'.format(o.getNum(), v.getNum()))
-
-    if trajectoryType == 'object':
-        cursor = connection.cursor()
-        try:
-            # attribute feature numbers to objects
-            queryStatement = 'SELECT trajectory_id, object_id FROM objects_features'
-            if objectNumbers is not None:
-                queryStatement += ' WHERE object_id '+getObjectCriteria(objectNumbers)
-            queryStatement += ' ORDER BY object_id' # order is important to group all features per object
-            logging.debug(queryStatement)
-            cursor.execute(queryStatement) 
+        if len(objectVelocities) > 0:
+            for o,v in zip(objects, objectVelocities):
+                if o.getNum() == v.getNum():
+                    o.velocities = v.positions
+                    o.velocities.duplicateLastPosition() # avoid having velocity shorter by one position than positions
+                else:
+                    print('Could not match positions {0} with velocities {1}'.format(o.getNum(), v.getNum()))
 
-            featureNumbers = {}
-            for row in cursor:
-                objId = row[1]
-                if objId not in featureNumbers:
-                    featureNumbers[objId] = [row[0]]
-                else:
-                    featureNumbers[objId].append(row[0])
-                    
-            for obj in objects:
-                obj.featureNumbers = featureNumbers[obj.getNum()]
+        if trajectoryType == 'object':
+            cursor = connection.cursor()
+            try:
+                # attribute feature numbers to objects
+                queryStatement = 'SELECT trajectory_id, object_id FROM objects_features'
+                if objectNumbers is not None:
+                    queryStatement += ' WHERE object_id '+getObjectCriteria(objectNumbers)
+                queryStatement += ' ORDER BY object_id' # order is important to group all features per object
+                logging.debug(queryStatement)
+                cursor.execute(queryStatement) 
+
+                featureNumbers = {}
+                for row in cursor:
+                    objId = row[1]
+                    if objId not in featureNumbers:
+                        featureNumbers[objId] = [row[0]]
+                    else:
+                        featureNumbers[objId].append(row[0])
 
-            # load userType
-            userTypes = loadUserTypesFromTable(cursor, objectNumbers)
-            for obj in objects:
-                obj.userType = userTypes[obj.getNum()]
+                for obj in objects:
+                    obj.featureNumbers = featureNumbers[obj.getNum()]
 
-            if withFeatures:
-                nFeatures = 0
+                # load userType
+                userTypes = loadUserTypesFromTable(cursor, objectNumbers)
                 for obj in objects:
-                    nFeatures = max(nFeatures, max(obj.featureNumbers))
-                features = loadTrajectoriesFromSqlite(filename, 'feature', nFeatures+1, timeStep = timeStep)
-                for obj in objects:
-                    obj.setFeatures(features)
-             
-        except sqlite3.OperationalError as error:
-            printDBError(error)
-            objects = []
+                    obj.userType = userTypes[obj.getNum()]
 
-    connection.close()
+                if withFeatures:
+                    nFeatures = 0
+                    for obj in objects:
+                        nFeatures = max(nFeatures, max(obj.featureNumbers))
+                    features = loadTrajectoriesFromSqlite(filename, 'feature', nFeatures+1, timeStep = timeStep)
+                    for obj in objects:
+                        obj.setFeatures(features)
+
+            except sqlite3.OperationalError as error:
+                printDBError(error)
     return objects
 
 def loadObjectFeatureFrameNumbers(filename, objectNumbers = None):
     'Loads the feature frame numbers for each object'
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        queryStatement = 'SELECT OF.object_id, TL.trajectory_id, TL.length FROM (SELECT trajectory_id, max(frame_number)-min(frame_number) AS length FROM positions GROUP BY trajectory_id) TL, objects_features OF WHERE TL.trajectory_id = OF.trajectory_id'
-        if objectNumbers is not None:
-            queryStatement += ' AND object_id '+getObjectCriteria(objectNumbers)
-        queryStatement += ' ORDER BY OF.object_id, TL.length DESC'
-        logging.debug(queryStatement)
-        cursor.execute(queryStatement)
-        objectFeatureNumbers = {}
-        for row in cursor:
-            objId = row[0]
-            if objId in objectFeatureNumbers:
-                objectFeatureNumbers[objId].append(row[1])
-            else:
-                objectFeatureNumbers[objId] = [row[1]]
-        return objectFeatureNumbers
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-        return None
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            queryStatement = 'SELECT OF.object_id, TL.trajectory_id, TL.length FROM (SELECT trajectory_id, max(frame_number)-min(frame_number) AS length FROM positions GROUP BY trajectory_id) TL, objects_features OF WHERE TL.trajectory_id = OF.trajectory_id'
+            if objectNumbers is not None:
+                queryStatement += ' AND object_id '+getObjectCriteria(objectNumbers)
+            queryStatement += ' ORDER BY OF.object_id, TL.length DESC'
+            logging.debug(queryStatement)
+            cursor.execute(queryStatement)
+            objectFeatureNumbers = {}
+            for row in cursor:
+                objId = row[0]
+                if objId in objectFeatureNumbers:
+                    objectFeatureNumbers[objId].append(row[1])
+                else:
+                    objectFeatureNumbers[objId] = [row[1]]
+            return objectFeatureNumbers
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+            return None
 
 def addCurvilinearTrajectoriesFromSqlite(filename, objects):
     '''Adds curvilinear positions (s_coordinate, y_coordinate, lane)
     from a database to an existing MovingObject dict (indexed by each objects's num)'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+
+        try:
+            cursor.execute('SELECT * from curvilinear_positions order by trajectory_id, frame_number')
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+            return []
 
-    try:
-        cursor.execute('SELECT * from curvilinear_positions order by trajectory_id, frame_number')
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-        return []
-    
-    missingObjectNumbers = []
-    objNum = None
-    for row in cursor:
-        if objNum != row[0]:
-            objNum = row[0]
+        missingObjectNumbers = []
+        objNum = None
+        for row in cursor:
+            if objNum != row[0]:
+                objNum = row[0]
+                if objNum in objects:
+                    objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory()
+                else:
+                    missingObjectNumbers.append(objNum)
             if objNum in objects:
-                objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory()
-            else:
-                missingObjectNumbers.append(objNum)
-        if objNum in objects:
-            objects[objNum].curvilinearPositions.addPositionSYL(row[2],row[3],row[4])
-    if len(missingObjectNumbers) > 0:
-        print('List of missing objects to attach corresponding curvilinear trajectories: {}'.format(missingObjectNumbers))
+                objects[objNum].curvilinearPositions.addPositionSYL(row[2],row[3],row[4])
+        if len(missingObjectNumbers) > 0:
+            print('List of missing objects to attach corresponding curvilinear trajectories: {}'.format(missingObjectNumbers))
 
 def saveTrajectoriesToTable(connection, objects, trajectoryType, tablePrefix = None):
     'Saves trajectories in table tableName'
@@ -427,12 +421,20 @@
     Either feature positions (and velocities if they exist)
     or curvilinear positions will be saved at a time'''
 
-    connection = sqlite3.connect(outputFilename)
-    try:
-        saveTrajectoriesToTable(connection, objects, trajectoryType, None)
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.close()
+    with sqlite3.connect(outputFilename) as connection:
+        try:
+            saveTrajectoriesToTable(connection, objects, trajectoryType, None)
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+
+def setRoadUserTypes(filename, objects):
+    '''Saves the user types of the objects in the sqlite database stored in filename
+    The objects should exist in the objects table'''
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        for obj in objects:
+            cursor.execute('update objects set road_user_type = {} WHERE object_id = {}'.format(obj.getUserType(), obj.getNum()))
+        connection.commit()
 
 def loadBBMovingObjectsFromSqlite(filename, objectType = 'bb', objectNumbers = None, timeStep = None):
     '''Loads bounding box moving object from an SQLite
@@ -440,23 +442,20 @@
     or Urban Tracker
 
     Load descriptions?'''
-    connection = sqlite3.connect(filename)
     objects = []
+    with sqlite3.connect(filename) as connection:
+        if objectType == 'bb':
+            topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', objectNumbers, timeStep)
+            bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', objectNumbers, timeStep)
+            userTypes = loadUserTypesFromTable(connection.cursor(), objectNumbers) # string format is same as object
 
-    if objectType == 'bb':
-        topCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbtop', objectNumbers, timeStep)
-        bottomCorners = loadTrajectoriesFromTable(connection, 'bounding_boxes', 'bbbottom', objectNumbers, timeStep)
-        userTypes = loadUserTypesFromTable(connection.cursor(), objectNumbers) # string format is same as object
-        
-        for t, b in zip(topCorners, bottomCorners):
-            num = t.getNum()
-            if t.getNum() == b.getNum():
-                annotation = moving.BBMovingObject(num, t.getTimeInterval(), t, b, userTypes[num])
-                objects.append(annotation)
-    else:
-        print ('Unknown type of bounding box {}'.format(objectType))
-
-    connection.close()
+            for t, b in zip(topCorners, bottomCorners):
+                num = t.getNum()
+                if t.getNum() == b.getNum():
+                    annotation = moving.BBMovingObject(num, t.getTimeInterval(), t, b, userTypes[num])
+                    objects.append(annotation)
+        else:
+            print ('Unknown type of bounding box {}'.format(objectType))
     return objects
 
 def saveInteraction(cursor, interaction):
@@ -465,16 +464,15 @@
 
 def saveInteractionsToSqlite(filename, interactions):
     'Saves the interactions in the table'
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        createInteractionTable(cursor)
-        for inter in interactions:
-            saveInteraction(cursor, inter)
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.commit()
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            createInteractionTable(cursor)
+            for inter in interactions:
+                saveInteraction(cursor, inter)
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+        connection.commit()
 
 def saveIndicator(cursor, interactionNum, indicator):
     for instant in indicator.getTimeInterval():
@@ -483,51 +481,49 @@
 
 def saveIndicatorsToSqlite(filename, interactions, indicatorNames = events.Interaction.indicatorNames):
     'Saves the indicator values in the table'
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        createInteractionTable(cursor)
-        createIndicatorTable(cursor)
-        for inter in interactions:
-            saveInteraction(cursor, inter)
-            for indicatorName in indicatorNames:
-                indicator = inter.getIndicator(indicatorName)
-                if indicator is not None:
-                    saveIndicator(cursor, inter.getNum(), indicator)
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.commit()
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            createInteractionTable(cursor)
+            createIndicatorTable(cursor)
+            for inter in interactions:
+                saveInteraction(cursor, inter)
+                for indicatorName in indicatorNames:
+                    indicator = inter.getIndicator(indicatorName)
+                    if indicator is not None:
+                        saveIndicator(cursor, inter.getNum(), indicator)
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+        connection.commit()
 
 def loadInteractionsFromSqlite(filename):
     '''Loads interaction and their indicators
     
     TODO choose the interactions to load'''
     interactions = []
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        cursor.execute('SELECT INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number')
-        interactionNum = -1
-        indicatorTypeNum = -1
-        tmpIndicators = {}
-        for row in cursor:
-            if row[0] != interactionNum:
-                interactionNum = row[0]
-                interactions.append(events.Interaction(interactionNum, moving.TimeInterval(row[3],row[4]), row[1], row[2]))
-                interactions[-1].indicators = {}
-            if indicatorTypeNum != row[5] or row[0] != interactionNum:
-                indicatorTypeNum = row[5]
-                indicatorName = events.Interaction.indicatorNames[indicatorTypeNum]
-                indicatorValues = {row[6]:row[7]}
-                interactions[-1].indicators[indicatorName] = indicators.SeverityIndicator(indicatorName, indicatorValues, mostSevereIsMax = not indicatorName in events.Interaction.timeIndicators)
-            else:
-                indicatorValues[row[6]] = row[7]
-                interactions[-1].indicators[indicatorName].timeInterval.last = row[6]
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-        return []
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('SELECT INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number')
+            interactionNum = -1
+            indicatorTypeNum = -1
+            tmpIndicators = {}
+            for row in cursor:
+                if row[0] != interactionNum:
+                    interactionNum = row[0]
+                    interactions.append(events.Interaction(interactionNum, moving.TimeInterval(row[3],row[4]), row[1], row[2]))
+                    interactions[-1].indicators = {}
+                if indicatorTypeNum != row[5] or row[0] != interactionNum:
+                    indicatorTypeNum = row[5]
+                    indicatorName = events.Interaction.indicatorNames[indicatorTypeNum]
+                    indicatorValues = {row[6]:row[7]}
+                    interactions[-1].indicators[indicatorName] = indicators.SeverityIndicator(indicatorName, indicatorValues, mostSevereIsMax = not indicatorName in events.Interaction.timeIndicators)
+                else:
+                    indicatorValues[row[6]] = row[7]
+                    interactions[-1].indicators[indicatorName].timeInterval.last = row[6]
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+            return []
     return interactions
 # load first and last object instants
 # CREATE TEMP TABLE IF NOT EXISTS object_instants AS SELECT OF.object_id, min(frame_number) as first_instant, max(frame_number) as last_instant from positions P, objects_features OF WHERE P.trajectory_id = OF.trajectory_id group by OF.object_id order by OF.object_id
@@ -535,35 +531,33 @@
 def createBoundingBoxTable(filename, invHomography = None):
     '''Create the table to store the object bounding boxes in image space
     '''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS bounding_boxes (object_id INTEGER, frame_number INTEGER, x_top_left REAL, y_top_left REAL, x_bottom_right REAL, y_bottom_right REAL,  PRIMARY KEY(object_id, frame_number))')
-        cursor.execute('INSERT INTO bounding_boxes SELECT object_id, frame_number, min(x), min(y), max(x), max(y) from '
-              '(SELECT object_id, frame_number, (x*{}+y*{}+{})/w as x, (x*{}+y*{}+{})/w as y from '
-              '(SELECT OF.object_id, P.frame_number, P.x_coordinate as x, P.y_coordinate as y, P.x_coordinate*{}+P.y_coordinate*{}+{} as w from positions P, objects_features OF WHERE P.trajectory_id = OF.trajectory_id)) '.format(invHomography[0,0], invHomography[0,1], invHomography[0,2], invHomography[1,0], invHomography[1,1], invHomography[1,2], invHomography[2,0], invHomography[2,1], invHomography[2,2])+
-              'GROUP BY object_id, frame_number')
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.commit()
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('CREATE TABLE IF NOT EXISTS bounding_boxes (object_id INTEGER, frame_number INTEGER, x_top_left REAL, y_top_left REAL, x_bottom_right REAL, y_bottom_right REAL,  PRIMARY KEY(object_id, frame_number))')
+            cursor.execute('INSERT INTO bounding_boxes SELECT object_id, frame_number, min(x), min(y), max(x), max(y) from '
+                  '(SELECT object_id, frame_number, (x*{}+y*{}+{})/w as x, (x*{}+y*{}+{})/w as y from '
+                  '(SELECT OF.object_id, P.frame_number, P.x_coordinate as x, P.y_coordinate as y, P.x_coordinate*{}+P.y_coordinate*{}+{} as w from positions P, objects_features OF WHERE P.trajectory_id = OF.trajectory_id)) '.format(invHomography[0,0], invHomography[0,1], invHomography[0,2], invHomography[1,0], invHomography[1,1], invHomography[1,2], invHomography[2,0], invHomography[2,1], invHomography[2,2])+
+                  'GROUP BY object_id, frame_number')
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+        connection.commit()
 
 def loadBoundingBoxTableForDisplay(filename):
     '''Loads bounding boxes from bounding_boxes table for display over trajectories'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
     boundingBoxes = {} # list of bounding boxes for each instant
-    try:
-        cursor.execute('SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'bounding_boxes\'')
-        result = cursor.fetchall()
-        if len(result) > 0:
-            cursor.execute('SELECT * FROM bounding_boxes')
-            for row in cursor:
-                boundingBoxes.setdefault(row[1], []).append([moving.Point(row[2], row[3]), moving.Point(row[4], row[5])])
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-        return boundingBoxes
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('SELECT name FROM sqlite_master WHERE type=\'table\' AND name=\'bounding_boxes\'')
+            result = cursor.fetchall()
+            if len(result) > 0:
+                cursor.execute('SELECT * FROM bounding_boxes')
+                for row in cursor:
+                    boundingBoxes.setdefault(row[1], []).append([moving.Point(row[2], row[3]), moving.Point(row[4], row[5])])
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+            return boundingBoxes
     return boundingBoxes
 
 #########################
@@ -572,56 +566,53 @@
 
 def savePrototypesToSqlite(filename, prototypes):
     '''save the prototypes (a prototype is defined by a filename, a number and type'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (prototype_filename, prototype_id, trajectory_type))')
-        for p in prototypes:
-            cursor.execute('INSERT INTO prototypes VALUES(?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings()))
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.commit()
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (prototype_filename, prototype_id, trajectory_type))')
+            for p in prototypes:
+                cursor.execute('INSERT INTO prototypes VALUES(?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings()))
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+        connection.commit()
 
 def savePrototypeAssignmentsToSqlite(filename, objects, labels, prototypes):
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))')
-        for obj, label in zip(objects, labels):
-            proto = prototypes[label]
-            cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType()))
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.commit()
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))')
+            for obj, label in zip(objects, labels):
+                proto = prototypes[label]
+                cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType()))
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+        connection.commit()
 
 def loadPrototypesFromSqlite(filename, withTrajectories = True):
     'Loads prototype ids and matchings (if stored)'
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
     prototypes = []
-    objects = []
-    try:
-        cursor.execute('SELECT * FROM prototypes')
-        for row in cursor:
-            prototypes.append(moving.Prototype(row[0], row[1], row[2], row[3]))
-        if withTrajectories:
-            for p in prototypes:
-                p.setMovingObject(loadTrajectoriesFromSqlite(p.getFilename(), p.getTrajectoryType(), [p.getNum()])[0])
-            # loadingInformation = {} # complicated slightly optimized
-            # for p in prototypes:
-            #     dbfn = p.getFilename()
-            #     trajType = p.getTrajectoryType()
-            #     if (dbfn, trajType) in loadingInformation:
-            #         loadingInformation[(dbfn, trajType)].append(p)
-            #     else:
-            #         loadingInformation[(dbfn, trajType)] = [p]
-            # for k, v in loadingInformation.iteritems():
-            #     objects += loadTrajectoriesFromSqlite(k[0], k[1], [p.getNum() for p in v])
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        objects = []
+        try:
+            cursor.execute('SELECT * FROM prototypes')
+            for row in cursor:
+                prototypes.append(moving.Prototype(row[0], row[1], row[2], row[3]))
+            if withTrajectories:
+                for p in prototypes:
+                    p.setMovingObject(loadTrajectoriesFromSqlite(p.getFilename(), p.getTrajectoryType(), [p.getNum()])[0])
+                # loadingInformation = {} # complicated slightly optimized
+                # for p in prototypes:
+                #     dbfn = p.getFilename()
+                #     trajType = p.getTrajectoryType()
+                #     if (dbfn, trajType) in loadingInformation:
+                #         loadingInformation[(dbfn, trajType)].append(p)
+                #     else:
+                #         loadingInformation[(dbfn, trajType)] = [p]
+                # for k, v in loadingInformation.iteritems():
+                #     objects += loadTrajectoriesFromSqlite(k[0], k[1], [p.getNum() for p in v])
+        except sqlite3.OperationalError as error:
+            printDBError(error)
     if len(set([p.getTrajectoryType() for p in prototypes])) > 1:
         print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes])))
     return prototypes
@@ -629,80 +620,77 @@
 def savePOIsToSqlite(filename, gmm, gmmType, gmmId):
     '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture)
     gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    if gmmType not in ['beginning', 'end']:
-        print('Unknown POI type {}. Exiting'.format(gmmType))
-        import sys
-        sys.exit()
-    try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))')
-        for i in xrange(gmm.n_components):
-            cursor.execute('INSERT INTO gaussians2d VALUES(?,?,?,?,?,?,?,?,?)', (gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist())))
-        connection.commit()
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        if gmmType not in ['beginning', 'end']:
+            print('Unknown POI type {}. Exiting'.format(gmmType))
+            import sys
+            sys.exit()
+        try:
+            cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))')
+            for i in xrange(gmm.n_components):
+                cursor.execute('INSERT INTO gaussians2d VALUES(?,?,?,?,?,?,?,?,?)', (gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist())))
+            connection.commit()
+        except sqlite3.OperationalError as error:
+            printDBError(error)
 
 def savePOIAssignmentsToSqlite(filename, objects):
     'save the od fields of objects'
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    try:
-        cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))')
-        for o in objects:
-            cursor.execute('INSERT INTO objects_pois VALUES(?,?,?)', (o.getNum(), o.od[0], o.od[1]))
-        connection.commit()
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))')
+            for o in objects:
+                cursor.execute('INSERT INTO objects_pois VALUES(?,?,?)', (o.getNum(), o.od[0], o.od[1]))
+            connection.commit()
+        except sqlite3.OperationalError as error:
+            printDBError(error)
     
 def loadPOIsFromSqlite(filename):
     'Loads all 2D Gaussians in the database'
     from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields
     from ast import literal_eval
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
     pois = []
-    try:
-        cursor.execute('SELECT * from gaussians2d')
-        gmmId = None
-        gmm = []
-        for row in cursor:
-            if gmmId is None or row[0] != gmmId:
-                if len(gmm) > 0:
-                    tmp = mixture.GaussianMixture(len(gmm), covarianceType)
-                    tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
-                    tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
-                    tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
-                    tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
-                    tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm])
-                    pois.append(tmp)
-                gaussian = {'type': row[2],
-                            'mean': row[3:5],
-                            'covar': array(literal_eval(row[5])),
-                            'weight': row[7],
-                            'precisions': array(literal_eval(row[8]))}
-                gmm = [gaussian]
-                covarianceType = row[6]
-                gmmId = row[0]
-            else:
-                gmm.append({'type': row[2],
-                            'mean': row[3:5],
-                            'covar': array(literal_eval(row[5])),
-                            'weight': row[7],
-                            'precisions': array(literal_eval(row[8]))})
-        if len(gmm) > 0:
-            tmp = mixture.GaussianMixture(len(gmm), covarianceType)
-            tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
-            tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
-            tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
-            tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
-            tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm])
-            pois.append(tmp)
-    except sqlite3.OperationalError as error:
-        printDBError(error)
-    connection.close()
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            cursor.execute('SELECT * from gaussians2d')
+            gmmId = None
+            gmm = []
+            for row in cursor:
+                if gmmId is None or row[0] != gmmId:
+                    if len(gmm) > 0:
+                        tmp = mixture.GaussianMixture(len(gmm), covarianceType)
+                        tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
+                        tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
+                        tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
+                        tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
+                        tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm])
+                        pois.append(tmp)
+                    gaussian = {'type': row[2],
+                                'mean': row[3:5],
+                                'covar': array(literal_eval(row[5])),
+                                'weight': row[7],
+                                'precisions': array(literal_eval(row[8]))}
+                    gmm = [gaussian]
+                    covarianceType = row[6]
+                    gmmId = row[0]
+                else:
+                    gmm.append({'type': row[2],
+                                'mean': row[3:5],
+                                'covar': array(literal_eval(row[5])),
+                                'weight': row[7],
+                                'precisions': array(literal_eval(row[8]))})
+            if len(gmm) > 0:
+                tmp = mixture.GaussianMixture(len(gmm), covarianceType)
+                tmp.means_ = array([gaussian['mean'] for gaussian in gmm])
+                tmp.covariances_ = array([gaussian['covar'] for gaussian in gmm])
+                tmp.weights_ = array([gaussian['weight'] for gaussian in gmm])
+                tmp.gmmTypes = [gaussian['type'] for gaussian in gmm]
+                tmp.precisions_cholesky_ = array([gaussian['precisions'] for gaussian in gmm])
+                pois.append(tmp)
+        except sqlite3.OperationalError as error:
+            printDBError(error)
     return pois
     
 #########################
@@ -882,16 +870,6 @@
     connection.commit()
     connection.close()
 
-def setRoadUserTypes(filename, objects):
-    '''Saves the user types of the objects in the sqlite database stored in filename
-    The objects should exist in the objects table'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    for obj in objects:
-        cursor.execute('update objects set road_user_type = {} WHERE object_id = {}'.format(obj.getUserType(), obj.getNum()))
-    connection.commit()
-    connection.close()
-
 #########################
 # txt files
 #########################
@@ -902,7 +880,7 @@
     try:
         return open(filename, option)
     except IOError:
-        print 'File %s could not be opened.' % filename
+        print 'File {} could not be opened.'.format(filename)
         if quitting:
             from sys import exit
             exit()
@@ -996,25 +974,25 @@
 
 def loadObjectNumbersInLinkFromVissimFile(filename, linkIds):
     '''Finds the ids of the objects that go through any of the link in the list linkIds'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    queryStatement = 'SELECT DISTINCT trajectory_id FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+')'
-    try:
-        cursor.execute(queryStatement)
-        return [row[0] for row in cursor]
-    except sqlite3.OperationalError as error:
-        printDBError(error)
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        queryStatement = 'SELECT DISTINCT trajectory_id FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+')'
+        try:
+            cursor.execute(queryStatement)
+            return [row[0] for row in cursor]
+        except sqlite3.OperationalError as error:
+            printDBError(error)
 
 def getNObjectsInLinkFromVissimFile(filename, linkIds):
     '''Returns the number of objects that traveled through the link ids'''
-    connection = sqlite3.connect(filename)
-    cursor = connection.cursor()
-    queryStatement = 'SELECT link_id, COUNT(DISTINCT trajectory_id) FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+') GROUP BY link_id'
-    try:
-        cursor.execute(queryStatement)
-        return {row[0]:row[1] for row in cursor}
-    except sqlite3.OperationalError as error:
-        printDBError(error)
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        queryStatement = 'SELECT link_id, COUNT(DISTINCT trajectory_id) FROM curvilinear_positions where link_id IN ('+','.join([str(id) for id in linkIds])+') GROUP BY link_id'
+        try:
+            cursor.execute(queryStatement)
+            return {row[0]:row[1] for row in cursor}
+        except sqlite3.OperationalError as error:
+            printDBError(error)
 
 def loadTrajectoriesFromVissimFile(filename, simulationStepsPerTimeUnit, objectNumbers = None, warmUpLastInstant = None, usePandas = False, nDecimals = 2, lowMemory = True):
     '''Reads data from VISSIM .fzp trajectory file
@@ -1067,30 +1045,30 @@
                     objects[objNum].curvilinearPositions.addPositionSYL(s, y, lane)
                 line = readline(inputfile, '*$')
         elif filename.endswith(".sqlite"):
-            connection = sqlite3.connect(filename)
-            cursor = connection.cursor()
-            queryStatement = 'SELECT t, trajectory_id, link_id, lane_id, s_coordinate, y_coordinate FROM curvilinear_positions'
-            if objectNumbers is not None:
-                queryStatement += ' WHERE trajectory_id '+getObjectCriteria(objectNumbers)
-            queryStatement += ' ORDER BY trajectory_id, t'
-            try:
-                cursor.execute(queryStatement)
-                for row in cursor:
-                    objNum = row[1]
-                    instant = row[0]*simulationStepsPerTimeUnit
-                    s = row[4]
-                    y = row[5]
-                    lane = '{}_{}'.format(row[2], row[3])
-                    if objNum not in objects:
-                        if warmUpLastInstant is None or instant >= warmUpLastInstant:
-                            if objectNumbers is None or len(objects) < objectNumbers:
-                                objects[objNum] = moving.MovingObject(num = objNum, timeInterval = moving.TimeInterval(instant, instant))
-                                objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory()
-                    if (warmUpLastInstant is None or instant >= warmUpLastInstant) and objNum in objects:
-                        objects[objNum].timeInterval.last = instant
-                        objects[objNum].curvilinearPositions.addPositionSYL(s, y, lane)
-            except sqlite3.OperationalError as error:
-                printDBError(error)
+            with sqlite3.connect(filename) as connection:
+                cursor = connection.cursor()
+                queryStatement = 'SELECT t, trajectory_id, link_id, lane_id, s_coordinate, y_coordinate FROM curvilinear_positions'
+                if objectNumbers is not None:
+                    queryStatement += ' WHERE trajectory_id '+getObjectCriteria(objectNumbers)
+                queryStatement += ' ORDER BY trajectory_id, t'
+                try:
+                    cursor.execute(queryStatement)
+                    for row in cursor:
+                        objNum = row[1]
+                        instant = row[0]*simulationStepsPerTimeUnit
+                        s = row[4]
+                        y = row[5]
+                        lane = '{}_{}'.format(row[2], row[3])
+                        if objNum not in objects:
+                            if warmUpLastInstant is None or instant >= warmUpLastInstant:
+                                if objectNumbers is None or len(objects) < objectNumbers:
+                                    objects[objNum] = moving.MovingObject(num = objNum, timeInterval = moving.TimeInterval(instant, instant))
+                                    objects[objNum].curvilinearPositions = moving.CurvilinearTrajectory()
+                        if (warmUpLastInstant is None or instant >= warmUpLastInstant) and objNum in objects:
+                            objects[objNum].timeInterval.last = instant
+                            objects[objNum].curvilinearPositions.addPositionSYL(s, y, lane)
+                except sqlite3.OperationalError as error:
+                    printDBError(error)
         else:
             print("File type of "+filename+" not supported (only .sqlite and .fzp files)")
         return objects.values()