You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

interface.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import logging
  16. import cv2
  17. import numpy as np
  18. import tensorflow as tf
  19. LOG = logging.getLogger(__name__)
  20. os.environ['BACKEND_TYPE'] = 'TENSORFLOW'
  21. flags = tf.flags.FLAGS
  22. def create_input_feed(sess, new_image, img_data):
  23. """Create input feed for edge model inference"""
  24. input_feed = {}
  25. input_img_data = sess.graph.get_tensor_by_name('images:0')
  26. input_feed[input_img_data] = new_image
  27. input_img_shape = sess.graph.get_tensor_by_name('shapes:0')
  28. input_feed[input_img_shape] = [img_data.shape[0], img_data.shape[1]]
  29. return input_feed
  30. def create_output_fetch(sess):
  31. """Create output fetch for edge model inference"""
  32. output_classes = sess.graph.get_tensor_by_name('concat_19:0')
  33. output_scores = sess.graph.get_tensor_by_name('concat_18:0')
  34. output_boxes = sess.graph.get_tensor_by_name('concat_17:0')
  35. output_fetch = [output_classes, output_scores, output_boxes]
  36. return output_fetch
  37. class Estimator:
  38. def __init__(self, **kwargs):
  39. """
  40. initialize logging configuration
  41. """
  42. graph = tf.Graph()
  43. config = tf.ConfigProto(allow_soft_placement=True)
  44. config.gpu_options.allow_growth = True
  45. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  46. self.session = tf.Session(graph=graph, config=config)
  47. self.input_shape = [416, 736]
  48. self.create_input_feed = create_input_feed
  49. self.create_output_fetch = create_output_fetch
  50. def load(self, model_url=""):
  51. with self.session.as_default():
  52. with self.session.graph.as_default():
  53. with tf.gfile.FastGFile(model_url, 'rb') as handle:
  54. LOG.info(f"Load model {model_url}, "
  55. f"ParseFromString start .......")
  56. graph_def = tf.GraphDef()
  57. graph_def.ParseFromString(handle.read())
  58. LOG.info("ParseFromString end .......")
  59. tf.import_graph_def(graph_def, name='')
  60. LOG.info("Import_graph_def end .......")
  61. LOG.info("Import model from pb end .......")
  62. @staticmethod
  63. def preprocess(image, input_shape):
  64. """Preprocess functions in edge model inference"""
  65. # resize image with unchanged aspect ratio using padding by opencv
  66. h, w, _ = image.shape
  67. input_h, input_w = input_shape
  68. scale = min(float(input_w) / float(w), float(input_h) / float(h))
  69. nw = int(w * scale)
  70. nh = int(h * scale)
  71. image = cv2.resize(image.astype(np.float32), (nw, nh))
  72. new_image = np.zeros((input_h, input_w, 3), np.float32)
  73. new_image.fill(128)
  74. bh, bw, _ = new_image.shape
  75. new_image[int((bh - nh) / 2):(nh + int((bh - nh) / 2)),
  76. int((bw - nw) / 2):(nw + int((bw - nw) / 2)), :] = image
  77. new_image /= 255.
  78. new_image = np.expand_dims(new_image, 0) # Add batch dimension.
  79. return new_image
  80. @staticmethod
  81. def postprocess(model_output):
  82. all_classes, all_scores, all_bboxes = model_output
  83. bboxes = []
  84. for c, s, bbox in zip(all_classes, all_scores, all_bboxes):
  85. bbox[0], bbox[1], bbox[2], bbox[3] = bbox[1].tolist(
  86. ), bbox[0].tolist(), bbox[3].tolist(), bbox[2].tolist()
  87. bboxes.append(bbox.tolist() + [s.tolist(), c.tolist()])
  88. return bboxes
  89. def predict(self, data, **kwargs):
  90. img_data_np = np.array(data)
  91. with self.session.as_default():
  92. new_image = self.preprocess(img_data_np, self.input_shape)
  93. input_feed = self.create_input_feed(
  94. self.session, new_image, img_data_np)
  95. output_fetch = self.create_output_fetch(self.session)
  96. output = self.session.run(output_fetch, input_feed)
  97. return self.postprocess(output)