Mercurial Hosting > traffic-intelligence
comparison scripts/dltrack.py @ 1236:100fe098abe9
progress on classification
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Tue, 19 Sep 2023 17:04:30 -0400 |
parents | 855abc69fa99 |
children | 31a441efca6c |
comparison
equal
deleted
inserted
replaced
1235:855abc69fa99 | 1236:100fe098abe9 |
---|---|
1 #! /usr/bin/env python3 | 1 #! /usr/bin/env python3 |
2 # from https://docs.ultralytics.com/modes/track/ | 2 # from https://docs.ultralytics.com/modes/track/ |
3 import sys, argparse | 3 import sys, argparse |
4 from copy import copy | 4 from copy import copy |
5 from collections import Counter | |
5 from ultralytics import YOLO | 6 from ultralytics import YOLO |
7 from torch import cat | |
8 from torchvision import ops | |
6 import cv2 | 9 import cv2 |
7 | 10 |
8 from trafficintelligence import cvutils, moving, storage, utils | 11 from trafficintelligence import cvutils, moving, storage, utils |
9 | 12 |
10 parser = argparse.ArgumentParser(description='The program tracks objects following the ultralytics yolo executable.')#, epilog = 'Either the configuration filename or the other parameters (at least video and database filenames) need to be provided.') | 13 parser = argparse.ArgumentParser(description='The program tracks objects following the ultralytics yolo executable.')#, epilog = 'Either the configuration filename or the other parameters (at least video and database filenames) need to be provided.') |
13 parser.add_argument('-m', dest = 'detectorFilename', help = 'name of the detection model file', required = True) | 16 parser.add_argument('-m', dest = 'detectorFilename', help = 'name of the detection model file', required = True) |
14 parser.add_argument('-t', dest = 'trackerFilename', help = 'name of the tracker file', required = True) | 17 parser.add_argument('-t', dest = 'trackerFilename', help = 'name of the tracker file', required = True) |
15 parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true') | 18 parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true') |
16 parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0) | 19 parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0) |
17 parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf')) | 20 parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf')) |
21 parser.add_argument('--bike-pct', dest = 'bikeProportion', help = 'percent of time a person classified as bike or motorbike to be classified as cyclist', type = float, default = 0.2) | |
18 args = parser.parse_args() | 22 args = parser.parse_args() |
19 | 23 |
20 # required functionality? | 24 # required functionality? |
21 # # filename of the video to process (can be images, eg image%04d.png) | 25 # # filename of the video to process (can be images, eg image%04d.png) |
22 # video-filename = laurier.avi | 26 # video-filename = laurier.avi |
64 model = YOLO(args.detectorFilename, ) # seg yolov8x-seg.pt | 68 model = YOLO(args.detectorFilename, ) # seg yolov8x-seg.pt |
65 # seg could be used on cropped image... if can be loaded and kept in memory | 69 # seg could be used on cropped image... if can be loaded and kept in memory |
66 # model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get' | 70 # model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get' |
67 | 71 |
68 # Track with the model | 72 # Track with the model |
69 #results = model.track(source=args.videoFilename, tracker="/home/nicolas/Research/Data/classification-models/bytetrack.yaml", classes=list(moving.cocoTypeNames.keys()), show=True) # , save_txt=True | |
70 if args.display: | 73 if args.display: |
71 windowName = 'frame' | 74 windowName = 'frame' |
72 cv2.namedWindow(windowName, cv2.WINDOW_NORMAL) | 75 cv2.namedWindow(windowName, cv2.WINDOW_NORMAL) |
73 | 76 |
74 capture = cv2.VideoCapture(args.videoFilename) | 77 capture = cv2.VideoCapture(args.videoFilename) |
85 results = model.track(frame, tracker=args.trackerFilename, classes=list(moving.cocoTypeNames.keys()), persist=True) | 88 results = model.track(frame, tracker=args.trackerFilename, classes=list(moving.cocoTypeNames.keys()), persist=True) |
86 # create object with user type and list of 3 features (bottom ones and middle) + projection | 89 # create object with user type and list of 3 features (bottom ones and middle) + projection |
87 while capture.isOpened() and success and frameNum <= lastFrameNum: | 90 while capture.isOpened() and success and frameNum <= lastFrameNum: |
88 #for frameNum, result in enumerate(results): | 91 #for frameNum, result in enumerate(results): |
89 result = results[0] | 92 result = results[0] |
90 print(frameNum, len(result.boxes)) | 93 print(frameNum, len(result.boxes), 'objects') |
91 for box in result.boxes: | 94 for box in result.boxes: |
92 #print(box.cls, box.id, box.xyxy) | 95 #print(box.cls, box.id, box.xyxy) |
93 if box.id is not None: # None are objects with low confidence | 96 if box.id is not None: # None are objects with low confidence |
94 num = int(box.id) | 97 num = int(box.id.item()) |
95 xyxy = box.xyxy[0].tolist() | 98 #xyxy = box.xyxy[0].tolist() |
96 if num in currentObjects: | 99 if num in currentObjects: |
97 currentObjects[num].timeInterval.last = frameNum | 100 currentObjects[num].timeInterval.last = frameNum |
98 currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls)]) | 101 currentObjects[num].bboxes[frameNum] = copy(box.xyxy) |
99 currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(xyxy[0],xyxy[1]) | 102 currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls.item())]) |
100 currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(xyxy[2],xyxy[3]) | 103 currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item()) |
101 #features[0].getPositions().addPositionXY(xyxy[0],xyxy[1]) | 104 currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item()) |
102 #features[1].getPositions().addPositionXY(xyxy[2],xyxy[3]) | |
103 else: | 105 else: |
104 inter = moving.TimeInterval(frameNum,frameNum) | 106 inter = moving.TimeInterval(frameNum,frameNum) |
105 currentObjects[num] = moving.MovingObject(num, inter) | 107 currentObjects[num] = moving.MovingObject(num, inter) |
106 currentObjects[num].userTypes = [moving.coco2Types[int(box.cls)]] | 108 currentObjects[num].bboxes = {frameNum: copy(box.xyxy)} |
109 currentObjects[num].userTypes = [moving.coco2Types[int(box.cls.item())]] | |
107 currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)] | 110 currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)] |
108 currentObjects[num].featureNumbers = [featureNum, featureNum+1] | 111 currentObjects[num].featureNumbers = [featureNum, featureNum+1] |
109 currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(xyxy[0],xyxy[1])} | 112 currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())} |
110 currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(xyxy[2],xyxy[3])} | 113 currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())} |
111 featureNum += 2 | 114 featureNum += 2 |
112 if args.display: | 115 if args.display: |
113 cvutils.cvImshow(windowName, result.plot()) # original image in orig_img | 116 cvutils.cvImshow(windowName, result.plot()) # original image in orig_img |
114 key = cv2.waitKey() | 117 key = cv2.waitKey() |
115 if cvutils.quitKey(key): | 118 if cvutils.quitKey(key): |
116 break | 119 break |
117 frameNum += 1 | 120 frameNum += 1 |
118 success, frame = capture.read() | 121 success, frame = capture.read() |
119 results = model.track(frame, persist=True) | 122 results = model.track(frame, persist=True) |
120 | 123 |
121 # interpolate and generate velocity before saving | 124 # classification |
122 for num, obj in currentObjects.items(): | 125 for num, obj in currentObjects.items(): |
123 obj.setUserType(utils.mostCommon(obj.userTypes)) | 126 #obj.setUserType(utils.mostCommon(obj.userTypes)) # improve? mix with speed? |
127 userTypeStats = Counter(obj.userTypes) | |
128 if (4 in userTypeStats or (3 in userTypeStats and 4 in userTypeStats and userTypeStats[3]<=userTypeStats[4])) and userTypeStats[3]+userTypeStats[4] > args.bikeProportion*userTypeStats.total(): # 3 is motorcycle and 4 is cyclist (verif if not turning all motorbike into cyclists) | |
129 obj.setUserType(4) | |
130 else: | |
131 obj.setUserType(userTypeStats.most_common()[0][0]) | |
132 | |
133 # merge bikes and people | |
134 #Construire graphe bipartite vélo/moto personne | |
135 #Lien = somme des iou / longueur track vélo | |
136 #Algo Hongrois | |
137 #Verif overlap piéton vélo : si long, changement mode (trouver exemples) | |
138 | |
139 # for all cyclists and motorbikes | |
140 | |
141 # interpolate and generate velocity (?) before saving | |
142 for num, obj in currentObjects.items(): | |
124 obj.features[0].timeInterval = copy(obj.getTimeInterval()) | 143 obj.features[0].timeInterval = copy(obj.getTimeInterval()) |
125 obj.features[1].timeInterval = copy(obj.getTimeInterval()) | 144 obj.features[1].timeInterval = copy(obj.getTimeInterval()) |
126 if obj.length() != len(obj.features[0].tmpPositions): # interpolate | 145 if obj.length() != len(obj.features[0].tmpPositions): # interpolate |
127 obj.features[0].positions = moving.Trajectory.fromPointDict(obj.features[0].tmpPositions) | 146 obj.features[0].positions = moving.Trajectory.fromPointDict(obj.features[0].tmpPositions) |
128 obj.features[1].positions = moving.Trajectory.fromPointDict(obj.features[1].tmpPositions) | 147 obj.features[1].positions = moving.Trajectory.fromPointDict(obj.features[1].tmpPositions) |