Mercurial Hosting > traffic-intelligence
comparison python/storage.py @ 921:630934595871
work in progress with prototype class
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 05 Jul 2017 18:01:43 -0400 |
parents | 499154254f37 |
children | acb5379c5fd7 |
comparison
equal
deleted
inserted
replaced
920:499154254f37 | 921:630934595871 |
---|---|
590 | 590 |
591 ######################### | 591 ######################### |
592 # saving and loading for scene interpretation: POIs and Prototypes | 592 # saving and loading for scene interpretation: POIs and Prototypes |
593 ######################### | 593 ######################### |
594 | 594 |
595 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, nMatchings = None, dbFilenames = None): | 595 def savePrototypesToSqlite(filename, prototypes): |
596 '''save the prototype indices | 596 '''save the prototypes (a prototype is defined by a filename, a number and type''' |
597 if objects is not None, the trajectories are also saved in prototype_positions and _velocities | 597 connection = sqlite3.connect(filename) |
598 (prototypeIndices have to be in objects | 598 cursor = connection.cursor() |
599 objects will be saved as features, with the centroid trajectory as if it is a feature) | 599 try: |
600 nMatchings, if not None, is a list of the number of matches | 600 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (dbfilename VARCHAR, id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (id, dbfilename))') |
601 dbFilenames, if not None, is a list of the DB filenames | 601 for p in prototypes: |
602 | 602 cursor.execute('INSERT INTO prototypes (dbfilename, id, trajectory_type, nmatchings) VALUES (?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) |
603 The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent''' | |
604 connection = sqlite3.connect(filename) | |
605 cursor = connection.cursor() | |
606 try: | |
607 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (id, dbfilename))') | |
608 for i, protoId in enumerate(prototypeIndices): | |
609 if nMatchings is not None: | |
610 n = nMatchings[i] | |
611 else: | |
612 n = 'NULL' | |
613 if dbFilenames is not None: | |
614 dbfn = dbFilenames[i] | |
615 else: | |
616 dbfn = filename | |
617 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings) VALUES (?,?,?,?)', (protoId, dbfn, trajectoryType, n)) | |
618 except sqlite3.OperationalError as error: | 603 except sqlite3.OperationalError as error: |
619 printDBError(error) | 604 printDBError(error) |
620 connection.commit() | 605 connection.commit() |
621 connection.close() | 606 connection.close() |
622 | 607 |
625 | 610 |
626 def loadPrototypesFromSqlite(filename, withTrajectories = True): | 611 def loadPrototypesFromSqlite(filename, withTrajectories = True): |
627 'Loads prototype ids and matchings (if stored)' | 612 'Loads prototype ids and matchings (if stored)' |
628 connection = sqlite3.connect(filename) | 613 connection = sqlite3.connect(filename) |
629 cursor = connection.cursor() | 614 cursor = connection.cursor() |
630 prototypeIndices = [] | 615 prototypes = [] |
631 dbFilenames = [] | |
632 trajectoryTypes = [] | |
633 nMatchings = [] | |
634 objects = [] | 616 objects = [] |
635 try: | 617 try: |
636 cursor.execute('SELECT * FROM prototypes') | 618 cursor.execute('SELECT * FROM prototypes') |
637 for row in cursor: | 619 for row in cursor: |
638 prototypeIndices.append(row[0]) | 620 prototypes.append(moving.Prototype(row[0], row[1], row[2], row[3])) |
639 dbFilenames.append(row[1]) | |
640 trajectoryTypes.append(row[2]) | |
641 if row[3] is not None: | |
642 nMatchings.append(row[3]) | |
643 if withTrajectories: | 621 if withTrajectories: |
644 loadingInformation = {} | 622 for p in prototypes: |
645 for dbfn, trajType, protoId in zip(dbFilenames, trajectoryTypes, prototypeIndices): | 623 p.setMovingObject(loadTrajectoriesFromSqlite(p.getFilename(), p.getTrajectoryType(), [p.getNum()])[0]) |
646 if (dbfn, trajType) in loadingInformation: | 624 # loadingInformation = {} # complicated slightly optimized |
647 loadingInformation[(dbfn, trajType)].append(protoId) | 625 # for p in prototypes: |
648 else: | 626 # dbfn = p.getFilename() |
649 loadingInformation[(dbfn, trajType)] = [protoId] | 627 # trajType = p.getTrajectoryType() |
650 for k, v in loadingInformation.iteritems(): | 628 # if (dbfn, trajType) in loadingInformation: |
651 objects += loadTrajectoriesFromSqlite(k[0], k[1], v) | 629 # loadingInformation[(dbfn, trajType)].append(p) |
652 except sqlite3.OperationalError as error: | 630 # else: |
653 printDBError(error) | 631 # loadingInformation[(dbfn, trajType)] = [p] |
654 connection.close() | 632 # for k, v in loadingInformation.iteritems(): |
655 if len(set(trajectoryTypes)) > 1: | 633 # objects += loadTrajectoriesFromSqlite(k[0], k[1], [p.getNum() for p in v]) |
656 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) | 634 except sqlite3.OperationalError as error: |
657 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings, objects | 635 printDBError(error) |
636 connection.close() | |
637 if len(set([p.getTrajectoryType() for p in prototypes])) > 1: | |
638 print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes]))) | |
639 return prototypes | |
658 | 640 |
659 def savePOIs(filename, gmm, gmmType, gmmId): | 641 def savePOIs(filename, gmm, gmmType, gmmId): |
660 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) | 642 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
661 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' | 643 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
662 connection = sqlite3.connect(filename) | 644 connection = sqlite3.connect(filename) |