view python/storage.py @ 343:74e437ab5f11

first version of indicator loading code
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Fri, 21 Jun 2013 15:28:59 -0400
parents 4d69486869a5
children 14a2405f54f8
line wrap: on
line source

#! /usr/bin/env python
# -*- coding: utf-8 -*-
'''Various utilities to save and load data'''

import utils, moving, events, indicators

import sqlite3

__metaclass__ = type


ngsimUserTypes = {'twowheels':1,
                  'car':2,
                  'truck':3}

#########################
# Sqlite
#########################

def saveTrajectoriesToSqlite(objects, outFilename, trajectoryType, objectNumbers = -1):
    """
    This function writers trajectories to a specified sqlite file
    @param[in] objects -> a list of trajectories
    @param[in] trajectoryType -
    @param[out] outFile -> the .sqlite file containting the written objects
    @param[in] objectNumber : number of objects loaded
    """
    connection = sqlite3.connect(outFilename)
    cursor = connection.cursor()

    schema = "CREATE TABLE IF NOT EXISTS \"positions\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))"
    cursor.execute(schema)

    trajectory_id = 0
    frame_number = 0
    if trajectoryType == 'feature':
        if type(objectNumbers) == int and objectNumbers == -1:
            for trajectory in objects:
                trajectory_id += 1
                frame_number = 0
                for position in trajectory.getPositions():
                    frame_number += 1
                    query = "insert into positions (trajectory_id, frame_number, x_coordinate, y_coordinate) values (?,?,?,?)"
                    cursor.execute(query,(trajectory_id,frame_number,position.x,position.y))
                    
    connection.commit()
    connection.close()

def setRoadUserTypes(filename, objects):
    '''Saves the user types of the objects in the sqlite database stored in filename
    The objects should exist in the objects table'''
    connection = sqlite3.connect(filename)
    cursor = connection.cursor()
    for obj in objects:
        cursor.execute('update objects set road_user_type = {} where object_id = {}'.format(obj.getUserType(), obj.getNum()))
    connection.commit()
    connection.close()

def printDBError(error):
    print('DB Error: {0}'.format(error))

def loadPrototypeMatchIndexesFromSqlite(filename):
    """
    This function loads the prototypes table in the database of name <filename>.
    It returns a list of tuples representing matching ids : [(prototype_id, matched_trajectory_id),...]
    """
    matched_indexes = []

    connection = sqlite3.connect(filename)
    cursor = connection.cursor()

    try:
        cursor.execute('SELECT * from prototypes order by prototype_id, trajectory_id_matched')
    except sqlite3.OperationalError as error:
        printDBError(error)
        return []

    for row in cursor:
        matched_indexes.append((row[0],row[1]))

    connection.close()
    return matched_indexes

def getTrajectoryIdQuery(objectNumbers, trajectoryType):
    if trajectoryType == 'feature':
        statementBeginning = 'where trajectory_id '
    elif trajectoryType == 'object':
        statementBeginning =  'and OF.object_id '
    else:
        print('no trajectory type was chosen')

    if type(objectNumbers) == int:
        if objectNumbers == -1:
            query = ''
        else:
            query = statementBeginning+'between 0 and {0} '.format(objectNumbers)
    elif type(objectNumbers) == list:
        query = statementBeginning+'in ('+', '.join([str(n) for n in objectNumbers])+') '
    return query

def loadTrajectoriesFromTable(connection, tableName, trajectoryType, objectNumbers = -1):
    '''Loads trajectories (in the general sense) from the given table
    can be positions or velocities

    returns a moving object'''
    cursor = connection.cursor()

    try:
        if trajectoryType == 'feature':
            trajectoryIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
            cursor.execute('SELECT * from '+tableName+' '+trajectoryIdQuery+'order by trajectory_id, frame_number')
        elif trajectoryType == 'object':
            objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
            cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by OF.object_id, P.frame_number order by OF.object_id, P.frame_number')
        else:
            print('no trajectory type was chosen')
    except sqlite3.OperationalError as error:
        printDBError(error)
        return []

    objId = -1
    obj = None
    objects = []
    for row in cursor:
        if row[0] != objId:
            objId = row[0]
            if obj:
                objects.append(obj)
            obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]]))
        else:
            obj.timeInterval.last = row[1]
            obj.positions.addPositionXY(row[2],row[3])

    if obj:
        objects.append(obj)

    return objects

def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = -1):
    '''Loads nObjects or the indices in objectNumbers from the database 
    TODO: load feature numbers and not average feature trajectories
    TODO: other ways of averaging trajectories (load all points, sorted by frame_number and leave the agregation to be done in python)
    '''
    connection = sqlite3.connect(filename) # add test if it open

    objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers)
    objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers)

    if len(objectVelocities) > 0:
        for o,v in zip(objects, objectVelocities):
            if o.getNum() == v.getNum():
                o.velocities = v.positions
                o.velocities.duplicateLastPosition() # avoid having velocity shorter by one position than positions
            else:
                print('Could not match positions {0} with velocities {1}'.format(o.getNum(), v.getNum()))

    if trajectoryType == 'object':
        cursor = connection.cursor()
        try:
            # attribute feature numbers to objects
            objectIdQuery = getTrajectoryIdQuery(objectNumbers, trajectoryType)
            cursor.execute('SELECT P.trajectory_id, OF.object_id from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id '+objectIdQuery+'group by P.trajectory_id order by OF.object_id') # order is important to group all features per object

            featureNumbers = {}
            for row in cursor:
                objId = row[1]
                if objId not in featureNumbers:
                    featureNumbers[objId] = [row[0]]
                else:
                    featureNumbers[objId].append(row[0])
                    
            for obj in objects:
                obj.featureNumbers = featureNumbers[obj.getNum()]

            # load userType
            if objectIdQuery == '':
                cursor.execute('SELECT object_id, road_user_type from objects')
            else:
                cursor.execute('SELECT object_id, road_user_type from objects where '+objectIdQuery[7:])
            userTypes = {}
            for row in cursor:
                userTypes[row[0]] = row[1]
            
            for obj in objects:
                obj.userType = userTypes[obj.getNum()]
             
        except sqlite3.OperationalError as error:
            printDBError(error)
            return []

    connection.close()
    return objects

def removeObjectsFromSqlite(filename):
    'Removes the objects and object_features tables in the filename'
    connection = sqlite3.connect(filename)
    utils.dropTables(connection, ['objects', 'objects_features'])
    connection.close()

def deleteIndicators(filename):
    'Deletes all indicator data in db'
    pass

def createInteractionTable(cursor):
    cursor.execute('CREATE TABLE IF NOT EXISTS interactions (id INTEGER PRIMARY KEY, object_id1 INTEGER, object_id2 INTEGER, first_frame_number INTEGER, last_frame_number INTEGER, FOREIGN KEY(object_id1) REFERENCES objects(id), FOREIGN KEY(object_id2) REFERENCES objects(id))')

def createIndicatorTables(cursor):
    # cursor.execute('CREATE TABLE IF NOT EXISTS indicators (id INTEGER PRIMARY KEY, interaction_id INTEGER, indicator_type INTEGER, FOREIGN KEY(interaction_id) REFERENCES interactions(id))')
    # cursor.execute('CREATE TABLE IF NOT EXISTS indicator_values (indicator_id INTEGER, frame_number INTEGER, value REAL, FOREIGN KEY(indicator_id) REFERENCES indicators(id), PRIMARY KEY(indicator_id, frame_number))')
    cursor.execute('CREATE TABLE IF NOT EXISTS indicators (interaction_id INTEGER, indicator_type INTEGER, frame_number INTEGER, value REAL, FOREIGN KEY(interaction_id) REFERENCES interactions(id), PRIMARY KEY(interaction_id, indicator_type, frame_number))')

def saveInteraction(cursor, interaction):
    roadUserNumbers = list(interaction.getRoadUserNumbers())
    cursor.execute('INSERT INTO interactions VALUES({}, {}, {}, {}, {})'.format(interaction.getNum(), roadUserNumbers[0], roadUserNumbers[1], interaction.getFirstInstant(), interaction.getLastInstant()))

def saveInteractions(filename, interactions):
    'Saves the interactions in the table'
    connection = sqlite3.connect(filename)
    cursor = connection.cursor()
    try:
        createInteractionTable(cursor)
        for inter in interactions:
            saveInteraction(cursor, inter)
    except sqlite3.OperationalError as error:
        printDBError(error)
    connection.commit()
    connection.close()

def saveIndicator(cursor, interactionNum, indicator):
    for instant in indicator.getTimeInterval():
        if indicator[instant]:
            cursor.execute('INSERT INTO indicators VALUES({}, {}, {}, {})'.format(interactionNum, events.Interaction.indicatorNameToIndices[indicator.getName()], instant, indicator[instant]))

def saveIndicators(filename, interactions, indicatorNames = events.Interaction.indicatorNames):
    'Saves the indicator values in the table'
    connection = sqlite3.connect(filename)
    cursor = connection.cursor()
    try:
        createInteractionTable(cursor)
        createIndicatorTables(cursor)
        for inter in interactions:
            saveInteraction(cursor, inter)
            for indicatorName in indicatorNames:
                indicator = inter.getIndicator(indicatorName)
                if indicator != None:
                    saveIndicator(cursor, inter.getNum(), indicator)
    except sqlite3.OperationalError as error:
        printDBError(error)
    connection.commit()
    connection.close()

def loadIndicators(filename):
    '''Loads interaction indicators
    
    TODO choose the interactions to load'''
    interactions = []
    connection = sqlite3.connect(filename)
    cursor = connection.cursor()
    try:
        cursor.execute('select INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND where INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type')
        interactionNum = -1
        indicatorTypeNum = -1
        tmpIndicators = {}
        for row in cursor:
            if row[0] != interactionNum: # save interaction and create new interaction
                if interactionNum >= 0:
                    interactions.append(events.Interaction(interactionNum, moving.TimeInterval(row[3],row[4]), roadUserNumbers[0], roadUserNumbers[1]))
                    interactions[-1].indicators = tmpIndicators
                    tmpIndicators = {}
                interactionNum = row[0]
                roadUserNumbers = row[1:3]
            if indicatorTypeNum != row[5]:
                if indicatorTypeNum >= 0:
                    indicatorName = events.Interaction.indicatorNames[indicatorTypeNum]
                    tmpIndicators[indicatorName] = indicators.SeverityIndicator(indicatorName, indicatorValues)
                indicatorTypeNum = row[5]
                indicatorValues = {row[6]:row[7]}
            else:
                indicatorValues[row[6]] = row[7]
        if interactionNum >= 0:
            if indicatorTypeNum >= 0:
                indicatorName = events.Interaction.indicatorNames[indicatorTypeNum]
                tmpIndicators[indicatorName] = indicators.SeverityIndicator(indicatorName, indicatorValues)
            interactions.append(events.Interaction(interactionNum, moving.TimeInterval(row[3],row[4]), roadUserNumbers[0], roadUserNumbers[1]))
            interactions[-1].indicators = tmpIndicators
    except sqlite3.OperationalError as error:
        printDBError(error)
        return []
    connection.close()
    return interactions
# load first and last object instants
# CREATE TEMP TABLE IF NOT EXISTS object_instants AS SELECT OF.object_id, min(frame_number) as first_instant, max(frame_number) as last_instant from positions P, objects_features OF where P.trajectory_id = OF.trajectory_id group by OF.object_id order by OF.object_id


#########################
# txt files
#########################

def loadTrajectoriesFromNgsimFile(filename, nObjects = -1, sequenceNum = -1):
    '''Reads data from the trajectory data provided by NGSIM project 
    and returns the list of Feature objects'''
    objects = []

    input = utils.openCheck(filename)
    if not input:
        import sys
        sys.exit()

    def createObject(numbers):
        firstFrameNum = int(numbers[1])
        # do the geometry and usertype

        firstFrameNum = int(numbers[1])
        lastFrameNum = firstFrameNum+int(numbers[2])-1
        #time = moving.TimeInterval(firstFrameNum, firstFrameNum+int(numbers[2])-1)
        obj = moving.MovingObject(num = int(numbers[0]), 
                                  timeInterval = moving.TimeInterval(firstFrameNum, lastFrameNum), 
                                  positions = moving.Trajectory([[float(numbers[6])],[float(numbers[7])]]), 
                                  userType = int(numbers[10]))
        obj.userType = int(numbers[10])
        obj.laneNums = [int(numbers[13])]
        obj.precedingVehicles = [int(numbers[14])] # lead vehicle (before)
        obj.followingVehicles = [int(numbers[15])] # following vehicle (after)
        obj.spaceHeadways = [float(numbers[16])] # feet
        obj.timeHeadways = [float(numbers[17])] # seconds
        obj.curvilinearPositions = moving.CurvilinearTrajectory([float(numbers[5])],[float(numbers[4])], obj.laneNums) # X is the longitudinal coordinate
        obj.speeds = [float(numbers[11])]
        obj.size = [float(numbers[8]), float(numbers[9])] # 8 lengh, 9 width # TODO: temporary, should use a geometry object
        return obj

    numbers = input.readline().strip().split()
    if (len(numbers) > 0):
        obj = createObject(numbers)

    for line in input:
        numbers = line.strip().split()
        if obj.getNum() != int(numbers[0]):
            # check and adapt the length to deal with issues in NGSIM data
            if (obj.length() != obj.positions.length()):
                print 'length pb with object %s (%d,%d)' % (obj.getNum(),obj.length(),obj.positions.length())
                obj.last = obj.getFirstInstant()+obj.positions.length()-1
                #obj.velocities = utils.computeVelocities(f.positions) # compare norm to speeds ?
            objects.append(obj)
            if (nObjects>0) and (len(objects)>=nObjects):
                break
            obj = createObject(numbers)
        else:
            obj.laneNums.append(int(numbers[13]))
            obj.positions.addPositionXY(float(numbers[6]), float(numbers[7]))
            obj.curvilinearPositions.addPosition(float(numbers[5]), float(numbers[4]), obj.laneNums[-1])
            obj.speeds.append(float(numbers[11]))
            obj.precedingVehicles.append(int(numbers[14]))
            obj.followingVehicles.append(int(numbers[15]))
            obj.spaceHeadways.append(float(numbers[16]))
            obj.timeHeadways.append(float(numbers[17]))

            if (obj.size[0] != float(numbers[8])):
                print 'changed length obj %d' % (obj.getNum())
            if (obj.size[1] != float(numbers[9])):
                print 'changed width obj %d' % (obj.getNum())
    
    input.close()
    return objects

def convertNgsimFile(inFile, outFile, append = False, nObjects = -1, sequenceNum = 0):
    '''Reads data from the trajectory data provided by NGSIM project
    and converts to our current format.'''
    if append:
        out = open(outFile,'a')
    else:
        out = open(outFile,'w')
    nObjectsPerType = [0,0,0]

    features = loadNgsimFile(inFile, sequenceNum)
    for f in features:
        nObjectsPerType[f.userType-1] += 1
        f.write(out)

    print nObjectsPerType
        
    out.close()

def writePositionsToCsv(f, obj):
    timeInterval = obj.getTimeInterval()
    positions = obj.getPositions()
    curvilinearPositions = obj.getCurvilinearPositions()
    for i in xrange(int(obj.length())):
        p1 = positions[i]
        s = '{},{},{},{}'.format(obj.num,timeInterval[i],p1.x,p1.y)
        if curvilinearPositions != None:
            p2 = curvilinearPositions[i]
            s += ',{},{}'.format(p2[0],p2[1])
        f.write(s+'\n')

def writeTrajectoriesToCsv(filename, objects):
    f = open(filename, 'w')
    for i,obj in enumerate(objects):
        writePositionsToCsv(f, obj)
    f.close()

if __name__ == "__main__":
    import doctest
    import unittest
    suite = doctest.DocFileSuite('tests/storage.txt')
    unittest.TextTestRunner().run(suite)
#     #doctest.testmod()
#     #doctest.testfile("example.txt")