PyTorch-YOLOv3/detect_onnx.py

71 lines
2.3 KiB
Python
Raw Permalink Normal View History

2023-03-04 19:47:42 +08:00
# coding: utf-8
# author: hxy
# 2019-12-10
"""
照片的inference
默认推理过程在CPU上
"""
2023-03-04 18:54:44 +08:00
import os
import time
2023-03-04 19:47:42 +08:00
import logging
2023-03-04 18:54:44 +08:00
import onnxruntime
2023-03-06 11:41:08 +08:00
from darknet_api import process_img, get_boxes, draw_box,draw_box_save
2023-03-04 19:47:42 +08:00
# 定义日志格式
def log_set():
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 加载onnx模型
def load_model(onnx_model):
sess = onnxruntime.InferenceSession(onnx_model)
in_name = [input.name for input in sess.get_inputs()][0]
out_name = [output.name for output in sess.get_outputs()]
logging.info("输入的name:{}, 输出的name:{}".format(in_name, out_name))
return sess, in_name, out_name
if __name__ == '__main__':
log_set()
input_shape = (416 , 416)
# anchors
anchors_yolo = [[(116, 90), (156, 198), (373, 326)], [(30, 61), (62, 45), (59, 119)],
[(10, 13), (16, 30), (33, 23)]]
anchors_yolo_tiny = [[(81, 82), (135, 169), (344, 319)], [(10, 14), (23, 27), (37, 58)]]
session, inname, outname = load_model(onnx_model='output/yolov3-416.onnx')
logging.info("开始Inference....")
# 照片的批量inference
img_files_path = 'data/samples'
imgs = os.listdir(img_files_path)
2023-03-06 11:41:08 +08:00
img_output_path = 'output/'
2023-03-04 19:47:42 +08:00
logging.debug(imgs)
for img_name in imgs:
img_full_path = os.path.join(img_files_path, img_name)
2023-03-06 11:41:08 +08:00
img_save_path = os.path.join(img_output_path,img_name)
2023-03-04 19:47:42 +08:00
logging.debug(img_full_path)
img, img_shape, testdata = process_img(img_path=img_full_path,
input_shape=input_shape)
s = time.time()
prediction = session.run(outname, {inname: testdata})
# logging.info("推理照片 %s 耗时:% .2fms" % (img_name, ((time.time() - s)*1000)))
boxes = get_boxes(prediction=prediction,
anchors=anchors_yolo,
img_shape=input_shape)
draw_box(boxes=boxes,
img=img,
img_shape=img_shape)
2023-03-06 11:41:08 +08:00
draw_box_save(
boxes=boxes,
img=img,
img_shape=img_shape,
img_path=img_save_path
)
2023-03-04 19:47:42 +08:00
logging.info("推理照片 %s 耗时:% .2fms" % (img_name, ((time.time() - s)*1000)))