2023-03-06 20:44:29 +08:00
|
|
|
"""trt_yolo.py
|
|
|
|
|
|
|
|
This script demonstrates how to do real-time object detection with
|
|
|
|
TensorRT optimized YOLO engine.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import pycuda.autoinit # This is needed for initializing CUDA driver
|
|
|
|
|
|
|
|
from utils.yolo_classes import get_cls_dict
|
|
|
|
from utils.camera import add_camera_args, Camera
|
|
|
|
from utils.display import open_window, set_display, show_fps
|
|
|
|
from utils.visualization import BBoxVisualization
|
|
|
|
from utils.yolo_with_plugins import TrtYOLO
|
|
|
|
|
|
|
|
|
|
|
|
WINDOW_NAME = 'TrtYOLODemo'
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
"""Parse input arguments."""
|
|
|
|
desc = ('Capture and display live camera video, while doing '
|
|
|
|
'real-time object detection with TensorRT optimized '
|
|
|
|
'YOLO model on Jetson')
|
|
|
|
parser = argparse.ArgumentParser(description=desc)
|
|
|
|
parser = add_camera_args(parser)
|
|
|
|
parser.add_argument(
|
|
|
|
'-c', '--category_num', type=int, default=80,
|
|
|
|
help='number of object categories [80]')
|
|
|
|
parser.add_argument(
|
|
|
|
'-t', '--conf_thresh', type=float, default=0.3,
|
|
|
|
help='set the detection confidence threshold')
|
|
|
|
parser.add_argument(
|
|
|
|
'-m', '--model', type=str, required=True,
|
|
|
|
help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|'
|
|
|
|
'yolov4-csp|yolov4x-mish|yolov4-p5]-[{dimension}], where '
|
|
|
|
'{dimension} could be either a single number (e.g. '
|
|
|
|
'288, 416, 608) or 2 numbers, WxH (e.g. 416x256)'))
|
|
|
|
parser.add_argument(
|
|
|
|
'-l', '--letter_box', action='store_true',
|
|
|
|
help='inference with letterboxed image [False]')
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
def loop_and_detect(cam, trt_yolo, conf_th, vis):
|
|
|
|
"""Continuously capture images from camera and do object detection.
|
|
|
|
|
|
|
|
# Arguments
|
|
|
|
cam: the camera instance (video source).
|
|
|
|
trt_yolo: the TRT YOLO object detector instance.
|
|
|
|
conf_th: confidence/score threshold for object detection.
|
|
|
|
vis: for visualization.
|
|
|
|
"""
|
|
|
|
full_scrn = False
|
|
|
|
fps = 0.0
|
|
|
|
tic = time.time()
|
|
|
|
while True:
|
|
|
|
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0:
|
|
|
|
break
|
|
|
|
img = cam.read()
|
|
|
|
if img is None:
|
|
|
|
break
|
2023-03-07 13:56:20 +08:00
|
|
|
# 获取trt推理结果
|
2023-03-06 20:44:29 +08:00
|
|
|
boxes, confs, clss = trt_yolo.detect(img, conf_th)
|
2023-03-07 13:56:20 +08:00
|
|
|
# 绘制识别框
|
2023-03-06 20:44:29 +08:00
|
|
|
img = vis.draw_bboxes(img, boxes, confs, clss)
|
|
|
|
img = show_fps(img, fps)
|
|
|
|
cv2.imshow(WINDOW_NAME, img)
|
|
|
|
toc = time.time()
|
|
|
|
curr_fps = 1.0 / (toc - tic)
|
|
|
|
# calculate an exponentially decaying average of fps number
|
|
|
|
fps = curr_fps if fps == 0.0 else (fps*0.95 + curr_fps*0.05)
|
|
|
|
tic = toc
|
|
|
|
key = cv2.waitKey(1)
|
|
|
|
if key == 27: # ESC key: quit program
|
|
|
|
break
|
|
|
|
elif key == ord('F') or key == ord('f'): # Toggle fullscreen
|
|
|
|
full_scrn = not full_scrn
|
|
|
|
set_display(WINDOW_NAME, full_scrn)
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
if args.category_num <= 0:
|
|
|
|
raise SystemExit('ERROR: bad category_num (%d)!' % args.category_num)
|
|
|
|
if not os.path.isfile('yolo/%s.trt' % args.model):
|
|
|
|
raise SystemExit('ERROR: file (yolo/%s.trt) not found!' % args.model)
|
|
|
|
|
|
|
|
cam = Camera(args)
|
|
|
|
if not cam.isOpened():
|
|
|
|
raise SystemExit('ERROR: failed to open camera!')
|
2023-03-07 13:56:20 +08:00
|
|
|
# 加载COCO数据集标签
|
2023-03-06 20:44:29 +08:00
|
|
|
cls_dict = get_cls_dict(args.category_num)
|
2023-03-07 13:56:20 +08:00
|
|
|
# 加载框绘制器
|
2023-03-06 20:44:29 +08:00
|
|
|
vis = BBoxVisualization(cls_dict)
|
2023-03-07 13:56:20 +08:00
|
|
|
# 创建推理引擎
|
2023-03-06 20:44:29 +08:00
|
|
|
trt_yolo = TrtYOLO(args.model, args.category_num, args.letter_box)
|
2023-03-07 13:56:20 +08:00
|
|
|
# 打开窗口
|
2023-03-06 20:44:29 +08:00
|
|
|
open_window(
|
|
|
|
WINDOW_NAME, 'Camera TensorRT YOLO Demo',
|
|
|
|
cam.img_width, cam.img_height)
|
2023-03-07 13:56:20 +08:00
|
|
|
# 开始推理
|
2023-03-06 20:44:29 +08:00
|
|
|
loop_and_detect(cam, trt_yolo, args.conf_thresh, vis=vis)
|
|
|
|
|
|
|
|
cam.release()
|
|
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|