成功在NX上运行推理
parent
29c2033053
commit
f709af4300
|
@ -153,6 +153,31 @@ def draw_box(boxes, img, img_shape):
|
|||
cv2.waitKey(10)
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
def draw_box_save(boxes,img,img_shape,img_path):
|
||||
label = ["background", "person",
|
||||
"bicycle", "car", "motorbike", "aeroplane",
|
||||
"bus", "train", "truck", "boat", "traffic light",
|
||||
"fire hydrant", "stop sign", "parking meter", "bench",
|
||||
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
|
||||
"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag",
|
||||
"tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
|
||||
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
|
||||
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon",
|
||||
"bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog",
|
||||
"pizza", "donut", "cake", "chair", "sofa", "potted plant", "bed", "dining table",
|
||||
"toilet", "TV monitor", "laptop", "mouse", "remote", "keyboard", "cell phone",
|
||||
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
|
||||
"scissors", "teddy bear", "hair drier", "toothbrush"]
|
||||
for box in boxes:
|
||||
x1 = int((box[0] - box[2] / 2) * img_shape[1])
|
||||
y1 = int((box[1] - box[3] / 2) * img_shape[0])
|
||||
x2 = int((box[0] + box[2] / 2) * img_shape[1])
|
||||
y2 = int((box[1] + box[3] / 2) * img_shape[0])
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(img, label[int(box[5])] + ":" + str(round(box[4], 3)), (x1 + 5, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5, (0, 0, 255), 1)
|
||||
print(label[int(box[5])] + ":" + "概率值:%.3f" % box[4])
|
||||
cv2.imwrite(img_path,img)
|
||||
|
||||
# 获取预测框
|
||||
def get_boxes(prediction, anchors, img_shape, confidence_threshold=0.25, nms_threshold=0.6):
|
||||
|
|
|
@ -10,7 +10,7 @@ import os
|
|||
import time
|
||||
import logging
|
||||
import onnxruntime
|
||||
from darknet_api import process_img, get_boxes, draw_box
|
||||
from darknet_api import process_img, get_boxes, draw_box,draw_box_save
|
||||
|
||||
|
||||
# 定义日志格式
|
||||
|
@ -41,10 +41,13 @@ if __name__ == '__main__':
|
|||
# 照片的批量inference
|
||||
img_files_path = 'data/samples'
|
||||
imgs = os.listdir(img_files_path)
|
||||
img_output_path = 'output/'
|
||||
|
||||
logging.debug(imgs)
|
||||
for img_name in imgs:
|
||||
img_full_path = os.path.join(img_files_path, img_name)
|
||||
img_save_path = os.path.join(img_output_path,img_name)
|
||||
|
||||
logging.debug(img_full_path)
|
||||
img, img_shape, testdata = process_img(img_path=img_full_path,
|
||||
input_shape=input_shape)
|
||||
|
@ -58,4 +61,10 @@ if __name__ == '__main__':
|
|||
draw_box(boxes=boxes,
|
||||
img=img,
|
||||
img_shape=img_shape)
|
||||
draw_box_save(
|
||||
boxes=boxes,
|
||||
img=img,
|
||||
img_shape=img_shape,
|
||||
img_path=img_save_path
|
||||
)
|
||||
logging.info("推理照片 %s 耗时:% .2fms" % (img_name, ((time.time() - s)*1000)))
|
||||
|
|
|
@ -449,7 +449,7 @@ class GraphBuilderONNX(object):
|
|||
print(helper.printable_graph(self.graph_def))
|
||||
model_def = helper.make_model(self.graph_def,
|
||||
producer_name='NVIDIA TensorRT sample',
|
||||
opset_imports=[helper.make_opsetid(domain="", version=17)])
|
||||
opset_imports=[helper.make_opsetid(domain="", version=15)])
|
||||
return model_def
|
||||
|
||||
def _make_onnx_node(self, layer_name, layer_dict):
|
||||
|
|
Loading…
Reference in New Issue