TensorRT-Demo/yolo/onnx_to_tensorrt.py

224 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# onnx_to_tensorrt.py
#
# Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
#
# NOTICE TO LICENSEE:
#
# This source code and/or documentation ("Licensed Deliverables") are
# subject to NVIDIA intellectual property rights under U.S. and
# international Copyright laws.
#
# These Licensed Deliverables contained herein is PROPRIETARY and
# CONFIDENTIAL to NVIDIA and is being provided under the terms and
# conditions of a form of NVIDIA software license agreement by and
# between NVIDIA and Licensee ("License Agreement") or electronically
# accepted by Licensee. Notwithstanding any terms or conditions to
# the contrary in the License Agreement, reproduction or disclosure
# of the Licensed Deliverables to any third party without the express
# written consent of NVIDIA is prohibited.
#
# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
# LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
# SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
# PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
# NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
# DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
# NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
# NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
# LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
# SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
# OF THESE LICENSED DELIVERABLES.
#
# U.S. Government End Users. These Licensed Deliverables are a
# "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
# 1995), consisting of "commercial computer software" and "commercial
# computer software documentation" as such terms are used in 48
# C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
# only as a commercial end item. Consistent with 48 C.F.R.12.212 and
# 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
# U.S. Government End Users acquire the Licensed Deliverables with
# only those rights set forth herein.
#
# Any use of the Licensed Deliverables in individual and commercial
# software must include, in the user documentation and internal
# comments to the code, the above Disclaimer and U.S. Government End
# Users Notice.
#
from __future__ import print_function
import os
import argparse
import tensorrt as trt
from yolo_to_onnx import DarkNetParser, get_h_and_w
from plugins import add_yolo_plugins, add_concat
MAX_BATCH_SIZE = 1
def get_c(layer_configs):
"""Find input channels of the yolo model from layer configs."""
net_config = layer_configs['000_net']
return net_config.get('channels', 3)
def load_onnx(model_name):
"""Read the ONNX file."""
onnx_path = '%s.onnx' % model_name
if not os.path.isfile(onnx_path):
print('ERROR: file (%s) not found! You might want to run yolo_to_onnx.py first to generate it.' % onnx_path)
return None
else:
with open(onnx_path, 'rb') as f:
return f.read()
def set_net_batch(network, batch_size):
"""Set network input batch size.
The ONNX file might have been generated with a different batch size,
say, 64.
"""
if trt.__version__[0] >= '7':
shape = list(network.get_input(0).shape)
shape[0] = batch_size
network.get_input(0).shape = shape
return network
def build_engine(model_name, do_int8, dla_core, verbose=False):
"""Build a TensorRT engine from ONNX using the older API."""
cfg_file_path = model_name + '.cfg'
# 初始化解释器
parser = DarkNetParser()
# 加载网络层信息
layer_configs = parser.parse_cfg_file(cfg_file_path)
# 获得输出层个数(几个候选框)
net_c = get_c(layer_configs)
# 获得输入的宽和高
net_h, net_w = get_h_and_w(layer_configs)
# 读取onnx模型
print('Loading the ONNX file...')
onnx_data = load_onnx(model_name)
if onnx_data is None:
return None
# 创建记录对象,用于调试和报错
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
# 明确输入的批次数
EXPLICIT_BATCH = [] if trt.__version__[0] < '7' else \
[1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]
# 解释onnx文件并生成trt网络
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
if do_int8 and not builder.platform_has_fast_int8:
raise RuntimeError('INT8 not supported on this platform')
if not parser.parse(onnx_data):
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
# 设置网络Batch
network = set_net_batch(network, MAX_BATCH_SIZE)
# 更换Yolo插件
print('Adding yolo_layer plugins.')
network = add_yolo_plugins(network, model_name, TRT_LOGGER)
# 将三个yolo输出层Concat到一起
print('Adding a concatenated output as "detections".')
network = add_concat(network, model_name, TRT_LOGGER)
# 命名输入层名称
print('Naming the input tensort as "input".')
network.get_input(0).name = 'input'
print('Building the TensorRT engine. This would take a while...')
print('(Use "--verbose" or "-v" to enable verbose logging.)')
if trt.__version__[0] < '7': # older API: build_cuda_engine()
if dla_core >= 0:
raise RuntimeError('DLA core not supported by old API')
builder.max_batch_size = MAX_BATCH_SIZE
builder.max_workspace_size = 1 << 30
builder.fp16_mode = True # alternative: builder.platform_has_fast_fp16
if do_int8:
from calibrator import YOLOEntropyCalibrator
builder.int8_mode = True
builder.int8_calibrator = YOLOEntropyCalibrator(
'calib_images', (net_h, net_w), 'calib_%s.bin' % model_name)
engine = builder.build_cuda_engine(network)
else: # new API: build_engine() with builder config
builder.max_batch_size = MAX_BATCH_SIZE
# 设置builder参数
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
profile.set_shape(
'input', # input tensor name
(MAX_BATCH_SIZE, net_c, net_h, net_w), # min shape
(MAX_BATCH_SIZE, net_c, net_h, net_w), # opt shape
(MAX_BATCH_SIZE, net_c, net_h, net_w)) # max shape
config.add_optimization_profile(profile)
if do_int8:
from calibrator import YOLOEntropyCalibrator
config.set_flag(trt.BuilderFlag.INT8)
config.int8_calibrator = YOLOEntropyCalibrator(
'calib_images', (net_h, net_w),
'calib_%s.bin' % model_name)
config.set_calibration_profile(profile)
if dla_core >= 0:
config.default_device_type = trt.DeviceType.DLA
config.DLA_core = dla_core
config.set_flag(trt.BuilderFlag.STRICT_TYPES)
print('Using DLA core %d.' % dla_core)
# 开始编译模型文件
engine = builder.build_engine(network, config)
if engine is not None:
print('Completed creating engine.')
return engine
def main():
"""Create a TensorRT engine for ONNX-based YOLO."""
parser = argparse.ArgumentParser()
parser.add_argument(
'-v', '--verbose', action='store_true',
help='enable verbose output (for debugging)')
parser.add_argument(
'-c', '--category_num', type=int,
help='number of object categories (obsolete)')
parser.add_argument(
'-m', '--model', type=str, required=True,
help=('[yolov3-tiny|yolov3|yolov3-spp|yolov4-tiny|yolov4|'
'yolov4-csp|yolov4x-mish|yolov4-p5]-[{dimension}], where '
'{dimension} could be either a single number (e.g. '
'288, 416, 608) or 2 numbers, WxH (e.g. 416x256)'))
parser.add_argument(
'--int8', action='store_true',
help='build INT8 TensorRT engine')
parser.add_argument(
'--dla_core', type=int, default=-1,
help='id of DLA core for inference (0 ~ N-1)')
args = parser.parse_args()
engine = build_engine(
args.model, args.int8, args.dla_core, args.verbose)
if engine is None:
raise SystemExit('ERROR: failed to build the TensorRT engine!')
engine_path = '%s.trt' % args.model
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
print('Serialized the TensorRT engine to file: %s' % engine_path)
if __name__ == '__main__':
main()