Mercurial Hosting > traffic-intelligence
changeset 1236:100fe098abe9
progress on classification
author | Nicolas Saunier <nicolas.saunier@polymtl.ca> |
---|---|
date | Tue, 19 Sep 2023 17:04:30 -0400 |
parents | 855abc69fa99 |
children | 31a441efca6c |
files | scripts/dltrack.py |
diffstat | 1 files changed, 33 insertions(+), 14 deletions(-) [+] |
line wrap: on
line diff
--- a/scripts/dltrack.py Fri Sep 15 11:56:16 2023 -0400 +++ b/scripts/dltrack.py Tue Sep 19 17:04:30 2023 -0400 @@ -2,7 +2,10 @@ # from https://docs.ultralytics.com/modes/track/ import sys, argparse from copy import copy +from collections import Counter from ultralytics import YOLO +from torch import cat +from torchvision import ops import cv2 from trafficintelligence import cvutils, moving, storage, utils @@ -15,6 +18,7 @@ parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true') parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0) parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf')) +parser.add_argument('--bike-pct', dest = 'bikeProportion', help = 'percent of time a person classified as bike or motorbike to be classified as cyclist', type = float, default = 0.2) args = parser.parse_args() # required functionality? @@ -66,7 +70,6 @@ # model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get' # Track with the model -#results = model.track(source=args.videoFilename, tracker="/home/nicolas/Research/Data/classification-models/bytetrack.yaml", classes=list(moving.cocoTypeNames.keys()), show=True) # , save_txt=True if args.display: windowName = 'frame' cv2.namedWindow(windowName, cv2.WINDOW_NORMAL) @@ -87,27 +90,27 @@ while capture.isOpened() and success and frameNum <= lastFrameNum: #for frameNum, result in enumerate(results): result = results[0] - print(frameNum, len(result.boxes)) + print(frameNum, len(result.boxes), 'objects') for box in result.boxes: #print(box.cls, box.id, box.xyxy) if box.id is not None: # None are objects with low confidence - num = int(box.id) - xyxy = box.xyxy[0].tolist() + num = int(box.id.item()) + #xyxy = box.xyxy[0].tolist() if num in currentObjects: currentObjects[num].timeInterval.last = frameNum - currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls)]) - currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(xyxy[0],xyxy[1]) - currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(xyxy[2],xyxy[3]) - #features[0].getPositions().addPositionXY(xyxy[0],xyxy[1]) - #features[1].getPositions().addPositionXY(xyxy[2],xyxy[3]) + currentObjects[num].bboxes[frameNum] = copy(box.xyxy) + currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls.item())]) + currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item()) + currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item()) else: inter = moving.TimeInterval(frameNum,frameNum) currentObjects[num] = moving.MovingObject(num, inter) - currentObjects[num].userTypes = [moving.coco2Types[int(box.cls)]] + currentObjects[num].bboxes = {frameNum: copy(box.xyxy)} + currentObjects[num].userTypes = [moving.coco2Types[int(box.cls.item())]] currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)] currentObjects[num].featureNumbers = [featureNum, featureNum+1] - currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(xyxy[0],xyxy[1])} - currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(xyxy[2],xyxy[3])} + currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())} + currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())} featureNum += 2 if args.display: cvutils.cvImshow(windowName, result.plot()) # original image in orig_img @@ -118,9 +121,25 @@ success, frame = capture.read() results = model.track(frame, persist=True) -# interpolate and generate velocity before saving +# classification for num, obj in currentObjects.items(): - obj.setUserType(utils.mostCommon(obj.userTypes)) + #obj.setUserType(utils.mostCommon(obj.userTypes)) # improve? mix with speed? + userTypeStats = Counter(obj.userTypes) + if (4 in userTypeStats or (3 in userTypeStats and 4 in userTypeStats and userTypeStats[3]<=userTypeStats[4])) and userTypeStats[3]+userTypeStats[4] > args.bikeProportion*userTypeStats.total(): # 3 is motorcycle and 4 is cyclist (verif if not turning all motorbike into cyclists) + obj.setUserType(4) + else: + obj.setUserType(userTypeStats.most_common()[0][0]) + +# merge bikes and people +#Construire graphe bipartite vélo/moto personne +#Lien = somme des iou / longueur track vélo +#Algo Hongrois +#Verif overlap piéton vélo : si long, changement mode (trouver exemples) + +# for all cyclists and motorbikes + +# interpolate and generate velocity (?) before saving +for num, obj in currentObjects.items(): obj.features[0].timeInterval = copy(obj.getTimeInterval()) obj.features[1].timeInterval = copy(obj.getTimeInterval()) if obj.length() != len(obj.features[0].tmpPositions): # interpolate