"""trt_yolo_cv.py This script could be used to make object detection video with TensorRT optimized YOLO engine. "cv" means "create video" made by BigJoon (ref. jkjung-avt) """ import os import argparse import cv2 import pycuda.autoinit # This is needed for initializing CUDA driver from utils.yolo_classes import get_cls_dict from utils.visualization import BBoxVisualization from utils.yolo_with_plugins import TrtYOLO def parse_args(): """Parse input arguments.""" desc = ('Run the TensorRT optimized object detecion model on an input ' 'video and save BBoxed overlaid output as another video.') parser = argparse.ArgumentParser(description=desc) parser.add_argument( '-v', '--video', type=str, required=True, help='input video file name') parser.add_argument( '-o', '--output', type=str, required=True, help='output video file name') parser.add_argument( '-c', '--category_num', type=int, default=80, help='number of object categories [80]') parser.add_argument( '-m', '--model', type=str, required=True, help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|' 'yolov4-csp|yolov4x-mish|yolov4-p5]-[{dimension}], where ' '{dimension} could be either a single number (e.g. ' '288, 416, 608) or 2 numbers, WxH (e.g. 416x256)')) parser.add_argument( '-l', '--letter_box', action='store_true', help='inference with letterboxed image [False]') args = parser.parse_args() return args def loop_and_detect(cap, trt_yolo, conf_th, vis, writer): """Continuously capture images from camera and do object detection. # Arguments cap: the camera instance (video source). trt_yolo: the TRT YOLO object detector instance. conf_th: confidence/score threshold for object detection. vis: for visualization. writer: the VideoWriter object for the output video. """ while True: ret, frame = cap.read() if frame is None: break boxes, confs, clss = trt_yolo.detect(frame, conf_th) frame = vis.draw_bboxes(frame, boxes, confs, clss) writer.write(frame) print('.', end='', flush=True) print('\nDone.') def main(): args = parse_args() if args.category_num <= 0: raise SystemExit('ERROR: bad category_num (%d)!' % args.category_num) if not os.path.isfile('yolo/%s.trt' % args.model): raise SystemExit('ERROR: file (yolo/%s.trt) not found!' % args.model) cap = cv2.VideoCapture(args.video) if not cap.isOpened(): raise SystemExit('ERROR: failed to open the input video file!') frame_width, frame_height = int(cap.get(3)), int(cap.get(4)) writer = cv2.VideoWriter( args.output, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height)) cls_dict = get_cls_dict(args.category_num) vis = BBoxVisualization(cls_dict) trt_yolo = TrtYOLO(args.model, args.category_num, args.letter_box) loop_and_detect(cap, trt_yolo, conf_th=0.3, vis=vis, writer=writer) writer.release() cap.release() if __name__ == '__main__': main()