60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
|
"""ssd_tf.py
|
||
|
|
||
|
This module implements the TfSSD class.
|
||
|
"""
|
||
|
|
||
|
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
import tensorflow as tf
|
||
|
|
||
|
|
||
|
def _postprocess_tf(img, boxes, scores, classes, conf_th):
|
||
|
"""Postprocess TensorFlow SSD output."""
|
||
|
h, w, _ = img.shape
|
||
|
out_boxes = boxes[0] * np.array([h, w, h, w])
|
||
|
out_boxes = out_boxes.astype(np.int32)
|
||
|
out_boxes = out_boxes[:, [1, 0, 3, 2]] # swap x's and y's
|
||
|
out_confs = scores[0]
|
||
|
out_clss = classes[0].astype(np.int32)
|
||
|
|
||
|
# only return bboxes with confidence score above threshold
|
||
|
mask = np.where(out_confs >= conf_th)
|
||
|
return out_boxes[mask], out_confs[mask], out_clss[mask]
|
||
|
|
||
|
|
||
|
class TfSSD(object):
|
||
|
"""TfSSD class encapsulates things needed to run TensorFlow SSD."""
|
||
|
|
||
|
def __init__(self, model, input_shape):
|
||
|
self.model = model
|
||
|
self.input_shape = input_shape
|
||
|
|
||
|
# load detection graph
|
||
|
ssd_graph = tf.Graph()
|
||
|
with ssd_graph.as_default():
|
||
|
graph_def = tf.GraphDef()
|
||
|
with tf.gfile.GFile('ssd/%s.pb' % model, 'rb') as fid:
|
||
|
serialized_graph = fid.read()
|
||
|
graph_def.ParseFromString(serialized_graph)
|
||
|
tf.import_graph_def(graph_def, name='')
|
||
|
|
||
|
# define input/output tensors
|
||
|
self.image_tensor = ssd_graph.get_tensor_by_name('image_tensor:0')
|
||
|
self.det_boxes = ssd_graph.get_tensor_by_name('detection_boxes:0')
|
||
|
self.det_scores = ssd_graph.get_tensor_by_name('detection_scores:0')
|
||
|
self.det_classes = ssd_graph.get_tensor_by_name('detection_classes:0')
|
||
|
|
||
|
# create the session for inferencing
|
||
|
self.sess = tf.Session(graph=ssd_graph)
|
||
|
|
||
|
def __del__(self):
|
||
|
self.sess.close()
|
||
|
|
||
|
def detect(self, img, conf_th):
|
||
|
img_resized = _preprocess_tf(img, self.input_shape)
|
||
|
boxes, scores, classes = self.sess.run(
|
||
|
[self.det_boxes, self.det_scores, self.det_classes],
|
||
|
feed_dict={self.image_tensor: np.expand_dims(img_resized, 0)})
|
||
|
return _postprocess_tf(img, boxes, scores, classes, conf_th)
|