添加onnx转换注释;
parent
eb89f036bd
commit
4e9a5c2ba1
|
@ -123,6 +123,8 @@ def get_output_convs(layer_configs):
|
|||
return output_convs
|
||||
|
||||
|
||||
# 从cfg中解析出分类数
|
||||
# 返回值:int 分类个数
|
||||
def get_category_num(cfg_file_path):
|
||||
"""Find number of output classes of the yolo model."""
|
||||
with open(cfg_file_path, 'r') as f:
|
||||
|
@ -184,12 +186,16 @@ class DarkNetParser(object):
|
|||
|
||||
# A list of YOLO layers containing dictionaries with all layer
|
||||
# parameters:
|
||||
|
||||
# OrderedDic 可以保持键值对的顺序[]
|
||||
self.layer_configs = OrderedDict()
|
||||
# 支持的节点类型
|
||||
self.supported_layers = supported_layers if supported_layers else \
|
||||
['net', 'convolutional', 'maxpool', 'shortcut',
|
||||
'route', 'upsample', 'yolo']
|
||||
self.layer_counter = 0
|
||||
|
||||
# 加载网络模型文件.cfg
|
||||
def parse_cfg_file(self, cfg_file_path):
|
||||
"""Takes the yolov?.cfg file and parses it layer by layer,
|
||||
appending each layer's parameters as a dictionary to layer_configs.
|
||||
|
@ -200,11 +206,15 @@ class DarkNetParser(object):
|
|||
with open(cfg_file_path, 'r') as cfg_file:
|
||||
remainder = cfg_file.read()
|
||||
while remainder is not None:
|
||||
# 从字符串中加载一层网络,并生成字典结构
|
||||
layer_dict, layer_name, remainder = self._next_layer(remainder)
|
||||
if layer_dict is not None:
|
||||
# 将一层网络结构根据名称,依次加入有序字典中,生成整个网络结构
|
||||
self.layer_configs[layer_name] = layer_dict
|
||||
return self.layer_configs
|
||||
|
||||
# 返回当前层生成的字典键值对,并指向下一层结构
|
||||
# layer_dict 一层内的网络结构字典, layer_name 层名称, remainder 剩余字符串
|
||||
def _next_layer(self, remainder):
|
||||
"""Takes in a string and segments it by looking for DarkNet delimiters.
|
||||
Returns the layer parameters and the remaining string after the last delimiter.
|
||||
|
@ -560,23 +570,32 @@ class GraphBuilderONNX(object):
|
|||
"""
|
||||
for layer_name in layer_configs.keys():
|
||||
layer_dict = layer_configs[layer_name]
|
||||
# 根据网络结构,分别生成onnx节点,节点的作用是操作符,如conv,relu等,每个操作都要写成Onnx格式
|
||||
major_node_specs = self._make_onnx_node(layer_name, layer_dict)
|
||||
# 成功生成后,添加到主网络结构中
|
||||
if major_node_specs.name is not None:
|
||||
self.major_node_specs.append(major_node_specs)
|
||||
# remove dummy 'route' and 'yolo' nodes
|
||||
self.major_node_specs = [node for node in self.major_node_specs
|
||||
if 'dummy' not in node.name]
|
||||
outputs = list()
|
||||
# 遍历输出字典中的名称
|
||||
for tensor_name in self.output_tensors.keys():
|
||||
# 输出维度,例[batch,255,13,13]
|
||||
output_dims = [self.batch_size, ] + \
|
||||
self.output_tensors[tensor_name]
|
||||
# 创建Onnx的"变量"(ValueInfoProto)
|
||||
output_tensor = helper.make_tensor_value_info(
|
||||
tensor_name, TensorProto.FLOAT, output_dims)
|
||||
# 添加到输出列表中
|
||||
outputs.append(output_tensor)
|
||||
inputs = [self.input_tensor]
|
||||
# 加载权重到ndarray中,weight文件如何存储的??按顺序存储的二进制文件
|
||||
weight_loader = WeightLoader(weights_file_path)
|
||||
initializer = list()
|
||||
# If a layer has parameters, add them to the initializer and input lists.
|
||||
# If a layer has parameters, add them to the initializer and input lists. ???
|
||||
# 大概是按照layer生成各层级的节点信息,并保存权重(darknet格式)到节点中
|
||||
# initializer是包含权重信息的,input则是ValueInfoProto,可以理解为变量
|
||||
for layer_name in self.param_dict.keys():
|
||||
_, layer_type = layer_name.split('_', 1)
|
||||
params = self.param_dict[layer_name]
|
||||
|
@ -591,6 +610,7 @@ class GraphBuilderONNX(object):
|
|||
initializer.extend(initializer_layer)
|
||||
inputs.extend(inputs_layer)
|
||||
del weight_loader
|
||||
# 生成onnx图
|
||||
self.graph_def = helper.make_graph(
|
||||
nodes=self._nodes,
|
||||
name=self.model_name,
|
||||
|
@ -1021,42 +1041,61 @@ def main():
|
|||
raise SystemExit('ERROR: file (%s) not found!' % weights_file_path)
|
||||
output_file_path = '%s.onnx' % args.model
|
||||
|
||||
# Darknet模型解释器(.cfg->.onnx格式)
|
||||
print('Parsing DarkNet cfg file...')
|
||||
parser = DarkNetParser()
|
||||
# 从cfg文件中加载网络结构,并按顺序存入字典中[layer_name,[key,value]]
|
||||
layer_configs = parser.parse_cfg_file(cfg_file_path)
|
||||
# 获取yolo输出分类个数,单位int
|
||||
category_num = get_category_num(cfg_file_path)
|
||||
# 获取输出层名称从yolo层提取,用于推算后获取结果
|
||||
output_tensor_names = get_output_convs(layer_configs)
|
||||
# e.g. ['036_convolutional', '044_convolutional', '052_convolutional']
|
||||
|
||||
# 获取输出维度 (80 + 5) * 3 = 255
|
||||
c = (category_num + 5) * get_anchor_num(cfg_file_path)
|
||||
# 获取输入图像宽高
|
||||
h, w = get_h_and_w(layer_configs)
|
||||
# 获取输出格式
|
||||
if len(output_tensor_names) == 2:
|
||||
# 2种候选框
|
||||
output_tensor_shapes = [
|
||||
[c, h // 32, w // 32], [c, h // 16, w // 16]]
|
||||
elif len(output_tensor_names) == 3:
|
||||
# 3种候选框
|
||||
output_tensor_shapes = [
|
||||
[c, h // 32, w // 32], [c, h // 16, w // 16],
|
||||
[c, h // 8, w // 8]]
|
||||
elif len(output_tensor_names) == 4:
|
||||
# 4种候选框
|
||||
output_tensor_shapes = [
|
||||
[c, h // 64, w // 64], [c, h // 32, w // 32],
|
||||
[c, h // 16, w // 16], [c, h // 8, w // 8]]
|
||||
# 判断是金字塔模式(自下而上)还是倒金字塔模式(自上而下),决定了输出的顺序
|
||||
if is_pan_arch(cfg_file_path):
|
||||
# 输出从大图到小图,改为小图先输出
|
||||
output_tensor_shapes.reverse()
|
||||
# 生成输出字典格式,以416为例 [036_convolutional,255,13,13],[044_convolutional,255,26,26],[052_convolutional,255,52,52]
|
||||
output_tensor_dims = OrderedDict(
|
||||
zip(output_tensor_names, output_tensor_shapes))
|
||||
|
||||
# 创建ONNX生成器
|
||||
print('Building ONNX graph...')
|
||||
builder = GraphBuilderONNX(
|
||||
args.model, output_tensor_dims, MAX_BATCH_SIZE)
|
||||
args.model, # Pytorch模型
|
||||
output_tensor_dims, # 输出字典[[],[],[]]
|
||||
MAX_BATCH_SIZE) # 最大Batch数量
|
||||
# 编译ONNX模型
|
||||
yolo_model_def = builder.build_onnx_graph(
|
||||
layer_configs=layer_configs,
|
||||
weights_file_path=weights_file_path,
|
||||
verbose=True)
|
||||
layer_configs=layer_configs, # 网络层结构
|
||||
weights_file_path=weights_file_path, # 网络权重
|
||||
verbose=True) # 显示生成过程
|
||||
|
||||
# 检查生成的ONNX模型
|
||||
print('Checking ONNX model...')
|
||||
onnx.checker.check_model(yolo_model_def)
|
||||
|
||||
# 保存ONNX模型
|
||||
print('Saving ONNX file...')
|
||||
onnx.save(yolo_model_def, output_file_path)
|
||||
|
||||
|
|
Loading…
Reference in New Issue