TensorRT-Demo/modnet/onnx_to_tensorrt.py

118 lines
4.0 KiB
Python
Raw Permalink Normal View History

2023-03-06 20:44:29 +08:00
"""onnx_to_tensorrt.py
For converting a MODNet ONNX model to a TensorRT engine.
"""
import os
import argparse
import tensorrt as trt
if trt.__version__[0] < '7':
raise SystemExit('TensorRT version < 7')
BATCH_SIZE = 1
def parse_args():
"""Parse command-line options and arguments."""
parser = argparse.ArgumentParser()
parser.add_argument(
'-v', '--verbose', action='store_true',
help='enable verbose output (for debugging) [False]')
parser.add_argument(
'--int8', action='store_true',
help='build INT8 TensorRT engine [False]')
parser.add_argument(
'--dla_core', type=int, default=-1,
help='id of DLA core for inference, ranging from 0 to N-1 [-1]')
parser.add_argument(
'--width', type=int, default=640,
help='input image width of the model [640]')
parser.add_argument(
'--height', type=int, default=480,
help='input image height of the model [480]')
parser.add_argument(
'input_onnx', type=str, help='the input onnx file')
parser.add_argument(
'output_engine', type=str, help='the output TensorRT engine file')
args = parser.parse_args()
return args
def load_onnx(onnx_file_path):
"""Read the ONNX file."""
with open(onnx_file_path, 'rb') as f:
return f.read()
def set_net_batch(network, batch_size):
"""Set network input batch size."""
shape = list(network.get_input(0).shape)
shape[0] = batch_size
network.get_input(0).shape = shape
return network
def build_engine(onnx_file_path, width, height,
do_int8=False, dla_core=False, verbose=False):
"""Build a TensorRT engine from ONNX using the older API."""
onnx_data = load_onnx(onnx_file_path)
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
EXPLICIT_BATCH = [1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
if do_int8 and not builder.platform_has_fast_int8:
raise RuntimeError('INT8 not supported on this platform')
if not parser.parse(onnx_data):
print('ERROR: Failed to parse the ONNX file.')
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
network = set_net_batch(network, BATCH_SIZE)
builder.max_batch_size = BATCH_SIZE
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
profile.set_shape(
'Input', # input tensor name
(BATCH_SIZE, 3, height, width), # min shape
(BATCH_SIZE, 3, height, width), # opt shape
(BATCH_SIZE, 3, height, width)) # max shape
config.add_optimization_profile(profile)
if do_int8:
raise RuntimeError('INT8 not implemented yet')
if dla_core >= 0:
raise RuntimeError('DLA_core not implemented yet')
engine = builder.build_engine(network, config)
return engine
def main():
args = parse_args()
if not os.path.isfile(args.input_onnx):
raise FileNotFoundError(args.input_onnx)
print('Building an engine. This would take a while...')
print('(Use "-v" or "--verbose" to enable verbose logging.)')
engine = build_engine(
args.input_onnx, args.width, args.height,
args.int8, args.dla_core, args.verbose)
if engine is None:
raise SystemExit('ERROR: failed to build the TensorRT engine!')
print('Completed creating engine.')
with open(args.output_engine, 'wb') as f:
f.write(engine.serialize())
print('Serialized the TensorRT engine to file: %s' % args.output_engine)
if __name__ == '__main__':
main()