comparison python/storage.py @ 919:7b3f2e0a2652

saving and loading prototype trajectories
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 05 Jul 2017 13:16:47 -0400
parents 3a06007a4bb7
children 499154254f37
comparison
equal deleted inserted replaced
918:3a06007a4bb7 919:7b3f2e0a2652
4 4
5 import utils, moving, events, indicators, shutil 5 import utils, moving, events, indicators, shutil
6 from base import VideoFilenameAddable 6 from base import VideoFilenameAddable
7 7
8 from os import path 8 from os import path
9 from copy import copy
9 import sqlite3, logging 10 import sqlite3, logging
10 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg 11 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg
11 from pandas import read_csv, merge 12 from pandas import read_csv, merge
12 13
13 14
51 elif dataType == 'bb': 52 elif dataType == 'bb':
52 dropTables(connection, ['bounding_boxes']) 53 dropTables(connection, ['bounding_boxes'])
53 elif dataType == 'pois': 54 elif dataType == 'pois':
54 dropTables(connection, ['gaussians2d', 'objects_pois']) 55 dropTables(connection, ['gaussians2d', 'objects_pois'])
55 elif dataType == 'prototype': 56 elif dataType == 'prototype':
56 dropTables(connection, ['prototypes']) 57 dropTables(connection, ['prototypes', 'prototype_positions', 'prototype_velocities'])
57 else: 58 else:
58 print('Unknown data type {} to delete from database'.format(dataType)) 59 print('Unknown data type {} to delete from database'.format(dataType))
59 connection.close() 60 connection.close()
60 else: 61 else:
61 print('{} does not exist'.format(filename)) 62 print('{} does not exist'.format(filename))
62 63
63 def tableExists(filename, tableName): 64 def tableExists(connection, tableName):
64 'indicates if the table exists in the database' 65 'indicates if the table exists in the database'
65 try: 66 try:
66 connection = sqlite3.connect(filename) 67 #connection = sqlite3.connect(filename)
67 cursor = connection.cursor() 68 cursor = connection.cursor()
68 cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'') 69 cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'')
69 return cursor.fetchone()[0] == 1 70 return cursor.fetchone()[0] == 1
70 except sqlite3.OperationalError as error: 71 except sqlite3.OperationalError as error:
71 printDBError(error) 72 printDBError(error)
72 73
73 def createTrajectoryTable(cursor, tableName): 74 def createTrajectoryTable(cursor, tableName):
74 if tableName in ['positions', 'velocities']: 75 if tableName.endswith('positions') or tableName.endswith('velocities'):
75 cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))") 76 cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))")
76 else: 77 else:
77 print('Unallowed name {} for trajectory table'.format(tableName)) 78 print('Unallowed name {} for trajectory table'.format(tableName))
78 79
79 def createObjectsTable(cursor): 80 def createObjectsTable(cursor):
264 userTypes = {} 265 userTypes = {}
265 for row in cursor: 266 for row in cursor:
266 userTypes[row[0]] = row[1] 267 userTypes[row[0]] = row[1]
267 return userTypes 268 return userTypes
268 269
269 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None): 270 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None, tablePrefix = None):
270 '''Loads the trajectories (in the general sense, 271 '''Loads the trajectories (in the general sense,
271 either features, objects (feature groups) or bounding box series) 272 either features, objects (feature groups) or bounding box series)
272 The number loaded is either the first objectNumbers objects, 273 The number loaded is either the first objectNumbers objects,
273 or the indices in objectNumbers from the database''' 274 or the indices in objectNumbers from the database'''
274 connection = sqlite3.connect(filename) 275 connection = sqlite3.connect(filename)
275 276
276 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers, timeStep) 277 if tablePrefix is None:
277 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers, timeStep) 278 prefix = ''
279 else:
280 prefix = tablePrefix + '_'
281 objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep)
282 objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep)
278 283
279 if len(objectVelocities) > 0: 284 if len(objectVelocities) > 0:
280 for o,v in zip(objects, objectVelocities): 285 for o,v in zip(objects, objectVelocities):
281 if o.getNum() == v.getNum(): 286 if o.getNum() == v.getNum():
282 o.velocities = v.positions 287 o.velocities = v.positions
588 ######################### 593 #########################
589 594
590 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None): 595 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None):
591 '''save the prototype indices 596 '''save the prototype indices
592 if objects is not None, the trajectories are also saved in prototype_positions and _velocities 597 if objects is not None, the trajectories are also saved in prototype_positions and _velocities
593 (prototypeIndices have to be in objects) 598 (prototypeIndices have to be in objects
599 objects will be saved as features, with the centroid trajectory as if it is a feature)
594 nMatchings, if not None, is a list of the number of matches 600 nMatchings, if not None, is a list of the number of matches
595 dbFilenames, if not None, is a list of the DB filenames''' 601 dbFilenames, if not None, is a list of the DB filenames
602
603 The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent'''
596 connection = sqlite3.connect(filename) 604 connection = sqlite3.connect(filename)
597 cursor = connection.cursor() 605 cursor = connection.cursor()
598 try: 606 try:
599 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))') 607 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))')
600 for i, protoId in enumerate(prototypeIndices): 608 for i, protoId in enumerate(prototypeIndices):
605 if dbFilenames is not None: 613 if dbFilenames is not None:
606 dbfn = dbFilenames[i] 614 dbfn = dbFilenames[i]
607 else: 615 else:
608 dbfn = filename 616 dbfn = filename
609 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) 617 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i))
610 #cursor.execute('SELECT * from sqlite_master WHERE type = \"table\" and name = \"{}\"'.format(tableNames[trajectoryType]))
611 if objects is not None: # save positions and velocities 618 if objects is not None: # save positions and velocities
612 pass 619 features = []
620 for i, o in enumerate(objects):
621 f = copy(o)
622 f.num = i
623 features.append(f)
624 saveTrajectoriesToTable(connection, features, 'feature', 'prototype')
613 except sqlite3.OperationalError as error: 625 except sqlite3.OperationalError as error:
614 printDBError(error) 626 printDBError(error)
615 connection.commit() 627 connection.commit()
616 connection.close() 628 connection.close()
617 629
624 cursor = connection.cursor() 636 cursor = connection.cursor()
625 prototypeIndices = [] 637 prototypeIndices = []
626 dbFilenames = [] 638 dbFilenames = []
627 trajectoryTypes = [] 639 trajectoryTypes = []
628 nMatchings = [] 640 nMatchings = []
641 trajectoryNumbers = []
629 try: 642 try:
630 cursor.execute('SELECT * FROM prototypes') 643 cursor.execute('SELECT * FROM prototypes')
631 for row in cursor: 644 for row in cursor:
632 prototypeIndices.append(row[0]) 645 prototypeIndices.append(row[0])
633 dbFilenames.append(row[1]) 646 dbFilenames.append(row[1])
634 trajectoryTypes.append(row[2]) 647 trajectoryTypes.append(row[2])
635 if row[3] is not None: 648 if row[3] is not None:
636 nMatchings.append(row[3]) 649 nMatchings.append(row[3])
650 if row[4] is not None:
651 trajectoryNumbers.append(row[4])
652 if tableExists(connection, 'prototype_positions'): # load prototypes trajectories
653 objects = loadTrajectoriesFromSqlite(filename, 'feature', trajectoryNumbers, tablePrefix = 'prototype')
654 else:
655 objects = None
637 except sqlite3.OperationalError as error: 656 except sqlite3.OperationalError as error:
638 printDBError(error) 657 printDBError(error)
639 connection.close() 658 connection.close()
640 if len(set(trajectoryTypes)) > 1: 659 if len(set(trajectoryTypes)) > 1:
641 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) 660 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes)))
642 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings 661 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings, objects
643 662
644 def savePOIs(filename, gmm, gmmType, gmmId): 663 def savePOIs(filename, gmm, gmmType, gmmId):
645 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) 664 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture)
646 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' 665 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories'''
647 connection = sqlite3.connect(filename) 666 connection = sqlite3.connect(filename)