TensorRT-Demo/yolo/plugins.py

141 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""plugins.py
I referenced the code from https://github.com/dongfangduoshou123/YoloV3-TensorRT/blob/master/seralizeEngineFromPythonAPI.py
"""
import ctypes
import numpy as np
import tensorrt as trt
from yolo_to_onnx import (is_pan_arch, DarkNetParser, get_category_num,
get_h_and_w, get_output_convs, get_anchors)
try:
ctypes.cdll.LoadLibrary('../plugins/libyolo_layer.so')
except OSError as e:
raise SystemExit('ERROR: failed to load ../plugins/libyolo_layer.so. '
'Did you forget to do a "make" in the "../plugins/" '
'subdirectory?') from e
def get_scales(cfg_file_path):
"""Get scale_x_y's of all yolo layers from the cfg file."""
with open(cfg_file_path, 'r') as f:
cfg_lines = f.readlines()
yolo_lines = [l.strip() for l in cfg_lines if l.startswith('[yolo]')]
scale_lines = [l.strip() for l in cfg_lines if l.startswith('scale_x_y')]
if len(scale_lines) == 0:
return [1.0] * len(yolo_lines)
else:
assert len(scale_lines) == len(yolo_lines)
return [float(l.split('=')[-1]) for l in scale_lines]
def get_new_coords(cfg_file_path):
"""Get new_coords flag of yolo layers from the cfg file."""
with open(cfg_file_path, 'r') as f:
cfg_lines = f.readlines()
yolo_lines = [l.strip() for l in cfg_lines if l.startswith('[yolo]')]
newc_lines = [l.strip() for l in cfg_lines if l.startswith('new_coords')]
if len(newc_lines) == 0:
return 0
else:
assert len(newc_lines) == len(yolo_lines)
return int(newc_lines[-1].split('=')[-1])
def get_plugin_creator(plugin_name, logger):
"""Get the TensorRT plugin creator."""
trt.init_libnvinfer_plugins(logger, '')
plugin_creator_list = trt.get_plugin_registry().plugin_creator_list
for c in plugin_creator_list:
if c.name == plugin_name:
return c
return None
def add_yolo_plugins(network, model_name, logger):
"""Add yolo plugins into a TensorRT network."""
cfg_file_path = model_name + '.cfg'
parser = DarkNetParser()
layer_configs = parser.parse_cfg_file(cfg_file_path)
num_classes = get_category_num(cfg_file_path)
output_tensor_names = get_output_convs(layer_configs)
h, w = get_h_and_w(layer_configs)
if len(output_tensor_names) == 2:
yolo_whs = [
[w // 32, h // 32], [w // 16, h // 16]]
elif len(output_tensor_names) == 3:
yolo_whs = [
[w // 32, h // 32], [w // 16, h // 16],
[w // 8, h // 8]]
elif len(output_tensor_names) == 4:
yolo_whs = [
[w // 64, h // 64], [w // 32, h // 32],
[w // 16, h // 16], [w // 8, h // 8]]
else:
raise TypeError('bad number of outputs: %d' % len(output_tensor_names))
if is_pan_arch(cfg_file_path):
yolo_whs.reverse()
# 获取Anchor大小
anchors = get_anchors(cfg_file_path)
if len(anchors) != len(yolo_whs):
raise ValueError('bad number of yolo layers: %d vs. %d' %
(len(anchors), len(yolo_whs)))
if network.num_outputs != len(anchors):
raise ValueError('bad number of network outputs: %d vs. %d' %
(network.num_outputs, len(anchors)))
# 获取Scale大小
scales = get_scales(cfg_file_path)
if any([s < 1.0 for s in scales]):
raise ValueError('bad scale_x_y: %s' % str(scales))
if len(scales) != len(anchors):
raise ValueError('bad number of scales: %d vs. %d' %
(len(scales), len(anchors)))
# yolov4中的参数
new_coords = get_new_coords(cfg_file_path)
# 获取TRT插件
plugin_creator = get_plugin_creator('YoloLayer_TRT', logger)
if not plugin_creator:
raise RuntimeError('cannot get YoloLayer_TRT plugin creator')
old_tensors = [network.get_output(i) for i in range(network.num_outputs)]
new_tensors = [None] * network.num_outputs
for i, old_tensor in enumerate(old_tensors):
input_multiplier = w // yolo_whs[i][0]
new_tensors[i] = network.add_plugin_v2(
[old_tensor],
plugin_creator.create_plugin('YoloLayer_TRT', trt.PluginFieldCollection([
trt.PluginField("yoloWidth", np.array(yolo_whs[i][0], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("yoloHeight", np.array(yolo_whs[i][1], dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("inputMultiplier", np.array(input_multiplier, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("newCoords", np.array(new_coords, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("numClasses", np.array(num_classes, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("numAnchors", np.array(len(anchors[i]) // 2, dtype=np.int32), trt.PluginFieldType.INT32),
trt.PluginField("anchors", np.ascontiguousarray(anchors[i], dtype=np.float32), trt.PluginFieldType.FLOAT32),
trt.PluginField("scaleXY", np.array(scales[i], dtype=np.float32), trt.PluginFieldType.FLOAT32),
]))
).get_output(0)
# 更换trt模型中的yolo输出层
for new_tensor in new_tensors:
network.mark_output(new_tensor)
for old_tensor in old_tensors:
network.unmark_output(old_tensor)
return network
def add_concat(network, model_name, logger):
"""Add a final concatenation output into a TensorRT network."""
if network.num_outputs < 2 or network.num_outputs > 4:
raise TypeError('bad number of yolo layers: %d' % network.num_outputs)
yolo_tensors = [network.get_output(i) for i in range(network.num_outputs)]
concat_tensor = network.add_concatenation(yolo_tensors).get_output(0)
for yolo_tensor in yolo_tensors:
network.unmark_output(yolo_tensor)
concat_tensor.name = 'detections'
network.mark_output(concat_tensor)
return network