view scripts/dltrack.py @ 1236:100fe098abe9

progress on classification
author Nicolas Saunier <nicolas.saunier@polymtl.ca>
date Tue, 19 Sep 2023 17:04:30 -0400
parents 855abc69fa99
children 31a441efca6c
line wrap: on
line source

#! /usr/bin/env python3
# from https://docs.ultralytics.com/modes/track/
import sys, argparse
from copy import copy
from collections import Counter
from ultralytics import YOLO
from torch import cat
from torchvision import ops
import cv2

from trafficintelligence import cvutils, moving, storage, utils

parser = argparse.ArgumentParser(description='The program tracks objects following the ultralytics yolo executable.')#, epilog = 'Either the configuration filename or the other parameters (at least video and database filenames) need to be provided.')
parser.add_argument('-i', dest = 'videoFilename', help = 'name of the video file', required = True)
parser.add_argument('-d', dest = 'databaseFilename', help = 'name of the Sqlite database file', required = True)
parser.add_argument('-m', dest = 'detectorFilename', help = 'name of the detection model file', required = True)
parser.add_argument('-t', dest = 'trackerFilename', help = 'name of the tracker file', required = True)
parser.add_argument('--display', dest = 'display', help = 'show the results (careful with long videos, risk of running out of memory)', action = 'store_true')
parser.add_argument('-f', dest = 'firstFrameNum', help = 'number of first frame number to process', type = int, default = 0)
parser.add_argument('-l', dest = 'lastFrameNum', help = 'number of last frame number to process', type = int, default = float('Inf'))
parser.add_argument('--bike-pct', dest = 'bikeProportion', help = 'percent of time a person classified as bike or motorbike to be classified as cyclist', type = float, default = 0.2)
args = parser.parse_args()

# required functionality?
# # filename of the video to process (can be images, eg image%04d.png)
# video-filename = laurier.avi
# # filename of the database where results are saved
# database-filename = laurier.sqlite
# # filename of the homography matrix
# homography-filename = laurier-homography.txt
# # filename of the camera intrinsic matrix
# intrinsic-camera-filename = intrinsic-camera.txt
# # -0.11759321 0.0148536 0.00030756 -0.00020578 -0.00091816
# distortion-coefficients = -0.11759321
# distortion-coefficients = 0.0148536
# distortion-coefficients = 0.00030756 
# distortion-coefficients = -0.00020578 
# distortion-coefficients = -0.00091816
# # undistorted image multiplication
# undistorted-size-multiplication = 1.31
# # Interpolation method for remapping image when correcting for distortion: 0 for INTER_NEAREST - a nearest-neighbor interpolation; 1 for INTER_LINEAR - a bilinear interpolation (used by default); 2 for INTER_CUBIC - a bicubic interpolation over 4x4 pixel neighborhood; 3 for INTER_LANCZOS4
# interpolation-method = 1
# # filename of the mask image (where features are detected)
# mask-filename = none
# # undistort the video for feature tracking
# undistort = false
# # load features from database
# load-features = false
# # display trajectories on the video
# display = false
# # original video frame rate (number of frames/s)
# video-fps = 29.97
# # number of digits of precision for all measurements derived from video
# # measurement-precision = 3
# # first frame to process
# frame1 = 0
# # number of frame to process: 0 means processing all frames
# nframes = 0

# TODO add option to refine position with mask for vehicles

# use 2 x bytetrack track buffer to remove objects from existing ones


# check if one can go to specific frame https://docs.ultralytics.com/modes/track/#persisting-tracks-loop

# Load a model
model = YOLO(args.detectorFilename, ) # seg yolov8x-seg.pt
# seg could be used on cropped image... if can be loaded and kept in memory
# model = YOLO('/home/nicolas/Research/Data/classification-models/yolo_nas_l.pt ') # AttributeError: 'YoloNAS_L' object has no attribute 'get'

# Track with the model
if args.display:
    windowName = 'frame'
    cv2.namedWindow(windowName, cv2.WINDOW_NORMAL)

capture = cv2.VideoCapture(args.videoFilename)
#results = model.track(source=args.videoFilename, tracker="/home/nicolas/Research/Data/classification-models/bytetrack.yaml", classes=list(moving.cocoTypeNames.keys()), stream=True)
objects = []
currentObjects = {}
featureNum = 0

frameNum = args.firstFrameNum
capture.set(cv2.CAP_PROP_POS_FRAMES, frameNum)
lastFrameNum = args.lastFrameNum

success, frame = capture.read()
results = model.track(frame, tracker=args.trackerFilename, classes=list(moving.cocoTypeNames.keys()), persist=True)
# create object with user type and list of 3 features (bottom ones and middle) + projection
while capture.isOpened() and success and frameNum <= lastFrameNum:
#for frameNum, result in enumerate(results):
    result = results[0]
    print(frameNum, len(result.boxes), 'objects')
    for box in result.boxes:
        #print(box.cls, box.id, box.xyxy)
        if box.id is not None: # None are objects with low confidence
            num = int(box.id.item())
            #xyxy = box.xyxy[0].tolist()
            if num in currentObjects:
                currentObjects[num].timeInterval.last = frameNum
                currentObjects[num].bboxes[frameNum] = copy(box.xyxy)
                currentObjects[num].userTypes.append(moving.coco2Types[int(box.cls.item())])
                currentObjects[num].features[0].tmpPositions[frameNum] = moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())
                currentObjects[num].features[1].tmpPositions[frameNum] = moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())
            else:
                inter = moving.TimeInterval(frameNum,frameNum)
                currentObjects[num] = moving.MovingObject(num, inter)
                currentObjects[num].bboxes = {frameNum: copy(box.xyxy)}
                currentObjects[num].userTypes = [moving.coco2Types[int(box.cls.item())]]
                currentObjects[num].features = [moving.MovingObject(featureNum), moving.MovingObject(featureNum+1)]
                currentObjects[num].featureNumbers = [featureNum, featureNum+1]
                currentObjects[num].features[0].tmpPositions = {frameNum: moving.Point(box.xyxy[0,0].item(), box.xyxy[0,1].item())}
                currentObjects[num].features[1].tmpPositions = {frameNum: moving.Point(box.xyxy[0,2].item(), box.xyxy[0,3].item())}
                featureNum += 2
    if args.display:
        cvutils.cvImshow(windowName, result.plot()) # original image in orig_img
        key = cv2.waitKey()
        if cvutils.quitKey(key):
            break
    frameNum += 1
    success, frame = capture.read()
    results = model.track(frame, persist=True)

# classification
for num, obj in currentObjects.items():
    #obj.setUserType(utils.mostCommon(obj.userTypes)) # improve? mix with speed?
    userTypeStats = Counter(obj.userTypes)
    if (4 in userTypeStats or (3 in userTypeStats and 4 in userTypeStats and userTypeStats[3]<=userTypeStats[4])) and userTypeStats[3]+userTypeStats[4] > args.bikeProportion*userTypeStats.total(): # 3 is motorcycle and 4 is cyclist (verif if not turning all motorbike into cyclists)
        obj.setUserType(4)
    else:
        obj.setUserType(userTypeStats.most_common()[0][0])

# merge bikes and people
#Construire graphe bipartite vélo/moto personne
#Lien = somme des iou / longueur track vélo 
#Algo Hongrois
#Verif overlap piéton vélo : si long, changement mode (trouver exemples)

# for all cyclists and motorbikes

# interpolate and generate velocity (?) before saving
for num, obj in currentObjects.items():
    obj.features[0].timeInterval = copy(obj.getTimeInterval())
    obj.features[1].timeInterval = copy(obj.getTimeInterval())
    if obj.length() != len(obj.features[0].tmpPositions): # interpolate
        obj.features[0].positions = moving.Trajectory.fromPointDict(obj.features[0].tmpPositions)
        obj.features[1].positions = moving.Trajectory.fromPointDict(obj.features[1].tmpPositions)
    else:
        obj.features[0].positions = moving.Trajectory.fromPointList(list(obj.features[0].tmpPositions.values()))
        obj.features[1].positions = moving.Trajectory.fromPointList(list(obj.features[1].tmpPositions.values()))
        
storage.saveTrajectoriesToSqlite(args.databaseFilename, list(currentObjects.values()), 'object')

# todo save bbox and mask to study localization / representation
# apply quality checks deviation and acceleration bounds?