diff scripts/learn-motion-patterns.py @ 979:cc89267b5ff9

work on learning and assigning
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 19 Feb 2018 10:47:19 -0500
parents 989917b1ed85
children 23f98ebb113f
line wrap: on
line diff
--- a/scripts/learn-motion-patterns.py	Thu Feb 08 05:53:50 2018 -0500
+++ b/scripts/learn-motion-patterns.py	Mon Feb 19 10:47:19 2018 -0500
@@ -42,23 +42,21 @@
 # save the objects that match the prototypes
 # write an assignment function for objects
 
-trajectoryType = args.trajectoryType
-prototypeType = args.trajectoryType
+# load trajectories to cluster or assign
 if args.trajectoryType == 'objectfeatures':
-    trajectoryType = 'object'
-    prototypeType = 'feature'
-
-if args.trajectoryType == 'objectfeatures':
+    trajectoryType = 'feature'
     objectFeatureNumbers = storage.loadObjectFeatureFrameNumbers(args.databaseFilename, objectNumbers = args.nTrajectories)
     featureNumbers = []
     for numbers in objectFeatureNumbers.values():
         featureNumbers += numbers[:min(len(numbers), args.maxNObjectFeatures)]
     objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, 'feature', objectNumbers = featureNumbers, timeStep = args.positionSubsamplingRate)
 else:
-    objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, withFeatures = (args.trajectoryType == 'objectfeatures'), objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate)
+    trajectoryType = args.trajectoryType
+    objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate)
 
-    
 trajectories = [o.getPositions().asArray().T for o in objects]
+
+# load initial prototypes, if any    
 if args.inputPrototypeDatabaseFilename is not None:
     initialPrototypes = storage.loadPrototypesFromSqlite(args.inputPrototypeDatabaseFilename, True)
     trajectories = [p.getMovingObject().getPositions().asArray().T for p in initialPrototypes]+trajectories
@@ -99,7 +97,7 @@
             initialPrototypes[i].nMatchings += nMatchings
             prototypes.append(initialPrototypes[i])
         else:
-            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), prototypeType, nMatchings))
+            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), trajectoryType, nMatchings))
 
     if args.outputPrototypeDatabaseFilename is None:
         outputPrototypeDatabaseFilename = args.databaseFilename
@@ -114,10 +112,8 @@
         np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')
 
     labelsToProtoIndices = {protoId: i for i, protoId in enumerate(prototypeIndices)}
-    if args.assign and args.saveMatches: # or args.assign
-    # save in the db that contained originally the data
-        # retirer les assignations anterieures?
-        storage.savePrototypeAssignmentsToSqlite(args.databaseFilename, objects, [labelsToProtoIndices[l] for l in labels], prototypes)
+    if args.assign and args.saveMatches:
+        storage.savePrototypeAssignmentsToSqlite(args.databaseFilename, objects, trajectoryType, [labelsToProtoIndices[l] for l in labels], prototypes)
 
     if args.display and args.assign:
         from matplotlib.pyplot import figure, show, axis