view python/storage.py @ 259:8ab76b95ee72

added code to save collision points and crossing zones in txt files
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 24 Jul 2012 15:18:12 -0400
parents 583a2c4622f9
children c71540470057
line wrap: on
line source

#! /usr/bin/env python
# -*- coding: utf-8 -*-
'''Various utilities to save and load data'''

import utils
import moving

__metaclass__ = type


ngsimUserTypes = {'twowheels':1,
                  'car':2,
                  'truck':3}

#########################
# txt files
#########################



#########################
# Sqlite
#########################

def writeTrajectoriesToSqlite(objects, outFile, trajectoryType, objectNumbers = -1):
    """
    This function writers trajectories to a specified sqlite file
    @param[in] objects -> a list of trajectories
    @param[in] trajectoryType -
    @param[out] outFile -> the .sqlite file containting the written objects
    @param[in] objectNumber : number of objects loaded
    """

    import sqlite3
    connection = sqlite3.connect(outFile)
    cursor = connection.cursor()

    schema = "CREATE TABLE \"positions\"(trajectory_id INTEGER,frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))"
    cursor.execute(schema)

    trajectory_id = 0
    frame_number = 0
    if trajectoryType == 'feature':
        if type(objectNumbers) == int and objectNumbers == -1:
            for trajectory in objects:
                trajectory_id += 1
                frame_number = 0
                for position in trajectory.getPositions():
                    frame_number += 1
                    query = "insert into positions (trajectory_id, frame_number, x_coordinate, y_coordinate) values (?,?,?,?)"
                    cursor.execute(query,(trajectory_id,frame_number,position.x,position.y))
                    
    connection.commit()            
    connection.close()

def loadPrototypeMatchIndexesFromSqlite(filename):
    """
    This function loads the prototypes table in the database of name <filename>.
    It returns a list of tuples representing matching ids : [(prototype_id, matched_trajectory_id),...]
    """
    matched_indexes = []

    import sqlite3    
    connection = sqlite3.connect(filename)
    cursor = connection.cursor()

    try:
        cursor.execute('SELECT * from prototypes order by prototype_id, trajectory_id_matched')
    except sqlite3.OperationalError as err:
        print('DB Error: {0}'.format(err))
        return []

    for row in cursor:
        matched_indexes.append((row[0],row[1]))

    connection.close()
    return matched_indexes

def loadTrajectoriesFromTable(connection, tableName, trajectoryType, objectNumbers = -1):
    '''Loads trajectories (in the general sense) from the given table
    can be positions or velocities

    returns a moving object'''
    import sqlite3

    cursor = connection.cursor()

    try:
        if trajectoryType == 'feature':
            if type(objectNumbers) == int:
                if objectNumbers == -1:
                    cursor.execute('SELECT * from '+tableName+' order by trajectory_id, frame_number')
                else:
                    cursor.execute('SELECT * from {0} where trajectory_id between 0 and {1} order by trajectory_id, frame_number'.format(tableName, objectNumbers))
            elif type(objectNumbers) == list:
                cursor.execute('SELECT * from '+tableName+' where trajectory_id in ('+', '.join([str(n) for n in objectNumbers])+') order by trajectory_id, frame_number')
        elif trajectoryType == 'object':
            if type(objectNumbers) == int:
                if objectNumbers == -1:
                    cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id group by object_id, frame_number')
                else:
                    cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from {0} P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.object_id between 0 and {1} group by object_id, frame_number'.format(tableName, objectNumbers))
            elif type(objectNumbers) == list:
                cursor.execute('SELECT OF.object_id, P.frame_number, avg(P.x_coordinate), avg(P.y_coordinate) from '+tableName+' P, objects_features OF where P.trajectory_id = OF.trajectory_id and OF.object_id in ('+', '.join([str(n) for n in objectNumbers])+') group by object_id, frame_number')
        else:
            print('no trajectory type was chosen')
    except sqlite3.OperationalError as err:
        print('DB Error: {0}'.format(err))
        return []

    objId = -1
    obj = None
    objects = []
    for row in cursor:
        if row[0] != objId:
            objId = row[0]
            if obj:
                objects.append(obj)
            obj = moving.MovingObject(row[0], timeInterval = moving.TimeInterval(row[1], row[1]), positions = moving.Trajectory([[row[2]],[row[3]]]))
        else:
            obj.timeInterval.last = row[1]
            obj.positions.addPositionXY(row[2],row[3])

    if obj:
        objects.append(obj)

    return objects

def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = -1):
    '''Loads nObjects or the indices in objectNumbers from the database 
    TODO: load feature numbers and not average feature trajectories
    TODO: other ways of averaging trajectories (load all points, sorted by frame_number and leave the agregation to be done in python)
    '''
    import sqlite3
    
    connection = sqlite3.connect(filename) # add test if it open

    objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers)
    objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers)

    if len(objectVelocities) > 0:
        for o,v in zip(objects, objectVelocities):
            if o.num == v.num:
                o.velocities = v.positions
            else:
                print('Could not match positions {0} with velocities {1}'.format(o.num, v.num))

    connection.close()
    return objects

def removeObjectsFromSqlite(filename):
    'Removes the objects and object_features tables in the filename'
    import sqlite3
    connection = sqlite3.connect(filename)
    utils.dropTables(connection, ['objects', 'objects_features'])
    connection.close()

def loadTrajectoriesFromNgsimFile(filename, nObjects = -1, sequenceNum = -1):
    '''Reads data from the trajectory data provided by NGSIM project 
    and returns the list of Feature objects'''
    objects = []

    input = utils.openCheck(filename)
    if not input:
        import sys
        sys.exit()

    def createObject(numbers):
        firstFrameNum = int(numbers[1])
        # do the geometry and usertype

        firstFrameNum = int(numbers[1])
        lastFrameNum = firstFrameNum+int(numbers[2])-1
        #time = moving.TimeInterval(firstFrameNum, firstFrameNum+int(numbers[2])-1)
        obj = moving.MovingObject(num = int(numbers[0]), 
                                  timeInterval = moving.TimeInterval(firstFrameNum, lastFrameNum), 
                                  positions = moving.Trajectory([[float(numbers[6])],[float(numbers[7])]]), 
                                  userType = int(numbers[10]))
        obj.userType = int(numbers[10])
        obj.laneNums = [int(numbers[13])]
        obj.precedingVehicles = [int(numbers[14])] # lead vehicle (before)
        obj.followingVehicles = [int(numbers[15])] # following vehicle (after)
        obj.spaceHeadways = [float(numbers[16])] # feet
        obj.timeHeadways = [float(numbers[17])] # seconds
        obj.curvilinearPositions = moving.Trajectory([[float(numbers[5])],[float(numbers[4])]]) # X is the longitudinal coordinate
        obj.speeds = [float(numbers[11])]
        obj.size = [float(numbers[8]), float(numbers[9])] # 8 lengh, 9 width # TODO: temporary, should use a geometry object
        return obj

    numbers = input.readline().strip().split()
    if (len(numbers) > 0):
        obj = createObject(numbers)

    for line in input:
        numbers = line.strip().split()
        if obj.num != int(numbers[0]):
            # check and adapt the length to deal with issues in NGSIM data
            if (obj.length() != obj.positions.length()):
                print 'length pb with object %s (%d,%d)' % (obj.num,obj.length(),obj.positions.length())
                obj.last = obj.getFirstInstant()+obj.positions.length()-1
                #obj.velocities = utils.computeVelocities(f.positions) # compare norm to speeds ?
            objects.append(obj)
            if (nObjects>0) and (len(objects)>=nObjects):
                break
            obj = createObject(numbers)
        else:
            obj.positions.addPositionXY(float(numbers[6]), float(numbers[7]))
            obj.curvilinearPositions.addPositionXY(float(numbers[5]), float(numbers[4]))
            obj.speeds.append(float(numbers[11]))
            obj.laneNums.append(int(numbers[13]))
            obj.precedingVehicles.append(int(numbers[14]))
            obj.followingVehicles.append(int(numbers[15]))
            obj.spaceHeadways.append(float(numbers[16]))
            obj.timeHeadways.append(float(numbers[17]))

            if (obj.size[0] != float(numbers[8])):
                print 'changed length obj %d' % (f.num)
            if (obj.size[1] != float(numbers[9])):
                print 'changed width obj %d' % (f.num)
    
    input.close()
    return objects

def convertNgsimFile(inFile, outFile, append = False, nObjects = -1, sequenceNum = 0):
    '''Reads data from the trajectory data provided by NGSIM project
    and converts to our current format.'''
    if append:
        out = open(outFile,'a')
    else:
        out = open(outFile,'w')
    nObjectsPerType = [0,0,0]

    features = loadNgsimFile(inFile, sequenceNum)
    for f in features:
        nObjectsPerType[f.userType-1] += 1
        f.write(out)

    print nObjectsPerType
        
    out.close()

if __name__ == "__main__":
    import doctest
    import unittest
    suite = doctest.DocFileSuite('tests/storage.txt')
    unittest.TextTestRunner().run(suite)
#     #doctest.testmod()
#     #doctest.testfile("example.txt")