comparison python/storage.py @ 920:499154254f37

improved prototype loading
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Wed, 05 Jul 2017 16:30:04 -0400
parents 7b3f2e0a2652
children 630934595871
comparison
equal deleted inserted replaced
919:7b3f2e0a2652 920:499154254f37
52 elif dataType == 'bb': 52 elif dataType == 'bb':
53 dropTables(connection, ['bounding_boxes']) 53 dropTables(connection, ['bounding_boxes'])
54 elif dataType == 'pois': 54 elif dataType == 'pois':
55 dropTables(connection, ['gaussians2d', 'objects_pois']) 55 dropTables(connection, ['gaussians2d', 'objects_pois'])
56 elif dataType == 'prototype': 56 elif dataType == 'prototype':
57 dropTables(connection, ['prototypes', 'prototype_positions', 'prototype_velocities']) 57 dropTables(connection, ['prototypes'])
58 else: 58 else:
59 print('Unknown data type {} to delete from database'.format(dataType)) 59 print('Unknown data type {} to delete from database'.format(dataType))
60 connection.close() 60 connection.close()
61 else: 61 else:
62 print('{} does not exist'.format(filename)) 62 print('{} does not exist'.format(filename))
590 590
591 ######################### 591 #########################
592 # saving and loading for scene interpretation: POIs and Prototypes 592 # saving and loading for scene interpretation: POIs and Prototypes
593 ######################### 593 #########################
594 594
595 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None): 595 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, nMatchings = None, dbFilenames = None):
596 '''save the prototype indices 596 '''save the prototype indices
597 if objects is not None, the trajectories are also saved in prototype_positions and _velocities 597 if objects is not None, the trajectories are also saved in prototype_positions and _velocities
598 (prototypeIndices have to be in objects 598 (prototypeIndices have to be in objects
599 objects will be saved as features, with the centroid trajectory as if it is a feature) 599 objects will be saved as features, with the centroid trajectory as if it is a feature)
600 nMatchings, if not None, is a list of the number of matches 600 nMatchings, if not None, is a list of the number of matches
602 602
603 The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent''' 603 The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent'''
604 connection = sqlite3.connect(filename) 604 connection = sqlite3.connect(filename)
605 cursor = connection.cursor() 605 cursor = connection.cursor()
606 try: 606 try:
607 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))') 607 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (id, dbfilename))')
608 for i, protoId in enumerate(prototypeIndices): 608 for i, protoId in enumerate(prototypeIndices):
609 if nMatchings is not None: 609 if nMatchings is not None:
610 n = nMatchings[i] 610 n = nMatchings[i]
611 else: 611 else:
612 n = 'NULL' 612 n = 'NULL'
613 if dbFilenames is not None: 613 if dbFilenames is not None:
614 dbfn = dbFilenames[i] 614 dbfn = dbFilenames[i]
615 else: 615 else:
616 dbfn = filename 616 dbfn = filename
617 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) 617 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings) VALUES (?,?,?,?)', (protoId, dbfn, trajectoryType, n))
618 if objects is not None: # save positions and velocities
619 features = []
620 for i, o in enumerate(objects):
621 f = copy(o)
622 f.num = i
623 features.append(f)
624 saveTrajectoriesToTable(connection, features, 'feature', 'prototype')
625 except sqlite3.OperationalError as error: 618 except sqlite3.OperationalError as error:
626 printDBError(error) 619 printDBError(error)
627 connection.commit() 620 connection.commit()
628 connection.close() 621 connection.close()
629 622
630 def savePrototypeAssignments(filename, objects): 623 def savePrototypeAssignments(filename, objects):
631 pass 624 pass
632 625
633 def loadPrototypesFromSqlite(filename): 626 def loadPrototypesFromSqlite(filename, withTrajectories = True):
634 'Loads prototype ids and matchings (if stored)' 627 'Loads prototype ids and matchings (if stored)'
635 connection = sqlite3.connect(filename) 628 connection = sqlite3.connect(filename)
636 cursor = connection.cursor() 629 cursor = connection.cursor()
637 prototypeIndices = [] 630 prototypeIndices = []
638 dbFilenames = [] 631 dbFilenames = []
639 trajectoryTypes = [] 632 trajectoryTypes = []
640 nMatchings = [] 633 nMatchings = []
641 trajectoryNumbers = [] 634 objects = []
642 try: 635 try:
643 cursor.execute('SELECT * FROM prototypes') 636 cursor.execute('SELECT * FROM prototypes')
644 for row in cursor: 637 for row in cursor:
645 prototypeIndices.append(row[0]) 638 prototypeIndices.append(row[0])
646 dbFilenames.append(row[1]) 639 dbFilenames.append(row[1])
647 trajectoryTypes.append(row[2]) 640 trajectoryTypes.append(row[2])
648 if row[3] is not None: 641 if row[3] is not None:
649 nMatchings.append(row[3]) 642 nMatchings.append(row[3])
650 if row[4] is not None: 643 if withTrajectories:
651 trajectoryNumbers.append(row[4]) 644 loadingInformation = {}
652 if tableExists(connection, 'prototype_positions'): # load prototypes trajectories 645 for dbfn, trajType, protoId in zip(dbFilenames, trajectoryTypes, prototypeIndices):
653 objects = loadTrajectoriesFromSqlite(filename, 'feature', trajectoryNumbers, tablePrefix = 'prototype') 646 if (dbfn, trajType) in loadingInformation:
654 else: 647 loadingInformation[(dbfn, trajType)].append(protoId)
655 objects = None 648 else:
649 loadingInformation[(dbfn, trajType)] = [protoId]
650 for k, v in loadingInformation.iteritems():
651 objects += loadTrajectoriesFromSqlite(k[0], k[1], v)
656 except sqlite3.OperationalError as error: 652 except sqlite3.OperationalError as error:
657 printDBError(error) 653 printDBError(error)
658 connection.close() 654 connection.close()
659 if len(set(trajectoryTypes)) > 1: 655 if len(set(trajectoryTypes)) > 1:
660 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) 656 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes)))