98 lines
3.1 KiB
Python
98 lines
3.1 KiB
Python
|
"""trt_yolo_cv.py
|
||
|
|
||
|
This script could be used to make object detection video with
|
||
|
TensorRT optimized YOLO engine.
|
||
|
|
||
|
"cv" means "create video"
|
||
|
made by BigJoon (ref. jkjung-avt)
|
||
|
"""
|
||
|
|
||
|
|
||
|
import os
|
||
|
import argparse
|
||
|
|
||
|
import cv2
|
||
|
import pycuda.autoinit # This is needed for initializing CUDA driver
|
||
|
|
||
|
from utils.yolo_classes import get_cls_dict
|
||
|
from utils.visualization import BBoxVisualization
|
||
|
from utils.yolo_with_plugins import TrtYOLO
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
"""Parse input arguments."""
|
||
|
desc = ('Run the TensorRT optimized object detecion model on an input '
|
||
|
'video and save BBoxed overlaid output as another video.')
|
||
|
parser = argparse.ArgumentParser(description=desc)
|
||
|
parser.add_argument(
|
||
|
'-v', '--video', type=str, required=True,
|
||
|
help='input video file name')
|
||
|
parser.add_argument(
|
||
|
'-o', '--output', type=str, required=True,
|
||
|
help='output video file name')
|
||
|
parser.add_argument(
|
||
|
'-c', '--category_num', type=int, default=80,
|
||
|
help='number of object categories [80]')
|
||
|
parser.add_argument(
|
||
|
'-m', '--model', type=str, required=True,
|
||
|
help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|'
|
||
|
'yolov4-csp|yolov4x-mish|yolov4-p5]-[{dimension}], where '
|
||
|
'{dimension} could be either a single number (e.g. '
|
||
|
'288, 416, 608) or 2 numbers, WxH (e.g. 416x256)'))
|
||
|
parser.add_argument(
|
||
|
'-l', '--letter_box', action='store_true',
|
||
|
help='inference with letterboxed image [False]')
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def loop_and_detect(cap, trt_yolo, conf_th, vis, writer):
|
||
|
"""Continuously capture images from camera and do object detection.
|
||
|
|
||
|
# Arguments
|
||
|
cap: the camera instance (video source).
|
||
|
trt_yolo: the TRT YOLO object detector instance.
|
||
|
conf_th: confidence/score threshold for object detection.
|
||
|
vis: for visualization.
|
||
|
writer: the VideoWriter object for the output video.
|
||
|
"""
|
||
|
|
||
|
while True:
|
||
|
ret, frame = cap.read()
|
||
|
if frame is None: break
|
||
|
boxes, confs, clss = trt_yolo.detect(frame, conf_th)
|
||
|
frame = vis.draw_bboxes(frame, boxes, confs, clss)
|
||
|
writer.write(frame)
|
||
|
print('.', end='', flush=True)
|
||
|
|
||
|
print('\nDone.')
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
if args.category_num <= 0:
|
||
|
raise SystemExit('ERROR: bad category_num (%d)!' % args.category_num)
|
||
|
if not os.path.isfile('yolo/%s.trt' % args.model):
|
||
|
raise SystemExit('ERROR: file (yolo/%s.trt) not found!' % args.model)
|
||
|
|
||
|
cap = cv2.VideoCapture(args.video)
|
||
|
if not cap.isOpened():
|
||
|
raise SystemExit('ERROR: failed to open the input video file!')
|
||
|
frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
|
||
|
writer = cv2.VideoWriter(
|
||
|
args.output,
|
||
|
cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height))
|
||
|
|
||
|
cls_dict = get_cls_dict(args.category_num)
|
||
|
vis = BBoxVisualization(cls_dict)
|
||
|
trt_yolo = TrtYOLO(args.model, args.category_num, args.letter_box)
|
||
|
|
||
|
loop_and_detect(cap, trt_yolo, conf_th=0.3, vis=vis, writer=writer)
|
||
|
|
||
|
writer.release()
|
||
|
cap.release()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|