You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interface.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import logging
  16. import cv2
  17. import numpy as np
  18. import tensorflow as tf
  19. LOG = logging.getLogger(__name__)
  20. os.environ['BACKEND_TYPE'] = 'TENSORFLOW'
  21. def preprocess(image, input_shape):
  22. ih, iw = input_shape
  23. h, w, _ = image.shape
  24. org_img_shape = (w, h)
  25. scale = min(iw / w, ih / h)
  26. nw, nh = int(scale * w), int(scale * h)
  27. image_resized = cv2.resize(image.astype(np.float32), (nw, nh))
  28. image_paded = np.full(shape=[ih, iw, 3], fill_value=128.0)
  29. dw, dh = (iw - nw) // 2, (ih - nh) // 2
  30. image_paded[dh:nh + dh, dw:nw + dw, :] = image_resized
  31. image_paded = image_paded / 255.
  32. preprocessed_data = image_paded.astype(np.float32)[np.newaxis, :]
  33. return preprocessed_data, org_img_shape
  34. def postprocess(data, org_img_shape):
  35. pred_sbbox, pred_mbbox, pred_lbbox = data[1], data[2], data[0]
  36. num_classes = 4
  37. score_threshold = 0.3
  38. input_size = 544
  39. pred_bbox = np.concatenate(
  40. [np.reshape(pred_sbbox, (-1, 5 + num_classes)),
  41. np.reshape(pred_mbbox, (-1, 5 + num_classes)),
  42. np.reshape(pred_lbbox, (-1, 5 + num_classes))], axis=0)
  43. valid_scale = [0, np.inf]
  44. pred_bbox = np.array(pred_bbox)
  45. pred_xywh = pred_bbox[:, 0:4]
  46. pred_conf = pred_bbox[:, 4]
  47. pred_prob = pred_bbox[:, 5:]
  48. pred_coor = np.concatenate(
  49. [pred_xywh[:, :2] - pred_xywh[:, 2:] * 0.5,
  50. pred_xywh[:, :2] + pred_xywh[:, 2:] * 0.5], axis=-1)
  51. org_w, org_h = org_img_shape
  52. resize_ratio = min(1.0 * input_size / org_w, 1.0 * input_size / org_h)
  53. dw = (input_size - resize_ratio * org_w) / 2.
  54. dh = (input_size - resize_ratio * org_h) / 2.
  55. pred_coor[:, 0::2] = 1.0 * (pred_coor[:, 0::2] - dw) / resize_ratio
  56. pred_coor[:, 1::2] = 1.0 * (pred_coor[:, 1::2] - dh) / resize_ratio
  57. # clip some boxes those are out of range
  58. pred_coor = np.concatenate(
  59. [np.maximum(pred_coor[:, :2], [0, 0]),
  60. np.minimum(pred_coor[:, 2:], [org_w - 1, org_h - 1])], axis=-1)
  61. invalid_mask = np.logical_or((pred_coor[:, 0] > pred_coor[:, 2]),
  62. (pred_coor[:, 1] > pred_coor[:, 3]))
  63. pred_coor[invalid_mask] = 0
  64. # discard some invalidboxes
  65. bboxes_scale = np.sqrt(
  66. np.multiply.reduce(pred_coor[:, 2:4] - pred_coor[:, 0:2], axis=-1))
  67. scale_mask = np.logical_and((valid_scale[0] < bboxes_scale),
  68. (bboxes_scale < valid_scale[1]))
  69. # discard some boxes with low scores
  70. classes = np.argmax(pred_prob, axis=-1)
  71. scores = pred_conf * pred_prob[np.arange(len(pred_coor)), classes]
  72. score_mask = scores > score_threshold
  73. mask = score_mask
  74. coors, scores, classes = pred_coor[mask], scores[mask], classes[mask]
  75. bboxes = np.concatenate(
  76. [coors, scores[:, np.newaxis], classes[:, np.newaxis]], axis=-1)
  77. bboxes = nms(bboxes, 0.4)
  78. return bboxes
  79. def bboxes_iou(boxes1, boxes2):
  80. boxes1 = np.array(boxes1)
  81. boxes2 = np.array(boxes2)
  82. boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (
  83. boxes1[..., 3] - boxes1[..., 1])
  84. boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (
  85. boxes2[..., 3] - boxes2[..., 1])
  86. left_up = np.maximum(boxes1[..., :2], boxes2[..., :2])
  87. right_down = np.minimum(boxes1[..., 2:], boxes2[..., 2:])
  88. inter_section = np.maximum(right_down - left_up, 0.0)
  89. inter_area = inter_section[..., 0] * inter_section[..., 1]
  90. union_area = boxes1_area + boxes2_area - inter_area
  91. ious = np.maximum(1.0 * inter_area / union_area, np.finfo(np.float32).eps)
  92. return ious
  93. def nms(bboxes, iou_threshold, sigma=0.3, method='nms'):
  94. """
  95. :param bboxes: (xmin, ymin, xmax, ymax, score, class)
  96. Note: soft-nms, https://arxiv.org/pdf/1704.04503.pdf
  97. https://github.com/bharatsingh430/soft-nms
  98. """
  99. classes_in_img = list(set(bboxes[:, 5]))
  100. best_bboxes = []
  101. for cls in classes_in_img:
  102. cls_mask = (bboxes[:, 5] == cls)
  103. cls_bboxes = bboxes[cls_mask]
  104. while len(cls_bboxes) > 0:
  105. max_ind = np.argmax(cls_bboxes[:, 4])
  106. best_bbox = cls_bboxes[max_ind]
  107. best_bbox_ = best_bbox.tolist()
  108. # cast into int for cls
  109. best_bbox_[5] = int(best_bbox[5])
  110. best_bboxes.append(best_bbox_)
  111. cls_bboxes = np.concatenate(
  112. [cls_bboxes[: max_ind], cls_bboxes[max_ind + 1:]])
  113. iou = bboxes_iou(best_bbox[np.newaxis, :4], cls_bboxes[:, :4])
  114. weight = np.ones((len(iou),), dtype=np.float32)
  115. assert method in ['nms', 'soft-nms']
  116. if method == 'nms':
  117. iou_mask = iou > iou_threshold
  118. weight[iou_mask] = 0.0
  119. if method == 'soft-nms':
  120. weight = np.exp(-(1.0 * iou ** 2 / sigma))
  121. cls_bboxes[:, 4] = cls_bboxes[:, 4] * weight
  122. score_mask = cls_bboxes[:, 4] > 0.
  123. cls_bboxes = cls_bboxes[score_mask]
  124. return best_bboxes
  125. def create_input_feed(sess, img_data, new_image=None):
  126. input_feed = {}
  127. input_img_data = sess.graph.get_tensor_by_name('input/input_data:0')
  128. input_feed[input_img_data] = img_data
  129. return input_feed
  130. def create_output_fetch(sess):
  131. """Create output fetch for edge model inference"""
  132. pred_sbbox = sess.graph.get_tensor_by_name('pred_sbbox/concat_2:0')
  133. pred_mbbox = sess.graph.get_tensor_by_name('pred_mbbox/concat_2:0')
  134. pred_lbbox = sess.graph.get_tensor_by_name('pred_lbbox/concat_2:0')
  135. output_fetch = [pred_sbbox, pred_mbbox, pred_lbbox]
  136. return output_fetch
  137. class Estimator:
  138. def __init__(self, **kwargs):
  139. """
  140. initialize logging configuration
  141. """
  142. graph = tf.Graph()
  143. config = tf.ConfigProto(allow_soft_placement=True)
  144. config.gpu_options.allow_growth = True
  145. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  146. self.session = tf.Session(graph=graph, config=config)
  147. self.input_shape = [544, 544]
  148. self.create_input_feed = create_input_feed
  149. self.create_output_fetch = create_output_fetch
  150. def load(self, model_url=""):
  151. with self.session.as_default():
  152. with self.session.graph.as_default():
  153. with tf.gfile.FastGFile(model_url, 'rb') as handle:
  154. LOG.info(f"Load model {model_url}, "
  155. f"ParseFromString start .......")
  156. graph_def = tf.GraphDef()
  157. graph_def.ParseFromString(handle.read())
  158. LOG.info("ParseFromString end .......")
  159. tf.import_graph_def(graph_def, name='')
  160. LOG.info("Import_graph_def end .......")
  161. LOG.info("Import model from pb end .......")
  162. def predict(self, data, **kwargs):
  163. img_data_np = np.array(data)
  164. new_image, shapes = preprocess(img_data_np, self.input_shape)
  165. with self.session.as_default():
  166. input_feed = self.create_input_feed(
  167. self.session, new_image, img_data_np)
  168. output_fetch = self.create_output_fetch(self.session)
  169. output = self.session.run(output_fetch, input_feed)
  170. return postprocess(output, shapes)