Mercurial Hosting > traffic-intelligence
comparison python/storage.py @ 919:7b3f2e0a2652
saving and loading prototype trajectories
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Wed, 05 Jul 2017 13:16:47 -0400 |
parents | 3a06007a4bb7 |
children | 499154254f37 |
comparison
equal
deleted
inserted
replaced
918:3a06007a4bb7 | 919:7b3f2e0a2652 |
---|---|
4 | 4 |
5 import utils, moving, events, indicators, shutil | 5 import utils, moving, events, indicators, shutil |
6 from base import VideoFilenameAddable | 6 from base import VideoFilenameAddable |
7 | 7 |
8 from os import path | 8 from os import path |
9 from copy import copy | |
9 import sqlite3, logging | 10 import sqlite3, logging |
10 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg | 11 from numpy import log, min as npmin, max as npmax, round as npround, array, sum as npsum, loadtxt, floor as npfloor, ceil as npceil, linalg |
11 from pandas import read_csv, merge | 12 from pandas import read_csv, merge |
12 | 13 |
13 | 14 |
51 elif dataType == 'bb': | 52 elif dataType == 'bb': |
52 dropTables(connection, ['bounding_boxes']) | 53 dropTables(connection, ['bounding_boxes']) |
53 elif dataType == 'pois': | 54 elif dataType == 'pois': |
54 dropTables(connection, ['gaussians2d', 'objects_pois']) | 55 dropTables(connection, ['gaussians2d', 'objects_pois']) |
55 elif dataType == 'prototype': | 56 elif dataType == 'prototype': |
56 dropTables(connection, ['prototypes']) | 57 dropTables(connection, ['prototypes', 'prototype_positions', 'prototype_velocities']) |
57 else: | 58 else: |
58 print('Unknown data type {} to delete from database'.format(dataType)) | 59 print('Unknown data type {} to delete from database'.format(dataType)) |
59 connection.close() | 60 connection.close() |
60 else: | 61 else: |
61 print('{} does not exist'.format(filename)) | 62 print('{} does not exist'.format(filename)) |
62 | 63 |
63 def tableExists(filename, tableName): | 64 def tableExists(connection, tableName): |
64 'indicates if the table exists in the database' | 65 'indicates if the table exists in the database' |
65 try: | 66 try: |
66 connection = sqlite3.connect(filename) | 67 #connection = sqlite3.connect(filename) |
67 cursor = connection.cursor() | 68 cursor = connection.cursor() |
68 cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'') | 69 cursor.execute('SELECT COUNT(*) FROM SQLITE_MASTER WHERE type = \'table\' AND name = \''+tableName+'\'') |
69 return cursor.fetchone()[0] == 1 | 70 return cursor.fetchone()[0] == 1 |
70 except sqlite3.OperationalError as error: | 71 except sqlite3.OperationalError as error: |
71 printDBError(error) | 72 printDBError(error) |
72 | 73 |
73 def createTrajectoryTable(cursor, tableName): | 74 def createTrajectoryTable(cursor, tableName): |
74 if tableName in ['positions', 'velocities']: | 75 if tableName.endswith('positions') or tableName.endswith('velocities'): |
75 cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))") | 76 cursor.execute("CREATE TABLE IF NOT EXISTS "+tableName+" (trajectory_id INTEGER, frame_number INTEGER, x_coordinate REAL, y_coordinate REAL, PRIMARY KEY(trajectory_id, frame_number))") |
76 else: | 77 else: |
77 print('Unallowed name {} for trajectory table'.format(tableName)) | 78 print('Unallowed name {} for trajectory table'.format(tableName)) |
78 | 79 |
79 def createObjectsTable(cursor): | 80 def createObjectsTable(cursor): |
264 userTypes = {} | 265 userTypes = {} |
265 for row in cursor: | 266 for row in cursor: |
266 userTypes[row[0]] = row[1] | 267 userTypes[row[0]] = row[1] |
267 return userTypes | 268 return userTypes |
268 | 269 |
269 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None): | 270 def loadTrajectoriesFromSqlite(filename, trajectoryType, objectNumbers = None, withFeatures = False, timeStep = None, tablePrefix = None): |
270 '''Loads the trajectories (in the general sense, | 271 '''Loads the trajectories (in the general sense, |
271 either features, objects (feature groups) or bounding box series) | 272 either features, objects (feature groups) or bounding box series) |
272 The number loaded is either the first objectNumbers objects, | 273 The number loaded is either the first objectNumbers objects, |
273 or the indices in objectNumbers from the database''' | 274 or the indices in objectNumbers from the database''' |
274 connection = sqlite3.connect(filename) | 275 connection = sqlite3.connect(filename) |
275 | 276 |
276 objects = loadTrajectoriesFromTable(connection, 'positions', trajectoryType, objectNumbers, timeStep) | 277 if tablePrefix is None: |
277 objectVelocities = loadTrajectoriesFromTable(connection, 'velocities', trajectoryType, objectNumbers, timeStep) | 278 prefix = '' |
279 else: | |
280 prefix = tablePrefix + '_' | |
281 objects = loadTrajectoriesFromTable(connection, prefix+'positions', trajectoryType, objectNumbers, timeStep) | |
282 objectVelocities = loadTrajectoriesFromTable(connection, prefix+'velocities', trajectoryType, objectNumbers, timeStep) | |
278 | 283 |
279 if len(objectVelocities) > 0: | 284 if len(objectVelocities) > 0: |
280 for o,v in zip(objects, objectVelocities): | 285 for o,v in zip(objects, objectVelocities): |
281 if o.getNum() == v.getNum(): | 286 if o.getNum() == v.getNum(): |
282 o.velocities = v.positions | 287 o.velocities = v.positions |
588 ######################### | 593 ######################### |
589 | 594 |
590 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None): | 595 def savePrototypesToSqlite(filename, prototypeIndices, trajectoryType, objects = None, nMatchings = None, dbFilenames = None): |
591 '''save the prototype indices | 596 '''save the prototype indices |
592 if objects is not None, the trajectories are also saved in prototype_positions and _velocities | 597 if objects is not None, the trajectories are also saved in prototype_positions and _velocities |
593 (prototypeIndices have to be in objects) | 598 (prototypeIndices have to be in objects |
599 objects will be saved as features, with the centroid trajectory as if it is a feature) | |
594 nMatchings, if not None, is a list of the number of matches | 600 nMatchings, if not None, is a list of the number of matches |
595 dbFilenames, if not None, is a list of the DB filenames''' | 601 dbFilenames, if not None, is a list of the DB filenames |
602 | |
603 The order of prototypeIndices, objects, nMatchings and dbFilenames should be consistent''' | |
596 connection = sqlite3.connect(filename) | 604 connection = sqlite3.connect(filename) |
597 cursor = connection.cursor() | 605 cursor = connection.cursor() |
598 try: | 606 try: |
599 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))') | 607 cursor.execute('CREATE TABLE IF NOT EXISTS prototypes (id INTEGER, dbfilename VARCHAR, trajectory_type VARCHAR CHECK (trajectory_type IN (\"feature\", \"object\")), nmatchings INTEGER, positions_id INTEGER, PRIMARY KEY (id, dbfilename))') |
600 for i, protoId in enumerate(prototypeIndices): | 608 for i, protoId in enumerate(prototypeIndices): |
605 if dbFilenames is not None: | 613 if dbFilenames is not None: |
606 dbfn = dbFilenames[i] | 614 dbfn = dbFilenames[i] |
607 else: | 615 else: |
608 dbfn = filename | 616 dbfn = filename |
609 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) | 617 cursor.execute('INSERT INTO prototypes (id, dbfilename, trajectory_type, nmatchings, positions_id) VALUES ({},\"{}\",\"{}\",{}, {})'.format(protoId, dbfn, trajectoryType, n, i)) |
610 #cursor.execute('SELECT * from sqlite_master WHERE type = \"table\" and name = \"{}\"'.format(tableNames[trajectoryType])) | |
611 if objects is not None: # save positions and velocities | 618 if objects is not None: # save positions and velocities |
612 pass | 619 features = [] |
620 for i, o in enumerate(objects): | |
621 f = copy(o) | |
622 f.num = i | |
623 features.append(f) | |
624 saveTrajectoriesToTable(connection, features, 'feature', 'prototype') | |
613 except sqlite3.OperationalError as error: | 625 except sqlite3.OperationalError as error: |
614 printDBError(error) | 626 printDBError(error) |
615 connection.commit() | 627 connection.commit() |
616 connection.close() | 628 connection.close() |
617 | 629 |
624 cursor = connection.cursor() | 636 cursor = connection.cursor() |
625 prototypeIndices = [] | 637 prototypeIndices = [] |
626 dbFilenames = [] | 638 dbFilenames = [] |
627 trajectoryTypes = [] | 639 trajectoryTypes = [] |
628 nMatchings = [] | 640 nMatchings = [] |
641 trajectoryNumbers = [] | |
629 try: | 642 try: |
630 cursor.execute('SELECT * FROM prototypes') | 643 cursor.execute('SELECT * FROM prototypes') |
631 for row in cursor: | 644 for row in cursor: |
632 prototypeIndices.append(row[0]) | 645 prototypeIndices.append(row[0]) |
633 dbFilenames.append(row[1]) | 646 dbFilenames.append(row[1]) |
634 trajectoryTypes.append(row[2]) | 647 trajectoryTypes.append(row[2]) |
635 if row[3] is not None: | 648 if row[3] is not None: |
636 nMatchings.append(row[3]) | 649 nMatchings.append(row[3]) |
650 if row[4] is not None: | |
651 trajectoryNumbers.append(row[4]) | |
652 if tableExists(connection, 'prototype_positions'): # load prototypes trajectories | |
653 objects = loadTrajectoriesFromSqlite(filename, 'feature', trajectoryNumbers, tablePrefix = 'prototype') | |
654 else: | |
655 objects = None | |
637 except sqlite3.OperationalError as error: | 656 except sqlite3.OperationalError as error: |
638 printDBError(error) | 657 printDBError(error) |
639 connection.close() | 658 connection.close() |
640 if len(set(trajectoryTypes)) > 1: | 659 if len(set(trajectoryTypes)) > 1: |
641 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) | 660 print('Different types of prototypes in database ({}).'.format(set(trajectoryTypes))) |
642 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings | 661 return prototypeIndices, dbFilenames, trajectoryTypes, nMatchings, objects |
643 | 662 |
644 def savePOIs(filename, gmm, gmmType, gmmId): | 663 def savePOIs(filename, gmm, gmmType, gmmId): |
645 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) | 664 '''Saves a Gaussian mixture model (of class sklearn.mixture.GaussianMixture) |
646 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' | 665 gmmType is a type of GaussianMixture, learnt either from beginnings or ends of trajectories''' |
647 connection = sqlite3.connect(filename) | 666 connection = sqlite3.connect(filename) |