view scripts/learn-motion-patterns.py @ 1032:d0e339359d8a

work in progress
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 19 Jun 2018 17:07:50 -0400
parents cc5cb04b04b0
children 8ffb3ae9f3d2
line wrap: on
line source

#! /usr/bin/env python3

import sys, argparse

import numpy as np

from trafficintelligence import ml, utils, storage, moving

parser = argparse.ArgumentParser(description='The program learns prototypes for the motion patterns') #, epilog = ''
#parser.add_argument('--cfg', dest = 'configFilename', help = 'name of the configuration file')
parser.add_argument('-d', dest = 'databaseFilename', help = 'name of the Sqlite database file', required = True)
parser.add_argument('-o', dest = 'outputPrototypeDatabaseFilename', help = 'name of the Sqlite database file to save prototypes')
parser.add_argument('-i', dest = 'inputPrototypeDatabaseFilename', help = 'name of the Sqlite database file for prototypes to start the algorithm with')
parser.add_argument('-t', dest = 'trajectoryType', help = 'type of trajectories to learn from', choices = ['objectfeatures', 'feature', 'object'], default = 'objectfeatures')
parser.add_argument('--max-nobjectfeatures', dest = 'maxNObjectFeatures', help = 'maximum number of features per object to load', type = int, default = 1)
parser.add_argument('-n', dest = 'nTrajectories', help = 'number of the object or feature trajectories to load', type = int, default = None)
parser.add_argument('-e', dest = 'epsilon', help = 'distance for the similarity of trajectory points', type = float, required = True)
parser.add_argument('--metric', dest = 'metric', help = 'metric for the similarity of trajectory points', default = 'cityblock') # default is manhattan distance
parser.add_argument('-s', dest = 'minSimilarity', help = 'minimum similarity to put a trajectory in a cluster', type = float, required = True)
parser.add_argument('-c', dest = 'minClusterSize', help = 'minimum cluster size', type = int, default = 0)
parser.add_argument('--learn', dest = 'learn', help = 'learn', action = 'store_true')
parser.add_argument('--optimize', dest = 'optimizeCentroid', help = 'recompute centroid at each assignment', action = 'store_true')
parser.add_argument('--random', dest = 'randomInitialization', help = 'random initialization of clustering algorithm', action = 'store_true')
parser.add_argument('--subsample', dest = 'positionSubsamplingRate', help = 'rate of position subsampling (1 every n positions)', type = int)
parser.add_argument('--display', dest = 'display', help = 'display trajectories', action = 'store_true')
parser.add_argument('--save-similarities', dest = 'saveSimilarities', help = 'save computed similarities (in addition to prototypes)', action = 'store_true')
parser.add_argument('--save-matches', dest = 'saveMatches', help = 'saves the assignments of the objects (not for features) to the prototypes', action = 'store_true')
parser.add_argument('--assign', dest = 'assign', help = 'assigns the objects to the prototypes and saves them (do not use min cluster size as it will discard prototypes at the beginning if the initial cluster is too small)', action = 'store_true')

args = parser.parse_args()

# use cases
# 1. learn proto from one file, save in same or another
# 2. load proto, load objects, update proto, save proto
# 3. assign objects from one db to proto
# 4. load objects from several files, save in another -> see metadata: site with view and times
# 5. keep prototypes, with positions/velocities, in separate db (keep link to original data through filename, type and index)

# TODO add possibility to cluster with velocities
# TODO add possibilite to load all trajectories and use minclustersize
# save the objects that match the prototypes
# write an assignment function for objects

# load trajectories to cluster or assign
if args.trajectoryType == 'objectfeatures':
    trajectoryType = 'feature'
    objectFeatureNumbers = storage.loadObjectFeatureFrameNumbers(args.databaseFilename, objectNumbers = args.nTrajectories)
    featureNumbers = []
    for numbers in objectFeatureNumbers.values():
        featureNumbers += numbers[:min(len(numbers), args.maxNObjectFeatures)]
    objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, 'feature', objectNumbers = featureNumbers, timeStep = args.positionSubsamplingRate)
else:
    trajectoryType = args.trajectoryType
    objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate)

trajectories = [o.getPositions().asArray().T for o in objects]

# load initial prototypes, if any    
if args.inputPrototypeDatabaseFilename is not None:
    initialPrototypes = storage.loadPrototypesFromSqlite(args.inputPrototypeDatabaseFilename, True)
    trajectories = [p.getMovingObject().getPositions().asArray().T for p in initialPrototypes]+trajectories
    if len(initialPrototypes) > 0:
        initialPrototypeIndices = list(range(len(initialPrototypes)))
    else:
        initialPrototypeIndices = None
else:
    initialPrototypes = []
    initialPrototypeIndices = None

lcss = utils.LCSS(metric = args.metric, epsilon = args.epsilon)
nTrajectories = len(trajectories)

similarities = -np.ones((nTrajectories, nTrajectories))
similarityFunc = lambda x,y : lcss.computeNormalized(x, y)
# the next line can be called again without reinitializing similarities
if args.learn:
    prototypeIndices = ml.prototypeCluster(trajectories, similarities, args.minSimilarity, similarityFunc, args.optimizeCentroid, args.randomInitialization, initialPrototypeIndices)
else:
    prototypeIndices = initialPrototypeIndices

if args.assign:
    if not args.learn and args.minClusterSize >= 1:
        print('Warning: you did not learn the prototypes and you are using minimum cluster size of {}, which may lead to removing prototypes and assigning them to others'.format(args.minClusterSize))
    prototypeIndices, labels = ml.assignToPrototypeClusters(trajectories, prototypeIndices, similarities, args.minSimilarity, similarityFunc, args.minClusterSize)
    clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
    print(clusterSizes)

if args.learn or args.assign:
    prototypes = []
    for i in prototypeIndices:
        if args.assign:
            nMatchings = clusterSizes[i]-1
        #else:
        #    nMatchings = 0
        if i<len(initialPrototypes):
            if args.assign:
                initialPrototypes[i].nMatchings += nMatchings
            prototypes.append(initialPrototypes[i])
        else:
            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), trajectoryType, nMatchings))

    if args.outputPrototypeDatabaseFilename is None:
        outputPrototypeDatabaseFilename = args.databaseFilename
    else:
        outputPrototypeDatabaseFilename = args.outputPrototypeDatabaseFilename
        if args.inputPrototypeDatabaseFilename == args.outputPrototypeDatabaseFilename:
            storage.deleteFromSqlite(args.outputPrototypeDatabaseFilename, 'prototype')
    storage.savePrototypesToSqlite(outputPrototypeDatabaseFilename, prototypes)

    if args.saveSimilarities:
        # todo save trajectories and prototypes
        np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')

    labelsToProtoIndices = {protoId: i for i, protoId in enumerate(prototypeIndices)}
    if args.assign and args.saveMatches:
        storage.savePrototypeAssignmentsToSqlite(args.databaseFilename, objects, trajectoryType, [labelsToProtoIndices[l] for l in labels], prototypes)

    if args.display and args.assign:
        from matplotlib.pyplot import figure, show, axis
        figure()
        for i,o in enumerate(objects):
            if i not in prototypeIndices:
                if labels[i] < 0:
                    o.plot('kx')
                else:
                    o.plot(utils.colors[labels[i]])
        for i in prototypeIndices:
                objects[i].plot(utils.colors[i]+'o')
        axis('equal')
        show()
else:
    print('Not learning nor assigning: doing nothing')