Mercurial Hosting > traffic-intelligence
changeset 910:b58a1061a717
loading is faster for longest object features
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 28 Jun 2017 15:36:25 -0400 |
parents | cd038493f8c6 |
children | 3dd5acfa1899 |
files | python/storage.py scripts/learn-motion-patterns.py |
diffstat | 2 files changed, 33 insertions(+), 11 deletions(-) [+] |
line wrap: on
line diff
--- a/python/storage.py Mon Jun 26 17:45:32 2017 -0400 +++ b/python/storage.py Wed Jun 28 15:36:25 2017 -0400 @@ -234,7 +234,7 @@ def loadUserTypesFromTable(cursor, objectNumbers): objectCriteria = getObjectCriteria(objectNumbers) - queryStatement = 'SELECT object_id, road_user_type from objects' + queryStatement = 'SELECT object_id, road_user_type FROM objects' if objectNumbers is not None: queryStatement += ' WHERE object_id '+objectCriteria cursor.execute(queryStatement) @@ -265,13 +265,12 @@ cursor = connection.cursor() try: # attribute feature numbers to objects - objectCriteria = getObjectCriteria(objectNumbers) queryStatement = 'SELECT trajectory_id, object_id FROM objects_features' if objectNumbers is not None: - queryStatement += ' WHERE object_id '+objectCriteria + queryStatement += ' WHERE object_id '+getObjectCriteria(objectNumbers) queryStatement += ' ORDER BY object_id' # order is important to group all features per object + logging.debug(queryStatement) cursor.execute(queryStatement) - logging.debug(queryStatement) featureNumbers = {} for row in cursor: @@ -304,6 +303,29 @@ connection.close() return objects +def loadObjectFeatureFrameNumbers(filename, objectNumbers = None): + 'Loads the feature frame numbers for each object' + connection = sqlite3.connect(filename) + cursor = connection.cursor() + try: + queryStatement = 'SELECT OF.object_id, TL.trajectory_id, TL.length FROM (SELECT trajectory_id, max(frame_number)-min(frame_number) AS length FROM positions GROUP BY trajectory_id) TL, objects_features OF WHERE TL.trajectory_id = OF.trajectory_id' + if objectNumbers is not None: + queryStatement += ' AND object_id '+getObjectCriteria(objectNumbers) + queryStatement += ' ORDER BY OF.object_id, TL.length DESC' + logging.debug(queryStatement) + cursor.execute(queryStatement) + objectFeatureNumbers = {} + for row in cursor: + objId = row[0] + if objId in objectFeatureNumbers: + objectFeatureNumbers[objId].append(row[1]) + else: + objectFeatureNumbers[objId] = [row[1]] + return objectFeatureNumbers + except sqlite3.OperationalError as error: + printDBError(error) + return None + def addCurvilinearTrajectoriesFromSqlite(filename, objects): '''Adds curvilinear positions (s_coordinate, y_coordinate, lane) from a database to an existing MovingObject dict (indexed by each objects's num)'''
--- a/scripts/learn-motion-patterns.py Mon Jun 26 17:45:32 2017 -0400 +++ b/scripts/learn-motion-patterns.py Wed Jun 28 15:36:25 2017 -0400 @@ -45,14 +45,14 @@ trajectoryType = 'object' prototypeType = 'feature' -#features = storage.loadTrajectoriesFromSqlite(databaseFilename, args.trajectoryType) -objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, withFeatures = (args.trajectoryType == 'objectfeatures'), objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate) - if args.trajectoryType == 'objectfeatures': - features = [] - for o in objects: - features += o.getNLongestFeatures(args.maxNObjectFeatures) - objects = features + objectFeatureNumbers = storage.loadObjectFeatureFrameNumbers(args.databaseFilename, objectNumbers = args.nTrajectories) + featureNumbers = [] + for numbers in objectFeatureNumbers.values(): + featureNumbers += numbers[:min(len(numbers), args.maxNObjectFeatures)] + objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, 'feature', objectNumbers = featureNumbers, timeStep = args.positionSubsamplingRate) +else: + objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, withFeatures = (args.trajectoryType == 'objectfeatures'), objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate) trajectories = [o.getPositions().asArray().T for o in objects]