changeset 979:cc89267b5ff9

work on learning and assigning
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 19 Feb 2018 10:47:19 -0500
parents 184f1dd307f9
children 23f98ebb113f
files python/storage.py scripts/learn-motion-patterns.py
diffstat 2 files changed, 21 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/python/storage.py	Thu Feb 08 05:53:50 2018 -0500
+++ b/python/storage.py	Mon Feb 19 10:47:19 2018 -0500
@@ -78,10 +78,10 @@
 def createObjectsTable(cursor):
     cursor.execute("CREATE TABLE IF NOT EXISTS objects (object_id INTEGER, road_user_type INTEGER, n_objects INTEGER, PRIMARY KEY(object_id))")
 
-def createAssignmentTable(cursor, objectType1, objectType2, columnName1, columnName2):
-    cursor.execute("CREATE TABLE IF NOT EXISTS "+objectType1+"s_"+objectType2+"s ("+columnName1+" INTEGER, "+columnName1+" INTEGER, PRIMARY KEY("+columnName1+","+columnName2+"))")
+def createAssignmentTable(cursor, objectType1, objectType2, objectIdColumnName1, objectIdColumnName2):
+    cursor.execute("CREATE TABLE IF NOT EXISTS "+objectType1+"s_"+objectType2+"s ("+objectIdColumnName1+" INTEGER, "+objectIdColumnName2+" INTEGER, PRIMARY KEY("+objectIdColumnName1+","+objectIdColumnName2+"))")
 
-def createObjectsFeaturesTable(cursor): # same as 
+def createObjectsFeaturesTable(cursor):
     cursor.execute("CREATE TABLE IF NOT EXISTS objects_features (object_id INTEGER, trajectory_id INTEGER, PRIMARY KEY(object_id, trajectory_id))")
 
 
@@ -565,7 +565,7 @@
 #########################
 
 def savePrototypesToSqlite(filename, prototypes):
-    '''save the prototypes (a prototype is defined by a filename, a number and type'''
+    '''save the prototypes (a prototype is defined by a filename, a number (id) and type'''
     with sqlite3.connect(filename) as connection:
         cursor = connection.cursor()
         try:
@@ -576,11 +576,17 @@
             printDBError(error)
         connection.commit()
 
-def savePrototypeAssignmentsToSqlite(filename, objects, labels, prototypes):
+def savePrototypeAssignmentsToSqlite(filename, objects, objectType, labels, prototypes):
     with sqlite3.connect(filename) as connection:
         cursor = connection.cursor()
         try:
-            cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))')
+            if objectType == 'feature':
+                tableName = 'features_prototypes'
+                objectIdColumnName = 'trajectory_id'
+            elif objectType == 'object':
+                tableName = 'objects_prototypes'
+                objectIdColumnName = 'object_id'
+            cursor.execute('CREATE TABLE IF NOT EXISTS '+tableName+' ('+objectIdColumnName+' INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY('+objectIdColumnName+', prototype_filename, prototype_id, trajectory_type))')
             for obj, label in zip(objects, labels):
                 proto = prototypes[label]
                 cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType()))
--- a/scripts/learn-motion-patterns.py	Thu Feb 08 05:53:50 2018 -0500
+++ b/scripts/learn-motion-patterns.py	Mon Feb 19 10:47:19 2018 -0500
@@ -42,23 +42,21 @@
 # save the objects that match the prototypes
 # write an assignment function for objects
 
-trajectoryType = args.trajectoryType
-prototypeType = args.trajectoryType
+# load trajectories to cluster or assign
 if args.trajectoryType == 'objectfeatures':
-    trajectoryType = 'object'
-    prototypeType = 'feature'
-
-if args.trajectoryType == 'objectfeatures':
+    trajectoryType = 'feature'
     objectFeatureNumbers = storage.loadObjectFeatureFrameNumbers(args.databaseFilename, objectNumbers = args.nTrajectories)
     featureNumbers = []
     for numbers in objectFeatureNumbers.values():
         featureNumbers += numbers[:min(len(numbers), args.maxNObjectFeatures)]
     objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, 'feature', objectNumbers = featureNumbers, timeStep = args.positionSubsamplingRate)
 else:
-    objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, withFeatures = (args.trajectoryType == 'objectfeatures'), objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate)
+    trajectoryType = args.trajectoryType
+    objects = storage.loadTrajectoriesFromSqlite(args.databaseFilename, trajectoryType, objectNumbers = args.nTrajectories, timeStep = args.positionSubsamplingRate)
 
-    
 trajectories = [o.getPositions().asArray().T for o in objects]
+
+# load initial prototypes, if any    
 if args.inputPrototypeDatabaseFilename is not None:
     initialPrototypes = storage.loadPrototypesFromSqlite(args.inputPrototypeDatabaseFilename, True)
     trajectories = [p.getMovingObject().getPositions().asArray().T for p in initialPrototypes]+trajectories
@@ -99,7 +97,7 @@
             initialPrototypes[i].nMatchings += nMatchings
             prototypes.append(initialPrototypes[i])
         else:
-            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), prototypeType, nMatchings))
+            prototypes.append(moving.Prototype(args.databaseFilename, objects[i-len(initialPrototypes)].getNum(), trajectoryType, nMatchings))
 
     if args.outputPrototypeDatabaseFilename is None:
         outputPrototypeDatabaseFilename = args.databaseFilename
@@ -114,10 +112,8 @@
         np.savetxt(utils.removeExtension(args.databaseFilename)+'-prototype-similarities.txt.gz', similarities, '%.4f')
 
     labelsToProtoIndices = {protoId: i for i, protoId in enumerate(prototypeIndices)}
-    if args.assign and args.saveMatches: # or args.assign
-    # save in the db that contained originally the data
-        # retirer les assignations anterieures?
-        storage.savePrototypeAssignmentsToSqlite(args.databaseFilename, objects, [labelsToProtoIndices[l] for l in labels], prototypes)
+    if args.assign and args.saveMatches:
+        storage.savePrototypeAssignmentsToSqlite(args.databaseFilename, objects, trajectoryType, [labelsToProtoIndices[l] for l in labels], prototypes)
 
     if args.display and args.assign:
         from matplotlib.pyplot import figure, show, axis