添加onnx转换注释;

main
邱棚 2023-03-06 23:33:45 +08:00
parent eb89f036bd
commit 4e9a5c2ba1
1 changed files with 44 additions and 5 deletions

View File

@ -123,6 +123,8 @@ def get_output_convs(layer_configs):
return output_convs return output_convs
# 从cfg中解析出分类数
# 返回值int 分类个数
def get_category_num(cfg_file_path): def get_category_num(cfg_file_path):
"""Find number of output classes of the yolo model.""" """Find number of output classes of the yolo model."""
with open(cfg_file_path, 'r') as f: with open(cfg_file_path, 'r') as f:
@ -184,12 +186,16 @@ class DarkNetParser(object):
# A list of YOLO layers containing dictionaries with all layer # A list of YOLO layers containing dictionaries with all layer
# parameters: # parameters:
# OrderedDic 可以保持键值对的顺序[]
self.layer_configs = OrderedDict() self.layer_configs = OrderedDict()
# 支持的节点类型
self.supported_layers = supported_layers if supported_layers else \ self.supported_layers = supported_layers if supported_layers else \
['net', 'convolutional', 'maxpool', 'shortcut', ['net', 'convolutional', 'maxpool', 'shortcut',
'route', 'upsample', 'yolo'] 'route', 'upsample', 'yolo']
self.layer_counter = 0 self.layer_counter = 0
# 加载网络模型文件.cfg
def parse_cfg_file(self, cfg_file_path): def parse_cfg_file(self, cfg_file_path):
"""Takes the yolov?.cfg file and parses it layer by layer, """Takes the yolov?.cfg file and parses it layer by layer,
appending each layer's parameters as a dictionary to layer_configs. appending each layer's parameters as a dictionary to layer_configs.
@ -200,11 +206,15 @@ class DarkNetParser(object):
with open(cfg_file_path, 'r') as cfg_file: with open(cfg_file_path, 'r') as cfg_file:
remainder = cfg_file.read() remainder = cfg_file.read()
while remainder is not None: while remainder is not None:
# 从字符串中加载一层网络,并生成字典结构
layer_dict, layer_name, remainder = self._next_layer(remainder) layer_dict, layer_name, remainder = self._next_layer(remainder)
if layer_dict is not None: if layer_dict is not None:
# 将一层网络结构根据名称,依次加入有序字典中,生成整个网络结构
self.layer_configs[layer_name] = layer_dict self.layer_configs[layer_name] = layer_dict
return self.layer_configs return self.layer_configs
# 返回当前层生成的字典键值对,并指向下一层结构
# layer_dict 一层内的网络结构字典, layer_name 层名称, remainder 剩余字符串
def _next_layer(self, remainder): def _next_layer(self, remainder):
"""Takes in a string and segments it by looking for DarkNet delimiters. """Takes in a string and segments it by looking for DarkNet delimiters.
Returns the layer parameters and the remaining string after the last delimiter. Returns the layer parameters and the remaining string after the last delimiter.
@ -560,23 +570,32 @@ class GraphBuilderONNX(object):
""" """
for layer_name in layer_configs.keys(): for layer_name in layer_configs.keys():
layer_dict = layer_configs[layer_name] layer_dict = layer_configs[layer_name]
# 根据网络结构分别生成onnx节点节点的作用是操作符如conv,relu等每个操作都要写成Onnx格式
major_node_specs = self._make_onnx_node(layer_name, layer_dict) major_node_specs = self._make_onnx_node(layer_name, layer_dict)
# 成功生成后,添加到主网络结构中
if major_node_specs.name is not None: if major_node_specs.name is not None:
self.major_node_specs.append(major_node_specs) self.major_node_specs.append(major_node_specs)
# remove dummy 'route' and 'yolo' nodes # remove dummy 'route' and 'yolo' nodes
self.major_node_specs = [node for node in self.major_node_specs self.major_node_specs = [node for node in self.major_node_specs
if 'dummy' not in node.name] if 'dummy' not in node.name]
outputs = list() outputs = list()
# 遍历输出字典中的名称
for tensor_name in self.output_tensors.keys(): for tensor_name in self.output_tensors.keys():
# 输出维度,例[batch,255,13,13]
output_dims = [self.batch_size, ] + \ output_dims = [self.batch_size, ] + \
self.output_tensors[tensor_name] self.output_tensors[tensor_name]
# 创建Onnx的"变量"ValueInfoProto
output_tensor = helper.make_tensor_value_info( output_tensor = helper.make_tensor_value_info(
tensor_name, TensorProto.FLOAT, output_dims) tensor_name, TensorProto.FLOAT, output_dims)
# 添加到输出列表中
outputs.append(output_tensor) outputs.append(output_tensor)
inputs = [self.input_tensor] inputs = [self.input_tensor]
# 加载权重到ndarray中weight文件如何存储的按顺序存储的二进制文件
weight_loader = WeightLoader(weights_file_path) weight_loader = WeightLoader(weights_file_path)
initializer = list() initializer = list()
# If a layer has parameters, add them to the initializer and input lists. # If a layer has parameters, add them to the initializer and input lists.
# 大概是按照layer生成各层级的节点信息并保存权重darknet格式到节点中
# initializer是包含权重信息的input则是ValueInfoProto可以理解为变量
for layer_name in self.param_dict.keys(): for layer_name in self.param_dict.keys():
_, layer_type = layer_name.split('_', 1) _, layer_type = layer_name.split('_', 1)
params = self.param_dict[layer_name] params = self.param_dict[layer_name]
@ -591,6 +610,7 @@ class GraphBuilderONNX(object):
initializer.extend(initializer_layer) initializer.extend(initializer_layer)
inputs.extend(inputs_layer) inputs.extend(inputs_layer)
del weight_loader del weight_loader
# 生成onnx图
self.graph_def = helper.make_graph( self.graph_def = helper.make_graph(
nodes=self._nodes, nodes=self._nodes,
name=self.model_name, name=self.model_name,
@ -1021,42 +1041,61 @@ def main():
raise SystemExit('ERROR: file (%s) not found!' % weights_file_path) raise SystemExit('ERROR: file (%s) not found!' % weights_file_path)
output_file_path = '%s.onnx' % args.model output_file_path = '%s.onnx' % args.model
# Darknet模型解释器(.cfg->.onnx格式)
print('Parsing DarkNet cfg file...') print('Parsing DarkNet cfg file...')
parser = DarkNetParser() parser = DarkNetParser()
# 从cfg文件中加载网络结构并按顺序存入字典中[layer_name,[key,value]]
layer_configs = parser.parse_cfg_file(cfg_file_path) layer_configs = parser.parse_cfg_file(cfg_file_path)
# 获取yolo输出分类个数单位int
category_num = get_category_num(cfg_file_path) category_num = get_category_num(cfg_file_path)
# 获取输出层名称从yolo层提取用于推算后获取结果
output_tensor_names = get_output_convs(layer_configs) output_tensor_names = get_output_convs(layer_configs)
# e.g. ['036_convolutional', '044_convolutional', '052_convolutional'] # e.g. ['036_convolutional', '044_convolutional', '052_convolutional']
# 获取输出维度 (80 + 5) * 3 = 255
c = (category_num + 5) * get_anchor_num(cfg_file_path) c = (category_num + 5) * get_anchor_num(cfg_file_path)
# 获取输入图像宽高
h, w = get_h_and_w(layer_configs) h, w = get_h_and_w(layer_configs)
# 获取输出格式
if len(output_tensor_names) == 2: if len(output_tensor_names) == 2:
# 2种候选框
output_tensor_shapes = [ output_tensor_shapes = [
[c, h // 32, w // 32], [c, h // 16, w // 16]] [c, h // 32, w // 32], [c, h // 16, w // 16]]
elif len(output_tensor_names) == 3: elif len(output_tensor_names) == 3:
# 3种候选框
output_tensor_shapes = [ output_tensor_shapes = [
[c, h // 32, w // 32], [c, h // 16, w // 16], [c, h // 32, w // 32], [c, h // 16, w // 16],
[c, h // 8, w // 8]] [c, h // 8, w // 8]]
elif len(output_tensor_names) == 4: elif len(output_tensor_names) == 4:
# 4种候选框
output_tensor_shapes = [ output_tensor_shapes = [
[c, h // 64, w // 64], [c, h // 32, w // 32], [c, h // 64, w // 64], [c, h // 32, w // 32],
[c, h // 16, w // 16], [c, h // 8, w // 8]] [c, h // 16, w // 16], [c, h // 8, w // 8]]
# 判断是金字塔模式(自下而上)还是倒金字塔模式(自上而下),决定了输出的顺序
if is_pan_arch(cfg_file_path): if is_pan_arch(cfg_file_path):
# 输出从大图到小图,改为小图先输出
output_tensor_shapes.reverse() output_tensor_shapes.reverse()
# 生成输出字典格式,以416为例 [036_convolutional2551313],[044_convolutional,255,26,26],[052_convolutional,255,52,52]
output_tensor_dims = OrderedDict( output_tensor_dims = OrderedDict(
zip(output_tensor_names, output_tensor_shapes)) zip(output_tensor_names, output_tensor_shapes))
# 创建ONNX生成器
print('Building ONNX graph...') print('Building ONNX graph...')
builder = GraphBuilderONNX( builder = GraphBuilderONNX(
args.model, output_tensor_dims, MAX_BATCH_SIZE) args.model, # Pytorch模型
output_tensor_dims, # 输出字典[[],[],[]]
MAX_BATCH_SIZE) # 最大Batch数量
# 编译ONNX模型
yolo_model_def = builder.build_onnx_graph( yolo_model_def = builder.build_onnx_graph(
layer_configs=layer_configs, layer_configs=layer_configs, # 网络层结构
weights_file_path=weights_file_path, weights_file_path=weights_file_path, # 网络权重
verbose=True) verbose=True) # 显示生成过程
# 检查生成的ONNX模型
print('Checking ONNX model...') print('Checking ONNX model...')
onnx.checker.check_model(yolo_model_def) onnx.checker.check_model(yolo_model_def)
# 保存ONNX模型
print('Saving ONNX file...') print('Saving ONNX file...')
onnx.save(yolo_model_def, output_file_path) onnx.save(yolo_model_def, output_file_path)