changeset 729:dad99b86a104 dev

merge with default
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Mon, 10 Aug 2015 17:52:19 -0400
parents 4e89341edd29 (current diff) c6d4ea05a2d0 (diff)
children a850a4f92735
files
diffstat 6 files changed, 200 insertions(+), 66 deletions(-) [+]
line wrap: on
line diff
--- a/python/cvutils.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/python/cvutils.py	Mon Aug 10 17:52:19 2015 -0400
@@ -25,9 +25,22 @@
 cvRed = (0,0,255)
 cvGreen = (0,255,0)
 cvBlue = (255,0,0)
+cvCyan = (255, 255, 0)
+cvYellow = (0, 255, 255)
+cvMagenta = (255, 0, 255)
+cvWhite = (255, 255, 255)
+cvBlack = (0,0,0)
+cvColors3 = utils.PlottingPropertyValues([cvRed,
+                                          cvGreen,
+                                          cvBlue])
 cvColors = utils.PlottingPropertyValues([cvRed,
                                          cvGreen,
-                                         cvBlue])
+                                         cvBlue,
+                                         cvCyan,
+                                         cvYellow,
+                                         cvMagenta,
+                                         cvWhite,
+                                         cvBlack])
 
 def quitKey(key):
     return chr(key&255)== 'q' or chr(key&255) == 'Q'
@@ -58,6 +71,9 @@
     #out = utils.openCheck(resultFilename)
     img.save(resultFilename)
 
+def rgb2gray(rgb):
+    return dot(rgb[...,:3], [0.299, 0.587, 0.144])
+
 def matlab2PointCorrespondences(filename):
     '''Loads and converts the point correspondences saved 
     by the matlab camera calibration tool'''
@@ -98,10 +114,11 @@
         return cvmat
 
     def cvPlot(img, positions, color, lastCoordinate = None, **kwargs):
-        last = lastCoordinate+1
-        if lastCoordinate is not None and lastCoordinate >=0:
+        if lastCoordinate is None:
+            last = positions.length()-1
+        elif lastCoordinate >=0:
             last = min(positions.length()-1, lastCoordinate)
-        for i in range(0, last-1):
+        for i in range(0, last):
             cv2.line(img, positions[i].asint().astuple(), positions[i+1].asint().astuple(), color, **kwargs)
 
     def cvImshow(windowName, img, rescale = 1.0):
@@ -264,7 +281,7 @@
         return croppedImg, yCropMin, yCropMax, xCropMin, xCropMax
 
 
-    def displayTrajectories(videoFilename, objects, boundingBoxes = {}, homography = None, firstFrameNum = 0, lastFrameNumArg = None, printFrames = True, rescale = 1., nFramesStep = 1, saveAllImages = False, undistort = False, intrinsicCameraMatrix = None, distortionCoefficients = None, undistortedImageMultiplication = 1.):
+    def displayTrajectories(videoFilename, objects, boundingBoxes = {}, homography = None, firstFrameNum = 0, lastFrameNumArg = None, printFrames = True, rescale = 1., nFramesStep = 1, saveAllImages = False, undistort = False, intrinsicCameraMatrix = None, distortionCoefficients = None, undistortedImageMultiplication = 1., annotations = [], gtMatches = {}, toMatches = {}):
         '''Displays the objects overlaid frame by frame over the video '''
         from moving import userTypeNames
         from math import ceil, log10
@@ -290,6 +307,7 @@
             else:
                 lastFrameNum = lastFrameNumArg
             nZerosFilename = int(ceil(log10(lastFrameNum)))
+            objectToDeleteIds = []
             while ret and not quitKey(key) and frameNum <= lastFrameNum:
                 ret, img = capture.read()
                 if ret:
@@ -297,24 +315,45 @@
                         img = cv2.remap(img, map1, map2, interpolation=cv2.INTER_LINEAR)
                     if printFrames:
                         print('frame {0}'.format(frameNum))
+                    if len(objectToDeleteIds) > 0:
+                        objects = [o for o in objects if o.getNum() not in objectToDeleteIds]
+                        objectToDeleteIds = []
+                    # plot objects
                     for obj in objects:
                         if obj.existsAtInstant(frameNum):
+                            if obj.getLastInstant() == frameNum:
+                                objectToDeleteIds.append(obj.getNum())
                             if not hasattr(obj, 'projectedPositions'):
                                 if homography is not None:
                                     obj.projectedPositions = obj.positions.project(homography)
                                 else:
                                     obj.projectedPositions = obj.positions
-                            cvPlot(img, obj.projectedPositions, cvRed, frameNum-obj.getFirstInstant())
-                            if frameNum in boundingBoxes.keys():
-                                for rect in boundingBoxes[frameNum]:
-                                    cv2.rectangle(img, rect[0].asint().astuple(), rect[1].asint().astuple(), cvRed)
-                            elif obj.hasFeatures():
+                            cvPlot(img, obj.projectedPositions, cvColors[obj.getNum()], frameNum-obj.getFirstInstant())
+                            if frameNum not in boundingBoxes.keys() and obj.hasFeatures():
                                 imgcrop, yCropMin, yCropMax, xCropMin, xCropMax = imageBox(img, obj, frameNum, homography, width, height)
                                 cv2.rectangle(img, (xCropMin, yCropMin), (xCropMax, yCropMax), cvBlue, 1)
                             objDescription = '{} '.format(obj.num)
                             if userTypeNames[obj.userType] != 'unknown':
                                 objDescription += userTypeNames[obj.userType][0].upper()
-                            cv2.putText(img, objDescription, obj.projectedPositions[frameNum-obj.getFirstInstant()].asint().astuple(), cv2.cv.CV_FONT_HERSHEY_PLAIN, 1, cvRed)
+                            if len(annotations) > 0: # if we loaded annotations, but there is no match
+                                if frameNum not in toMatches[obj.getNum()]:
+                                    objDescription += " FA"
+                            cv2.putText(img, objDescription, obj.projectedPositions[frameNum-obj.getFirstInstant()].asint().astuple(), cv2.cv.CV_FONT_HERSHEY_PLAIN, 1, cvColors[obj.getNum()])
+                    # plot object bounding boxes
+                    if frameNum in boundingBoxes.keys():
+                        for rect in boundingBoxes[frameNum]:
+                            cv2.rectangle(img, rect[0].asint().astuple(), rect[1].asint().astuple(), cvColors[obj.getNum()])
+                    # plot ground truth
+                    if len(annotations) > 0:
+                        for gt in annotations:
+                            if gt.existsAtInstant(frameNum):
+                                if frameNum in gtMatches[gt.getNum()]:
+                                    color = cvColors[gtMatches[gt.getNum()][frameNum]] # same color as object
+                                else:
+                                    color = cvRed
+                                    cv2.putText(img, 'Miss', gt.topLeftPositions[frameNum-gt.getFirstInstant()].asint().astuple(), cv2.cv.CV_FONT_HERSHEY_PLAIN, 1, cvRed)
+                                cv2.rectangle(img, gt.topLeftPositions[frameNum-gt.getFirstInstant()].asint().astuple(), gt.bottomRightPositions[frameNum-gt.getFirstInstant()].asint().astuple(), color)
+                    # saving images and going to next
                     if not saveAllImages:
                         cvImshow(windowName, img, rescale)
                         key = cv2.waitKey()
--- a/python/indicators.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/python/indicators.py	Mon Aug 10 17:52:19 2015 -0400
@@ -8,6 +8,9 @@
 from numpy import array, arange, mean, floor, mean
 
 
+def multivariateName(indicatorNames):
+    return '_'.join(indicatorNames)
+
 # need for a class representing the indicators, their units, how to print them in graphs...
 class TemporalIndicator(object):
     '''Class for temporal indicators
@@ -44,10 +47,7 @@
 
     def __getitem__(self, t):
         'Returns the value at time t'
-        if t in self.values.keys():
-            return self.values[t]
-        else:
-            return None
+        return self.values.get(t)
 
     def getIthValue(self, i):
         sortedKeys = sorted(self.values.keys())
@@ -86,18 +86,27 @@
         plot([(x+timeShift)/xfactor for x in time], [self.values[i]/yfactor for i in time], options+marker, **kwargs)
         if self.maxValue:
             ylim(ymax = self.maxValue)
-	
-    def valueSorted(self):
-	''' return the values after sort the keys in the indicator
-        This should probably not be used: to delete''' 
-        print('Deprecated: values should not be accessed in this way')
-        values=[]
-        keys = self.values.keys()
-        keys.sort()
-        for key in keys:
-            values.append(self.values[key]) 
-        return values
+
+    @classmethod
+    def createMultivariate(cls, indicators):
+        '''Creates a new temporal indicator where the value at each instant is a list 
+        of the indicator values at the instant, in the same order
+        the time interval will be the union of the time intervals of the indicators
+        name is concatenation of the indicator names'''
+        if len(indicators) < 2:
+            print('Error creating multivariate indicator with only {} indicator'.format(len(indicators)))
+            return None
 
+        timeInterval = moving.TimeInterval.unionIntervals([indic.getTimeInterval() for indic in indicators])
+        values = {}
+        for t in timeInterval:
+            tmpValues = [indic[t] for indic in indicators]
+            uniqueValues = set(tmpValues)
+            if len(uniqueValues) >= 2 or uniqueValues.pop() is not None:
+                values[t] = tmpValues
+        return cls(multivariateName([indic.name for indic in indicators]), values)
+
+# TODO static method avec class en parametre pour faire des indicateurs agrege, list par instant
 
 def l1Distance(x, y): # lambda x,y:abs(x-y)
     if x is None or y is None:
@@ -105,12 +114,20 @@
     else:
         return abs(x-y)
 
+def multiL1Matching(x, y, thresholds, proportionMatching=1.):
+    n = 0
+    nDimensions = len(x)
+    for i in range(nDimensions):
+        if l1Distance(x[i], y[i]) <= thresholds[i]:
+            n += 1
+    return n >= nDimensions*proportionMatching
+
 from utils import LCSS as utilsLCSS
 
 class LCSS(utilsLCSS):
     '''Adapted LCSS class for indicators, same pattern'''
     def __init__(self, similarityFunc, delta = float('inf'), minLength = 0, aligned = False, lengthFunc = min):
-        utilsLCSS.__init__(self, similarityFunc, delta, aligned, lengthFunc)
+        utilsLCSS.__init__(self, similarityFunc = similarityFunc, delta = delta, aligned = aligned, lengthFunc = lengthFunc)
         self.minLength = minLength
 
     def checkIndicator(self, indicator):
--- a/python/moving.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/python/moving.py	Mon Aug 10 17:52:19 2015 -0400
@@ -78,12 +78,13 @@
         else:
             return None
 
-def unionIntervals(intervals):
-    'returns the smallest interval containing all intervals'
-    inter = intervals[0]
-    for i in intervals[1:]:
-        inter = Interval.union(inter, i)
-    return inter
+    @classmethod
+    def unionIntervals(cls, intervals):
+        'returns the smallest interval containing all intervals'
+        inter = cls(intervals[0].first, intervals[0].last)
+        for i in intervals[1:]:
+            inter = cls.union(inter, i)
+        return inter
 
 
 class TimeInterval(Interval):
@@ -1067,6 +1068,42 @@
             print 'The object does not exist at '+str(inter)
             return None
 
+    def getObjectsInMask(self, mask, homography = None, minLength = 1):
+        '''Returns new objects made of the positions in the mask
+        mask is in the destination of the homography space'''
+        if homography is not None:
+            self.projectedPositions = self.positions.project(homography)
+        else:
+            self.projectedPositions = self.positions
+        def inMask(positions, i, mask):
+            p = positions[i]
+            return mask[p.y, p.x] != 0.
+
+        #subTimeIntervals self.getFirstInstant()+i
+        filteredIndices = [inMask(self.projectedPositions, i, mask) for i in range(int(self.length()))]
+        # 'connected components' in subTimeIntervals
+        l = 0
+        intervalLabels = []
+        prev = True
+        for i in filteredIndices:
+            if i:
+                if not prev: # new interval
+                    l += 1
+                intervalLabels.append(l)
+            else:
+                intervalLabels.append(-1)
+            prev = i
+        intervalLabels = array(intervalLabels)
+        subObjects = []
+        for l in set(intervalLabels):
+            if l >= 0:
+                if sum(intervalLabels == l) >= minLength:
+                    times = [self.getFirstInstant()+i for i in range(len(intervalLabels)) if intervalLabels[i] == l]
+                    subTimeInterval = TimeInterval(min(times), max(times))
+                    subObjects.append(self.getObjectInTimeInterval(subTimeInterval))
+
+        return subObjects
+
     def getPositions(self):
         return self.positions
 
@@ -1517,7 +1554,7 @@
         else:
             return matchingDistance + 1
 
-def computeClearMOT(annotations, objects, matchingDistance, firstInstant, lastInstant, debug = False):
+def computeClearMOT(annotations, objects, matchingDistance, firstInstant, lastInstant, returnMatches = False, debug = False):
     '''Computes the CLEAR MOT metrics 
 
     Reference:
@@ -1536,6 +1573,12 @@
     fpt number of false alarm.frames (tracker objects without match in each frame)
     gt number of GT.frames
 
+    if returnMatches is True, return as 2 new arguments the GT and TO matches
+    matches is a dict
+    matches[i] is the list of matches for GT/TO i
+    the list of matches is a dict, indexed by time, for the TO/GT id matched at time t 
+    (an instant t not present in matches[i] at which GT/TO exists means a missed detection or false alarm)
+
     TODO: Should we use the distance as weights or just 1/0 if distance below matchingDistance?
     (add argument useDistanceForWeights = False)'''
     from munkres import Munkres
@@ -1548,6 +1591,9 @@
     fpt = 0 # number of false alarm.frames (tracker objects without match in each frame)
     mme = 0 # number of mismatches
     matches = {} # match[i] is the tracker track associated with GT i (using object references)
+    if returnMatches:
+        gtMatches = {a.getNum():{} for a in annotations}
+        toMatches = {o.getNum():{} for o in objects}
     for t in xrange(firstInstant, lastInstant+1):
         previousMatches = matches.copy()
         # go through currently matched GT-TO and check if they are still matched withing matchingDistance
@@ -1583,6 +1629,10 @@
                     dist += costs[k][v]
         if debug:
             print('{} '.format(t)+', '.join(['{} {}'.format(k.getNum(), v.getNum()) for k,v in matches.iteritems()]))
+        if returnMatches:
+            for a,o in matches.iteritems():
+                gtMatches[a.getNum()][t] = o.getNum()
+                toMatches[o.getNum()][t] = a.getNum()
         
         # compute metrics elements
         ct += len(matches)
@@ -1615,8 +1665,10 @@
         mota = 1.-float(mt+fpt+mme)/gt
     else:
         mota = None
-    return motp, mota, mt, mme, fpt, gt
-
+    if returnMatches:
+        return motp, mota, mt, mme, fpt, gt, gtMatches, toMatches
+    else:
+        return motp, mota, mt, mme, fpt, gt
 
 def plotRoadUsers(objects, colors):
     '''Colors is a PlottingPropertyValues instance'''
--- a/python/storage.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/python/storage.py	Mon Aug 10 17:52:19 2015 -0400
@@ -654,11 +654,11 @@
         s = f.readline()
     return s.strip()
 
-def getLines(f, commentCharacters = commentChar):
+def getLines(f, delimiterChar = delimiterChar, commentCharacters = commentChar):
     '''Gets a complete entry (all the lines) in between delimiterChar.'''
     dataStrings = []
     s = readline(f, commentCharacters)
-    while len(s) > 0:
+    while len(s) > 0 and s[0] != delimiterChar:
         dataStrings += [s.strip()]
         s = readline(f, commentCharacters)
     return dataStrings
@@ -701,7 +701,7 @@
 def generatePDLaneColumn(data):
     data['LANE'] = data['LANE\LINK\NO'].astype(str)+'_'+data['LANE\INDEX'].astype(str)
 
-def loadTrajectoriesFromVissimFile(filename, simulationStepsPerTimeUnit, nObjects = -1, warmUpLastInstant = None, usePandas = False, nDecimals = 2):
+def loadTrajectoriesFromVissimFile(filename, simulationStepsPerTimeUnit, nObjects = -1, warmUpLastInstant = None, usePandas = False, nDecimals = 2, lowMemory = True):
     '''Reads data from VISSIM .fzp trajectory file
     simulationStepsPerTimeUnit is the number of simulation steps per unit of time used by VISSIM
     for example, there seems to be 5 simulation steps per simulated second in VISSIM, 
@@ -716,7 +716,7 @@
 
     if usePandas:
         from pandas import read_csv
-        data = read_csv(filename, delimiter=';', comment='*', header=0, skiprows = 1)
+        data = read_csv(filename, delimiter=';', comment='*', header=0, skiprows = 1, low_memory = lowMemory)
         generatePDLaneColumn(data)
         data['TIME'] = data['$VEHICLE:SIMSEC']*simulationStepsPerTimeUnit
         if warmUpLastInstant is not None:
@@ -782,7 +782,7 @@
     columns = ['NO', '$VEHICLE:SIMSEC', 'POS']
     if lanes is not None:
         columns += ['LANE\LINK\NO', 'LANE\INDEX']
-    data = read_csv(filename, delimiter=';', comment='*', header=0, skiprows = 1, usecols = columns)
+    data = read_csv(filename, delimiter=';', comment='*', header=0, skiprows = 1, usecols = columns, low_memory = lowMemory)
     data = selectPDLanes(data, lanes)
     data.sort(['$VEHICLE:SIMSEC'], inplace = True)
 
@@ -806,7 +806,7 @@
     If lanes is not None, only the data for the selected lanes will be provided
     (format as string x_y where x is link index and y is lane index)'''
     from pandas import read_csv, merge
-    data = read_csv(filename, delimiter=';', comment='*', header=0, skiprows = 1, usecols = ['LANE\LINK\NO', 'LANE\INDEX', '$VEHICLE:SIMSEC', 'NO', 'POS'])
+    data = read_csv(filename, delimiter=';', comment='*', header=0, skiprows = 1, usecols = ['LANE\LINK\NO', 'LANE\INDEX', '$VEHICLE:SIMSEC', 'NO', 'POS'], low_memory = lowMemory)
     data = selectPDLanes(data, lanes)
     merged = merge(data, data, how='inner', left_on=['LANE\LINK\NO', 'LANE\INDEX', '$VEHICLE:SIMSEC'], right_on=['LANE\LINK\NO', 'LANE\INDEX', '$VEHICLE:SIMSEC'], sort = False)
     merged = merged[merged['NO_x']>merged['NO_y']]
--- a/python/ubc_utils.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/python/ubc_utils.py	Mon Aug 10 17:52:19 2015 -0400
@@ -1,7 +1,7 @@
 #! /usr/bin/env python
 '''Various utilities to load data saved by the UBC tool(s)'''
 
-import utils, events
+import utils, events, storage
 from moving import MovingObject, TimeInterval, Trajectory
 
 
@@ -57,13 +57,13 @@
     by just copying the corresponding trajectory and velocity data
     from the inFilename, and saving the characteristics in objects (first line)
     into outFilename'''
-    infile = utils.openCheck(inFilename)
-    outfile = utils.openCheck(outFilename,'w')
+    infile = storage.openCheck(inFilename)
+    outfile = storage.openCheck(outFilename,'w')
 
     if (inFilename.find('features') >= 0) or (not infile) or (not outfile):
         return
 
-    lines = utils.getLines(infile)
+    lines = storage.getLines(infile)
     objNum = 0 # in inFilename
     while lines != []:
         # find object in objects (index i)
@@ -80,16 +80,16 @@
             outfile.write(utils.delimiterChar+'\n')
         # next object
         objNum += 1
-        lines = utils.getLines(infile)
+        lines = storage.getLines(infile)
 
     print('read {0} objects'.format(objNum))
 
 def modifyTrajectoryFile(modifyLines, filenameIn, filenameOut):
     '''Reads filenameIn, replaces the lines with the result of modifyLines and writes the result in filenameOut'''
-    fileIn = utils.openCheck(filenameIn, 'r', True)
-    fileOut = utils.openCheck(filenameOut, "w", True)
+    fileIn = storage.openCheck(filenameIn, 'r', True)
+    fileOut = storage.openCheck(filenameOut, "w", True)
 
-    lines = utils.getLines(fileIn)
+    lines = storage.getLines(fileIn)
     trajNum = 0
     while (lines != []):
         modifiedLines = modifyLines(trajNum, lines)
@@ -97,7 +97,7 @@
             for l in modifiedLines:
                 fileOut.write(l+"\n")
             fileOut.write(utils.delimiterChar+"\n")
-        lines = utils.getLines(fileIn)
+        lines = storage.getLines(fileIn)
         trajNum += 1
          
     fileIn.close()
@@ -106,17 +106,17 @@
 def copyTrajectoryFile(keepTrajectory, filenameIn, filenameOut):
     '''Reads filenameIn, keeps the trajectories for which the function keepTrajectory(trajNum, lines) is True
     and writes the result in filenameOut'''
-    fileIn = utils.openCheck(filenameIn, 'r', True)
-    fileOut = utils.openCheck(filenameOut, "w", True)
+    fileIn = storage.openCheck(filenameIn, 'r', True)
+    fileOut = storage.openCheck(filenameOut, "w", True)
 
-    lines = utils.getLines(fileIn)
+    lines = storage.getLines(fileIn)
     trajNum = 0
     while (lines != []):
         if keepTrajectory(trajNum, lines):
             for l in lines:
                 fileOut.write(l+"\n")
             fileOut.write(utils.delimiterChar+"\n")
-        lines = utils.getLines(fileIn)
+        lines = storage.getLines(fileIn)
         trajNum += 1
         
     fileIn.close()
@@ -125,14 +125,14 @@
 def loadTrajectories(filename, nObjects = -1):
     '''Loads trajectories'''
 
-    file = utils.openCheck(filename)
+    file = storage.openCheck(filename)
     if (not file):
         return []
 
     objects = []
     objNum = 0
     objectType = getFileType(filename)
-    lines = utils.getLines(file)
+    lines = storage.getLines(file)
     while (lines != []) and ((nObjects<0) or (objNum<nObjects)):
         l = lines[0].split(' ')
         parsedLine = [int(n) for n in l[:4]]
@@ -162,7 +162,7 @@
         else:
             print("Error two lines of data for feature %d"%(f.num))
 
-        lines = utils.getLines(file)
+        lines = storage.getLines(file)
 
     file.close()
     return objects
@@ -177,13 +177,13 @@
     'Loads interactions from the old UBC traffic event format'
     from events import Interaction 
     from indicators import SeverityIndicator
-    file = utils.openCheck(filename)
+    file = storage.openCheck(filename)
     if (not file):
         return []
 
     interactions = []
     interactionNum = 0
-    lines = utils.getLines(file)
+    lines = storage.getLines(file)
     while (lines != []) and ((nInteractions<0) or (interactionNum<nInteractions)):
         parsedLine = [int(n) for n in lines[0].split(' ')]
         inter = Interaction(interactionNum, TimeInterval(parsedLine[1],parsedLine[2]), parsedLine[3], parsedLine[4], categoryNum = parsedLine[5])
@@ -198,7 +198,7 @@
 
         interactions.append(inter)
         interactionNum+=1
-        lines = utils.getLines(file)
+        lines = storage.getLines(file)
 
     file.close()
     return interactions
@@ -206,13 +206,13 @@
 def loadCollisionPoints(filename, nPoints = -1):
     '''Loads collision points and returns a dict
     with keys as a pair of the numbers of the two interacting objects'''
-    file = utils.openCheck(filename)
+    file = storage.openCheck(filename)
     if (not file):
         return []
 
     points = {}
     num = 0
-    lines = utils.getLines(file)
+    lines = storage.getLines(file)
     while (lines != []) and ((nPoints<0) or (num<nPoints)):
         parsedLine = [int(n) for n in lines[0].split(' ')]
         protagonistNums = (parsedLine[0], parsedLine[1])
@@ -220,7 +220,7 @@
                                    [float(n) for n in lines[2].split(' ')]]
 
         num+=1
-        lines = utils.getLines(file)
+        lines = storage.getLines(file)
 
     file.close()
     return points
--- a/scripts/compute-clearmot.py	Mon Aug 10 17:51:49 2015 -0400
+++ b/scripts/compute-clearmot.py	Mon Aug 10 17:52:19 2015 -0400
@@ -2,7 +2,8 @@
 
 import sys, argparse
 from numpy import loadtxt
-import moving, storage
+from numpy.linalg import inv
+import moving, storage, cvutils
 
 # TODO: need to trim objects to same mask ?
 
@@ -16,8 +17,11 @@
 parser.add_argument('-g', dest = 'groundTruthDatabaseFilename', help = 'name of the Sqlite database containing the ground truth', required = True)
 parser.add_argument('-o', dest = 'homographyFilename', help = 'name of the filename for the homography (if tracking was done using the homography)')
 parser.add_argument('-m', dest = 'matchingDistance', help = 'matching distance between tracker and ground truth trajectories', required = True, type = float)
+parser.add_argument('--mask', dest = 'maskFilename', help = 'filename of the mask file used to define the where objects were tracked')
 parser.add_argument('-f', dest = 'firstInstant', help = 'first instant for measurement', required = True, type = int)
 parser.add_argument('-l', dest = 'lastInstant', help = 'last instant for measurement', required = True, type = int)
+parser.add_argument('--display', dest = 'display', help = 'display the ground truth to object matches (graphically)', action = 'store_true')
+parser.add_argument('-i', dest = 'videoFilename', help = 'name of the video file (for display)')
 args = parser.parse_args()
 
 if args.homographyFilename is not None:
@@ -26,14 +30,36 @@
     homography = None
 
 objects = storage.loadTrajectoriesFromSqlite(args.trackerDatabaseFilename, 'object')
+
+if args.maskFilename is not None:
+    maskObjects = []
+    from matplotlib.pyplot import imread
+    mask = imread(args.maskFilename)
+    if len(mask) > 1:
+        mask = mask[:,:,0]
+    for obj in objects:
+        maskObjects += obj.getObjectsInMask(mask, inv(homography), 2) # TODO add option to keep object if at least one feature in mask
+    objects = maskObjects    
+
 annotations = storage.loadGroundTruthFromSqlite(args.groundTruthDatabaseFilename)
 for a in annotations:
     a.computeCentroidTrajectory(homography)
 
-motp, mota, mt, mme, fpt, gt = moving.computeClearMOT(annotations, objects, args.matchingDistance, args.firstInstant, args.lastInstant)
+if args.display:
+    motp, mota, mt, mme, fpt, gt, gtMatches, toMatches = moving.computeClearMOT(annotations, objects, args.matchingDistance, args.firstInstant, args.lastInstant, True)
+else:
+    motp, mota, mt, mme, fpt, gt = moving.computeClearMOT(annotations, objects, args.matchingDistance, args.firstInstant, args.lastInstant)
+
 
 print 'MOTP: {}'.format(motp)
 print 'MOTA: {}'.format(mota)
 print 'Number of missed objects.frames: {}'.format(mt)
 print 'Number of mismatches: {}'.format(mme)
 print 'Number of false alarms.frames: {}'.format(fpt)
+if args.display:
+    cvutils.displayTrajectories(args.videoFilename, objects, {}, inv(homography), args.firstInstant, args.lastInstant, annotations = annotations, gtMatches = gtMatches, toMatches = toMatches)#, rescale = args.rescale, nFramesStep = args.nFramesStep, saveAllImages = args.saveAllImages, undistort = (undistort or args.undistort), intrinsicCameraMatrix = intrinsicCameraMatrix, distortionCoefficients = distortionCoefficients, undistortedImageMultiplication = undistortedImageMultiplication)
+
+    #print('Ground truth matches')
+    #print(gtMatches)
+    #print('Object matches')
+    #rint toMatches