Mercurial Hosting > traffic-intelligence
comparison python/storage.py @ 927:c030f735c594
added assignment of trajectories to prototypes and cleanup of insert queries
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Tue, 11 Jul 2017 17:56:23 -0400 |
parents | acb5379c5fd7 |
children | 0e63a918a1ca |
comparison
equal
deleted
inserted
replaced
926:dbd81710d515 | 927:c030f735c594 |
---|---|
52 elif dataType == 'bb': | 52 elif dataType == 'bb': |
53 dropTables(connection, ['bounding_boxes']) | 53 dropTables(connection, ['bounding_boxes']) |
54 elif dataType == 'pois': | 54 elif dataType == 'pois': |
55 dropTables(connection, ['gaussians2d', 'objects_pois']) | 55 dropTables(connection, ['gaussians2d', 'objects_pois']) |
56 elif dataType == 'prototype': | 56 elif dataType == 'prototype': |
57 dropTables(connection, ['prototypes']) | 57 dropTables(connection, ['prototypes', 'objects_prototypes']) |
58 else: | 58 else: |
59 print('Unknown data type {} to delete from database'.format(dataType)) | 59 print('Unknown data type {} to delete from database'.format(dataType)) |
60 connection.close() | 60 connection.close() |
61 else: | 61 else: |
62 print('{} does not exist'.format(filename)) | 62 print('{} does not exist'.format(filename)) |
98 | 98 |
99 def createIndicatorTable(cursor): | 99 def createIndicatorTable(cursor): |
100 cursor.execute('CREATE TABLE IF NOT EXISTS indicators (interaction_id INTEGER, indicator_type INTEGER, frame_number INTEGER, value REAL, FOREIGN KEY(interaction_id) REFERENCES interactions(id), PRIMARY KEY(interaction_id, indicator_type, frame_number))') | 100 cursor.execute('CREATE TABLE IF NOT EXISTS indicators (interaction_id INTEGER, indicator_type INTEGER, frame_number INTEGER, value REAL, FOREIGN KEY(interaction_id) REFERENCES interactions(id), PRIMARY KEY(interaction_id, indicator_type, frame_number))') |
101 | 101 |
102 def insertTrajectoryQuery(tableName): | 102 def insertTrajectoryQuery(tableName): |
103 return "INSERT INTO "+tableName+" (trajectory_id, frame_number, x_coordinate, y_coordinate) VALUES (?,?,?,?)" | 103 return "INSERT INTO "+tableName+" VALUES (?,?,?,?)" |
104 | 104 |
105 def insertObjectQuery(): | 105 def insertObjectQuery(): |
106 return "INSERT INTO objects (object_id, road_user_type, n_objects) VALUES (?,?,?)" | 106 return "INSERT INTO objects VALUES (?,?,?)" |
107 | 107 |
108 def insertObjectFeatureQuery(): | 108 def insertObjectFeatureQuery(): |
109 return "INSERT INTO objects_features (object_id, trajectory_id) VALUES (?,?)" | 109 return "INSERT INTO objects_features VALUES (?,?)" |
110 | 110 |
111 def createIndex(connection, tableName, columnName, unique = False): | 111 def createIndex(connection, tableName, columnName, unique = False): |
112 '''Creates an index for the column in the table | 112 '''Creates an index for the column in the table |
113 I will make querying with a condition on this column faster''' | 113 I will make querying with a condition on this column faster''' |
114 try: | 114 try: |
149 else: | 149 else: |
150 print("Argument minmax unknown: {}".format(minmax)) | 150 print("Argument minmax unknown: {}".format(minmax)) |
151 return cursor.fetchone()[0] | 151 return cursor.fetchone()[0] |
152 except sqlite3.OperationalError as error: | 152 except sqlite3.OperationalError as error: |
153 printDBError(error) | 153 printDBError(error) |
154 | |
155 def loadPrototypeMatchIndexesFromSqlite(filename): | |
156 """ | |
157 This function loads the prototypes table in the database of name <filename>. | |
158 It returns a list of tuples representing matching ids : [(prototype_id, matched_trajectory_id),...] | |
159 """ | |
160 matched_indexes = [] | |
161 | |
162 connection = sqlite3.connect(filename) | |
163 cursor = connection.cursor() | |
164 | |
165 try: | |
166 cursor.execute('SELECT * from prototypes order by prototype_id, trajectory_id_matched') | |
167 except sqlite3.OperationalError as error: | |
168 printDBError(error) | |
169 return [] | |
170 | |
171 for row in cursor: | |
172 matched_indexes.append((row[0],row[1])) | |
173 | |
174 connection.close() | |
175 return matched_indexes | |
176 | 154 |
177 def getObjectCriteria(objectNumbers): | 155 def getObjectCriteria(objectNumbers): |
178 if objectNumbers is None: | 156 if objectNumbers is None: |
179 query = '' | 157 query = '' |
180 elif type(objectNumbers) == int: | 158 elif type(objectNumbers) == int: |
430 cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum)) | 408 cursor.execute(objectFeatureQuery, (obj.getNum(), featureNum)) |
431 cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1)) | 409 cursor.execute(objectQuery, (obj.getNum(), obj.getUserType(), 1)) |
432 # Parse curvilinear position structure | 410 # Parse curvilinear position structure |
433 elif(trajectoryType == 'curvilinear'): | 411 elif(trajectoryType == 'curvilinear'): |
434 createCurvilinearTrajectoryTable(cursor) | 412 createCurvilinearTrajectoryTable(cursor) |
435 curvilinearQuery = "insert into curvilinear_positions (trajectory_id, frame_number, s_coordinate, y_coordinate, lane) values (?,?,?,?,?)" | 413 curvilinearQuery = "INSERT INTO curvilinear_positions VALUES (?,?,?,?,?)" |
436 for obj in objects: | 414 for obj in objects: |
437 num = obj.getNum() | 415 num = obj.getNum() |
438 frameNum = obj.getFirstInstant() | 416 frameNum = obj.getFirstInstant() |
439 for p in obj.getCurvilinearPositions(): | 417 for p in obj.getCurvilinearPositions(): |
440 cursor.execute(curvilinearQuery, (num, frameNum, p[0], p[1], p[2])) | 418 cursor.execute(curvilinearQuery, (num, frameNum, p[0], p[1], p[2])) |
483 | 461 |
484 def saveInteraction(cursor, interaction): | 462 def saveInteraction(cursor, interaction): |
485 roadUserNumbers = list(interaction.getRoadUserNumbers()) | 463 roadUserNumbers = list(interaction.getRoadUserNumbers()) |
486 cursor.execute('INSERT INTO interactions VALUES({}, {}, {}, {}, {})'.format(interaction.getNum(), roadUserNumbers[0], roadUserNumbers[1], interaction.getFirstInstant(), interaction.getLastInstant())) | 464 cursor.execute('INSERT INTO interactions VALUES({}, {}, {}, {}, {})'.format(interaction.getNum(), roadUserNumbers[0], roadUserNumbers[1], interaction.getFirstInstant(), interaction.getLastInstant())) |
487 | 465 |
488 def saveInteractions(filename, interactions): | 466 def saveInteractionsToSqlite(filename, interactions): |
489 'Saves the interactions in the table' | 467 'Saves the interactions in the table' |
490 connection = sqlite3.connect(filename) | 468 connection = sqlite3.connect(filename) |
491 cursor = connection.cursor() | 469 cursor = connection.cursor() |
492 try: | 470 try: |
493 createInteractionTable(cursor) | 471 createInteractionTable(cursor) |
501 def saveIndicator(cursor, interactionNum, indicator): | 479 def saveIndicator(cursor, interactionNum, indicator): |
502 for instant in indicator.getTimeInterval(): | 480 for instant in indicator.getTimeInterval(): |
503 if indicator[instant]: | 481 if indicator[instant]: |
504 cursor.execute('INSERT INTO indicators VALUES({}, {}, {}, {})'.format(interactionNum, events.Interaction.indicatorNameToIndices[indicator.getName()], instant, indicator[instant])) | 482 cursor.execute('INSERT INTO indicators VALUES({}, {}, {}, {})'.format(interactionNum, events.Interaction.indicatorNameToIndices[indicator.getName()], instant, indicator[instant])) |
505 | 483 |
506 def saveIndicators(filename, interactions, indicatorNames = events.Interaction.indicatorNames): | 484 def saveIndicatorsToSqlite(filename, interactions, indicatorNames = events.Interaction.indicatorNames): |
507 'Saves the indicator values in the table' | 485 'Saves the indicator values in the table' |
508 connection = sqlite3.connect(filename) | 486 connection = sqlite3.connect(filename) |
509 cursor = connection.cursor() | 487 cursor = connection.cursor() |
510 try: | 488 try: |
511 createInteractionTable(cursor) | 489 createInteractionTable(cursor) |
519 except sqlite3.OperationalError as error: | 497 except sqlite3.OperationalError as error: |
520 printDBError(error) | 498 printDBError(error) |
521 connection.commit() | 499 connection.commit() |
522 connection.close() | 500 connection.close() |
523 | 501 |
524 def loadInteractions(filename): | 502 def loadInteractionsFromSqlite(filename): |
525 '''Loads interaction and their indicators | 503 '''Loads interaction and their indicators |
526 | 504 |
527 TODO choose the interactions to load''' | 505 TODO choose the interactions to load''' |
528 interactions = [] | 506 interactions = [] |
529 connection = sqlite3.connect(filename) | 507 connection = sqlite3.connect(filename) |
530 cursor = connection.cursor() | 508 cursor = connection.cursor() |
531 try: | 509 try: |
532 cursor.execute('select INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number') | 510 cursor.execute('SELECT INT.id, INT.object_id1, INT.object_id2, INT.first_frame_number, INT.last_frame_number, IND.indicator_type, IND.frame_number, IND.value from interactions INT, indicators IND WHERE INT.id = IND.interaction_id ORDER BY INT.id, IND.indicator_type, IND.frame_number') |
533 interactionNum = -1 | 511 interactionNum = -1 |
534 indicatorTypeNum = -1 | 512 indicatorTypeNum = -1 |
535 tmpIndicators = {} | 513 tmpIndicators = {} |
536 for row in cursor: | 514 for row in cursor: |
537 if row[0] != interactionNum: | 515 if row[0] != interactionNum: |
595 def savePrototypesToSqlite(filename, prototypes): | 573 def savePrototypesToSqlite(filename, prototypes): |
596 '''save the prototypes (a prototype is defined by a filename, a number and type''' | 574 '''save the prototypes (a prototype is defined by a filename, a number and type''' |
597 connection = sqlite3.connect(filename) | 575 connection = sqlite3.connect(filename) |
598 cursor = connection.cursor() | 576 cursor = connection.cursor() |
599 try: | 577 try: |
600 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (filename VARCHAR, id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (filename, id))') | 578 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, PRIMARY KEY (prototype_filename, prototype_id, trajectory_type))') |
601 for p in prototypes: | 579 for p in prototypes: |
602 cursor.execute('INSERT INTO prototypes (filename, id, trajectory_type, nmatchings) VALUES (?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) | 580 cursor.execute('INSERT INTO prototypes VALUES(?,?,?,?)', (p.getFilename(), p.getNum(), p.getTrajectoryType(), p.getNMatchings())) |
603 except sqlite3.OperationalError as error: | 581 except sqlite3.OperationalError as error: |
604 printDBError(error) | 582 printDBError(error) |
605 connection.commit() | 583 connection.commit() |
606 connection.close() | 584 connection.close() |
607 | 585 |
608 def savePrototypeAssignments(filename, objects): | 586 def savePrototypeAssignmentsToSqlite(filename, objects, labels, prototypes): |
609 pass | 587 connection = sqlite3.connect(filename) |
588 cursor = connection.cursor() | |
589 try: | |
590 cursor.execute('CREATE TABLE IF NOT EXISTS objects_prototypes (object_id INTEGER, prototype_filename VARCHAR, prototype_id INTEGER, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), PRIMARY KEY(object_id, prototype_filename, prototype_id, trajectory_type))') | |
591 for obj, label in zip(objects, labels): | |
592 proto = prototypes[label] | |
593 cursor.execute('INSERT INTO objects_prototypes VALUES(?,?,?,?)', (obj.getNum(), proto.getFilename(), proto.getNum(), proto.getTrajectoryType())) | |
594 except sqlite3.OperationalError as error: | |
595 printDBError(error) | |
596 connection.commit() | |
597 connection.close() | |
610 | 598 |
611 def loadPrototypesFromSqlite(filename, withTrajectories = True): | 599 def loadPrototypesFromSqlite(filename, withTrajectories = True): |
612 'Loads prototype ids and matchings (if stored)' | 600 'Loads prototype ids and matchings (if stored)' |
613 connection = sqlite3.connect(filename) | 601 connection = sqlite3.connect(filename) |
614 cursor = connection.cursor() | 602 cursor = connection.cursor() |
636 connection.close() | 624 connection.close() |
637 if len(set([p.getTrajectoryType() for p in prototypes])) > 1: | 625 if len(set([p.getTrajectoryType() for p in prototypes])) > 1: |
638 print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes]))) | 626 print('Different types of prototypes in database ({}).'.format(set([p.getTrajectoryType() for p in prototypes]))) |
639 return prototypes | 627 return prototypes |
640 | 628 |
641 def savePOIs(filename, gmm, gmmType, gmmId): | 629 def savePOIsToSqlite(filename, gmm, gmmType, gmmId): |
642 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) | 630 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
643 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' | 631 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
644 connection = sqlite3.connect(filename) | 632 connection = sqlite3.connect(filename) |
645 cursor = connection.cursor() | 633 cursor = connection.cursor() |
646 if gmmType not in ['beginning', 'end']: | 634 if gmmType not in ['beginning', 'end']: |
648 import sys | 636 import sys |
649 sys.exit() | 637 sys.exit() |
650 try: | 638 try: |
651 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') | 639 cursor.execute('CREATE TABLE IF NOT EXISTS gaussians2d (poi_id INTEGER, id INTEGER, type VARCHAR, x_center REAL, y_center REAL, covariance VARCHAR, covariance_type VARCHAR, weight, precisions_cholesky VARCHAR, PRIMARY KEY(poi_id, id))') |
652 for i in xrange(gmm.n_components): | 640 for i in xrange(gmm.n_components): |
653 cursor.execute('INSERT INTO gaussians2d VALUES({}, {}, \'{}\', {}, {}, \'{}\', \'{}\', {}, \'{}\')'.format(gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) | 641 cursor.execute('INSERT INTO gaussians2d VALUES(?,?,?,?,?,?,?,?,?)', (gmmId, i, gmmType, gmm.means_[i][0], gmm.means_[i][1], str(gmm.covariances_[i].tolist()), gmm.covariance_type, gmm.weights_[i], str(gmm.precisions_cholesky_[i].tolist()))) |
654 connection.commit() | 642 connection.commit() |
655 except sqlite3.OperationalError as error: | 643 except sqlite3.OperationalError as error: |
656 printDBError(error) | 644 printDBError(error) |
657 connection.close() | 645 connection.close() |
658 | 646 |
659 def savePOIAssignments(filename, objects): | 647 def savePOIAssignmentsToSqlite(filename, objects): |
660 'save the od fields of objects' | 648 'save the od fields of objects' |
661 connection = sqlite3.connect(filename) | 649 connection = sqlite3.connect(filename) |
662 cursor = connection.cursor() | 650 cursor = connection.cursor() |
663 try: | 651 try: |
664 cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))') | 652 cursor.execute('CREATE TABLE IF NOT EXISTS objects_pois (object_id INTEGER, origin_poi_id INTEGER, destination_poi_id INTEGER, PRIMARY KEY(object_id))') |
665 for o in objects: | 653 for o in objects: |
666 cursor.execute('INSERT INTO objects_pois VALUES({},{},{})'.format(o.getNum(), o.od[0], o.od[1])) | 654 cursor.execute('INSERT INTO objects_pois VALUES(?,?,?)', (o.getNum(), o.od[0], o.od[1])) |
667 connection.commit() | 655 connection.commit() |
668 except sqlite3.OperationalError as error: | 656 except sqlite3.OperationalError as error: |
669 printDBError(error) | 657 printDBError(error) |
670 connection.close() | 658 connection.close() |
671 | 659 |
672 def loadPOIs(filename): | 660 def loadPOIsFromSqlite(filename): |
673 'Loads all 2D Gaussians in the database' | 661 'Loads all 2D Gaussians in the database' |
674 from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields | 662 from sklearn import mixture # todo if not avalaible, load data in duck-typed class with same fields |
675 from ast import literal_eval | 663 from ast import literal_eval |
676 connection = sqlite3.connect(filename) | 664 connection = sqlite3.connect(filename) |
677 cursor = connection.cursor() | 665 cursor = connection.cursor() |