118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
"""onnx_to_tensorrt.py
|
|
|
|
For converting a MODNet ONNX model to a TensorRT engine.
|
|
"""
|
|
|
|
|
|
import os
|
|
import argparse
|
|
|
|
import tensorrt as trt
|
|
|
|
if trt.__version__[0] < '7':
|
|
raise SystemExit('TensorRT version < 7')
|
|
|
|
|
|
BATCH_SIZE = 1
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command-line options and arguments."""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'-v', '--verbose', action='store_true',
|
|
help='enable verbose output (for debugging) [False]')
|
|
parser.add_argument(
|
|
'--int8', action='store_true',
|
|
help='build INT8 TensorRT engine [False]')
|
|
parser.add_argument(
|
|
'--dla_core', type=int, default=-1,
|
|
help='id of DLA core for inference, ranging from 0 to N-1 [-1]')
|
|
parser.add_argument(
|
|
'--width', type=int, default=640,
|
|
help='input image width of the model [640]')
|
|
parser.add_argument(
|
|
'--height', type=int, default=480,
|
|
help='input image height of the model [480]')
|
|
parser.add_argument(
|
|
'input_onnx', type=str, help='the input onnx file')
|
|
parser.add_argument(
|
|
'output_engine', type=str, help='the output TensorRT engine file')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def load_onnx(onnx_file_path):
|
|
"""Read the ONNX file."""
|
|
with open(onnx_file_path, 'rb') as f:
|
|
return f.read()
|
|
|
|
|
|
def set_net_batch(network, batch_size):
|
|
"""Set network input batch size."""
|
|
shape = list(network.get_input(0).shape)
|
|
shape[0] = batch_size
|
|
network.get_input(0).shape = shape
|
|
return network
|
|
|
|
|
|
def build_engine(onnx_file_path, width, height,
|
|
do_int8=False, dla_core=False, verbose=False):
|
|
"""Build a TensorRT engine from ONNX using the older API."""
|
|
onnx_data = load_onnx(onnx_file_path)
|
|
|
|
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
|
|
EXPLICIT_BATCH = [1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)]
|
|
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
|
|
if do_int8 and not builder.platform_has_fast_int8:
|
|
raise RuntimeError('INT8 not supported on this platform')
|
|
if not parser.parse(onnx_data):
|
|
print('ERROR: Failed to parse the ONNX file.')
|
|
for error in range(parser.num_errors):
|
|
print(parser.get_error(error))
|
|
return None
|
|
network = set_net_batch(network, BATCH_SIZE)
|
|
|
|
builder.max_batch_size = BATCH_SIZE
|
|
config = builder.create_builder_config()
|
|
config.max_workspace_size = 1 << 30
|
|
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
profile = builder.create_optimization_profile()
|
|
profile.set_shape(
|
|
'Input', # input tensor name
|
|
(BATCH_SIZE, 3, height, width), # min shape
|
|
(BATCH_SIZE, 3, height, width), # opt shape
|
|
(BATCH_SIZE, 3, height, width)) # max shape
|
|
config.add_optimization_profile(profile)
|
|
if do_int8:
|
|
raise RuntimeError('INT8 not implemented yet')
|
|
if dla_core >= 0:
|
|
raise RuntimeError('DLA_core not implemented yet')
|
|
engine = builder.build_engine(network, config)
|
|
|
|
return engine
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
if not os.path.isfile(args.input_onnx):
|
|
raise FileNotFoundError(args.input_onnx)
|
|
|
|
print('Building an engine. This would take a while...')
|
|
print('(Use "-v" or "--verbose" to enable verbose logging.)')
|
|
engine = build_engine(
|
|
args.input_onnx, args.width, args.height,
|
|
args.int8, args.dla_core, args.verbose)
|
|
if engine is None:
|
|
raise SystemExit('ERROR: failed to build the TensorRT engine!')
|
|
print('Completed creating engine.')
|
|
|
|
with open(args.output_engine, 'wb') as f:
|
|
f.write(engine.serialize())
|
|
print('Serialized the TensorRT engine to file: %s' % args.output_engine)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|