changeset 1033:8ffb3ae9f3d2

work in progress
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 20 Jun 2018 00:07:03 -0400
parents d0e339359d8a
children 4069d8545922
files scripts/learn-motion-patterns.py trafficintelligence/ml.py trafficintelligence/storage.py
diffstat 3 files changed, 63 insertions(+), 35 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/learn-motion-patterns.py	Tue Jun 19 17:07:50 2018 -0400
+++ b/scripts/learn-motion-patterns.py	Wed Jun 20 00:07:03 2018 -0400
@@ -25,7 +25,7 @@
 parser.add_argument('--display', dest = 'display', help = 'display trajectories', action = 'store_true')
 parser.add_argument('--save-similarities', dest = 'saveSimilarities', help = 'save computed similarities (in addition to prototypes)', action = 'store_true')
 parser.add_argument('--save-matches', dest = 'saveMatches', help = 'saves the assignments of the objects (not for features) to the prototypes', action = 'store_true')
-parser.add_argument('--assign', dest = 'assign', help = 'assigns the objects to the prototypes and saves them (do not use min cluster size as it will discard prototypes at the beginning if the initial cluster is too small)', action = 'store_true')
+parser.add_argument('--assign', dest = 'assign', help = 'assigns the objects to the prototypes and saves the assignments', action = 'store_true')
 
 args = parser.parse_args()
 
@@ -78,26 +78,27 @@
 else:
     prototypeIndices = initialPrototypeIndices
 
-if args.assign:
-    if not args.learn and args.minClusterSize >= 1:
-        print('Warning: you did not learn the prototypes and you are using minimum cluster size of {}, which may lead to removing prototypes and assigning them to others'.format(args.minClusterSize))
-    prototypeIndices, labels = ml.assignToPrototypeClusters(trajectories, prototypeIndices, similarities, args.minSimilarity, similarityFunc, args.minClusterSize)
+if args.assign: # TODO don't touch initial prototypes if not from same db as trajectories
+    #if not args.learn and args.minClusterSize >= 1: # allow only 
+    #   print('Warning: you did not learn the prototypes and you are using minimum cluster size of {}, which may lead to removing prototypes and assigning them to others'.format(args.minClusterSize))
+    # if args.minClusterSize >= 1:
+    #     if initialPrototypeIndices is None:
+    #         prototypeIndices, labels = ml.assignToPrototypeClusters(trajectories, prototypeIndices, similarities, args.minSimilarity, similarityFunc, args.minClusterSize)
+    #     else:
+    #         print('Not assigning with non-zero minimum cluster size and initial prototypes (would remove initial prototypes based on other trajectories')
+    # else:
+    #     prototypeIndices, labels = ml.assignToPrototypeClusters(trajectories, prototypeIndices, similarities, args.minSimilarity, similarityFunc)
+    prototypeIndices, labels = ml.assignToPrototypeClusters(trajectories, prototypeIndices, similarities, args.minSimilarity, similarityFunc)
     clusterSizes = ml.computeClusterSizes(labels, prototypeIndices, -1)
     print(clusterSizes)
 
-if args.learn or args.assign:
+if args.learn and not args.assign:
     prototypes = []
     for i in prototypeIndices:
-        if args.assign:
-            nMatchings = clusterSizes[i]-1
-        #else:
-        #    nMatchings = 0
         if i<len(initialPrototypes):
-            if args.assign:
-                initialPrototypes[i].nMatchings += nMatchings
             prototypes.append(initialPrototypes[i])
         else:
-            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), trajectoryType, nMatchings))
+            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), trajectoryType))
 
     if args.outputPrototypeDatabaseFilename is None:
         outputPrototypeDatabaseFilename = args.databaseFilename
@@ -107,26 +108,38 @@
             storage.deleteFromSqlite(args.outputPrototypeDatabaseFilename, 'prototype')
     storage.savePrototypesToSqlite(outputPrototypeDatabaseFilename, prototypes)
 
-    if args.saveSimilarities:
-        # todo save trajectories and prototypes
-        np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')
+if not args.learn and args.assign: # no new prototypes # not save assignments of past prototypes if removes with minClusterSize
+    prototypes = []
+    for i in prototypeIndices:
+        nMatchings = clusterSizes[i]-1
+        if initialPrototypes[i].nMatchings is None:
+            initialPrototypes[i].nMatchings = nMatchings
+        else:
+            initialPrototypes[i].nMatchings += nMatchings
+        prototypes.append(initialPrototypes[i])
+    if args.outputPrototypeDatabaseFilename is None:
+        outputPrototypeDatabaseFilename = args.databaseFilename
+    else:
+        outputPrototypeDatabaseFilename = args.outputPrototypeDatabaseFilename
+    storage.setPrototypeMatchingsInSqlite(outputPrototypeDatabaseFilename, prototypes)
 
     labelsToProtoIndices = {protoId: i for i, protoId in enumerate(prototypeIndices)}
-    if args.assign and args.saveMatches:
+    if args.saveMatches:
         storage.savePrototypeAssignmentsToSqlite(args.databaseFilename, objects, trajectoryType, [labelsToProtoIndices[l] for l in labels], prototypes)
 
-    if args.display and args.assign:
-        from matplotlib.pyplot import figure, show, axis
-        figure()
-        for i,o in enumerate(objects):
-            if i not in prototypeIndices:
-                if labels[i] < 0:
-                    o.plot('kx')
-                else:
-                    o.plot(utils.colors[labels[i]])
-        for i in prototypeIndices:
-                objects[i].plot(utils.colors[i]+'o')
-        axis('equal')
-        show()
-else:
-    print('Not learning nor assigning: doing nothing')        
+if (args.learn or args.assign) and args.saveSimilarities:
+    np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')
+
+if args.display and args.assign:
+    from matplotlib.pyplot import figure, show, axis
+    figure()
+    for i,o in enumerate(objects):
+        if i not in prototypeIndices:
+            if labels[i] < 0:
+                o.plot('kx')
+            else:
+                o.plot(utils.colors[labels[i]])
+    for i in prototypeIndices:
+            objects[i].plot(utils.colors[i]+'o')
+    axis('equal')
+    show()
--- a/trafficintelligence/ml.py	Tue Jun 19 17:07:50 2018 -0400
+++ b/trafficintelligence/ml.py	Wed Jun 20 00:07:03 2018 -0400
@@ -210,7 +210,7 @@
         return None
 
     # sort instances based on length
-    indices = range(len(instances))
+    indices = list(range(len(instances)))
     if randomInitialization or optimizeCentroid:
         indices = np.random.permutation(indices).tolist()
     else:
@@ -221,7 +221,7 @@
                 return 0
             else:
                 return 1
-        indices.sort(compare)
+        indices.sort(key=lambda i: len(instances[i]))
     # initialize clusters
     clusters = []
     if initialPrototypeIndices is None:
--- a/trafficintelligence/storage.py	Tue Jun 19 17:07:50 2018 -0400
+++ b/trafficintelligence/storage.py	Wed Jun 20 00:07:03 2018 -0400
@@ -563,7 +563,7 @@
 #########################
 
 def savePrototypesToSqlite(filename, prototypes):
-    '''save the prototypes (a prototype is defined by a filename, a number (id) and type'''
+    '''save the prototypes (a prototype is defined by a filename, a number (id) and type)'''
     with sqlite3.connect(filename) as connection:
         cursor = connection.cursor()
         try:
@@ -574,6 +574,21 @@
             printDBError(error)
         connection.commit()
 
+def setPrototypeMatchingsInSqlite(filename, prototypes):
+    '''updates the prototype matchings'''
+    with sqlite3.connect(filename) as connection:
+        cursor = connection.cursor()
+        try:
+            for p in prototypes:
+                if p.getNMatchings() is None:
+                    nMatchings = 'NULL'
+                else:
+                    nMatchings = p.getNMatchings()
+                cursor.execute('UPDATE prototypes SET nmatchings = {} WHERE prototype_filename = \"{}\" AND prototype_id = {} AND trajectory_type = \"{}\"'.format(nMatchings, p.getFilename(), p.getNum(), p.getTrajectoryType()))
+        except sqlite3.OperationalError as error:
+            printDBError(error)
+        connection.commit()
+
 def savePrototypeAssignmentsToSqlite(filename, objects, objectType, labels, prototypes):
     with sqlite3.connect(filename) as connection:
         cursor = connection.cursor()