changeset 1236:100fe098abe9

progress on classification
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 19 Sep 2023 17:04:30 -0400
parents 855abc69fa99
children 31a441efca6c
files scripts/dltrack.py
diffstat 1 files changed, 33 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/scripts/dltrack.py	Fri Sep 15 11:56:16 2023 -0400
+++ b/scripts/dltrack.py	Tue Sep 19 17:04:30 2023 -0400
@@ -2,7 +2,10 @@
 # from https://docs.ultralytics.com/modes/track/
 import sys, argparse
 from copy import copy
+from collections import Counter
 from ultralytics import YOLO
+from torch import cat
+from torchvision import ops
 import cv2
 
 from trafficintelligence import cvutils, moving, storage, utils
@@ -15,6 +18,7 @@
 parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true')
 parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0)
 parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf'))
+parser.add_argument('--bike-pct', dest = 'bikeProportion', help = 'percent of time a person classified as bike or motorbike to be classified as cyclist', type = float, default = 0.2)
 args = parser.parse_args()
 
 # required functionality?
@@ -66,7 +70,6 @@
 # model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get'
 
 # Track with the model
-#results = model.track(source=args.videoFilename, tracker="/home/nicolas/Research/Data/classification-models/bytetrack.yaml", classes=list(moving.cocoTypeNames.keys()), show=True) # , save_txt=True
 if args.display:
     windowName = 'frame'
     cv2.namedWindow(windowName, cv2.WINDOW_NORMAL)
@@ -87,27 +90,27 @@
 while capture.isOpened() and success and frameNum <= lastFrameNum:
 #for frameNum, result in enumerate(results):
     result = results[0]
-    print(frameNum, len(result.boxes))
+    print(frameNum, len(result.boxes), 'objects')
     for box in result.boxes:
         #print(box.cls, box.id, box.xyxy)
         if box.id is not None: # None are objects with low confidence
-            num = int(box.id)
-            xyxy = box.xyxy[0].tolist()
+            num = int(box.id.item())
+            #xyxy = box.xyxy[0].tolist()
             if num in currentObjects:
                 currentObjects[num].timeInterval.last = frameNum
-                currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls)])
-                currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(xyxy[0],xyxy[1])
-                currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(xyxy[2],xyxy[3])
-                #features[0].getPositions().addPositionXY(xyxy[0],xyxy[1])
-                #features[1].getPositions().addPositionXY(xyxy[2],xyxy[3])
+                currentObjects[num].bboxes[frameNum] = copy(box.xyxy)
+                currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls.item())])
+                currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())
+                currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())
             else:
                 inter = moving.TimeInterval(frameNum,frameNum)
                 currentObjects[num] = moving.MovingObject(num, inter)
-                currentObjects[num].userTypes = [moving.coco2Types[int(box.cls)]]
+                currentObjects[num].bboxes = {frameNum: copy(box.xyxy)}
+                currentObjects[num].userTypes = [moving.coco2Types[int(box.cls.item())]]
                 currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)]
                 currentObjects[num].featureNumbers = [featureNum, featureNum+1]
-                currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(xyxy[0],xyxy[1])}
-                currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(xyxy[2],xyxy[3])}
+                currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())}
+                currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())}
                 featureNum += 2
     if args.display:
         cvutils.cvImshow(windowName, result.plot()) # original image in orig_img
@@ -118,9 +121,25 @@
     success, frame = capture.read()
     results = model.track(frame, persist=True)
 
-# interpolate and generate velocity before saving
+# classification
 for num, obj in currentObjects.items():
-    obj.setUserType(utils.mostCommon(obj.userTypes))
+    #obj.setUserType(utils.mostCommon(obj.userTypes)) # improve? mix with speed?
+    userTypeStats = Counter(obj.userTypes)
+    if (4 in userTypeStats or (3 in userTypeStats and 4 in userTypeStats and userTypeStats[3]<=userTypeStats[4])) and userTypeStats[3]+userTypeStats[4] > args.bikeProportion*userTypeStats.total(): # 3 is motorcycle and 4 is cyclist (verif if not turning all motorbike into cyclists)
+        obj.setUserType(4)
+    else:
+        obj.setUserType(userTypeStats.most_common()[0][0])
+
+# merge bikes and people
+#Construire graphe bipartite vélo/moto personne
+#Lien = somme des iou / longueur track vélo 
+#Algo Hongrois
+#Verif overlap piéton vélo : si long, changement mode (trouver exemples)
+
+# for all cyclists and motorbikes
+
+# interpolate and generate velocity (?) before saving
+for num, obj in currentObjects.items():
     obj.features[0].timeInterval = copy(obj.getTimeInterval())
     obj.features[1].timeInterval = copy(obj.getTimeInterval())
     if obj.length() != len(obj.features[0].tmpPositions): # interpolate