You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import logging
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. import neptune
  7. from neptune.incremental_learning import InferenceResult
  8. LOG = logging.getLogger(__name__)
  9. he_saved_url = neptune.context.get_parameters('HE_SAVED_URL')
  10. class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']
  11. def draw_boxes(img, labels, scores, bboxes, class_names, colors):
  12. line_type = 2
  13. text_thickness = 1
  14. box_thickness = 1
  15. # get color code
  16. colors = colors.split(",")
  17. colors_code = []
  18. for color in colors:
  19. if color == 'green':
  20. colors_code.append((0, 255, 0))
  21. elif color == 'blue':
  22. colors_code.append((255, 0, 0))
  23. elif color == 'yellow':
  24. colors_code.append((0, 255, 255))
  25. else:
  26. colors_code.append((0, 0, 255))
  27. label_dict = {i: label for i, label in enumerate(class_names)}
  28. for i in range(bboxes.shape[0]):
  29. bbox = bboxes[i]
  30. if float("inf") in bbox or float("-inf") in bbox:
  31. continue
  32. label = int(labels[i])
  33. score = "%.2f" % round(scores[i], 2)
  34. text = label_dict.get(label) + ":" + score
  35. p1 = (int(bbox[0]), int(bbox[1]))
  36. p2 = (int(bbox[2]), int(bbox[3]))
  37. if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
  38. continue
  39. cv2.rectangle(img, p1[::-1], p2[::-1], colors_code[labels[i]],
  40. box_thickness)
  41. cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)),
  42. cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
  43. text_thickness, line_type)
  44. return img
  45. def preprocess(image, input_shape):
  46. """Preprocess functions in edge model inference"""
  47. # resize image with unchanged aspect ratio using padding by opencv
  48. h, w, _ = image.shape
  49. input_h, input_w = input_shape
  50. scale = min(float(input_w) / float(w), float(input_h) / float(h))
  51. nw = int(w * scale)
  52. nh = int(h * scale)
  53. image = cv2.resize(image, (nw, nh))
  54. new_image = np.zeros((input_h, input_w, 3), np.float32)
  55. new_image.fill(128)
  56. bh, bw, _ = new_image.shape
  57. new_image[int((bh - nh) / 2):(nh + int((bh - nh) / 2)),
  58. int((bw - nw) / 2):(nw + int((bw - nw) / 2)), :] = image
  59. new_image /= 255.
  60. new_image = np.expand_dims(new_image, 0) # Add batch dimension.
  61. return new_image
  62. def create_input_feed(sess, new_image, img_data):
  63. """Create input feed for edge model inference"""
  64. input_feed = {}
  65. input_img_data = sess.graph.get_tensor_by_name('images:0')
  66. input_feed[input_img_data] = new_image
  67. input_img_shape = sess.graph.get_tensor_by_name('shapes:0')
  68. input_feed[input_img_shape] = [img_data.shape[0], img_data.shape[1]]
  69. return input_feed
  70. def create_output_fetch(sess):
  71. """Create output fetch for edge model inference"""
  72. output_classes = sess.graph.get_tensor_by_name('output/classes:0')
  73. output_scores = sess.graph.get_tensor_by_name('output/scores:0')
  74. output_boxes = sess.graph.get_tensor_by_name('output/boxes:0')
  75. output_fetch = [output_classes, output_scores, output_boxes]
  76. return output_fetch
  77. def output_deal(inference_result: InferenceResult, nframe, img_rgb):
  78. # save and show image
  79. img_rgb = np.array(img_rgb)
  80. img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
  81. colors = 'yellow,blue,green,red'
  82. if inference_result.is_hard_example:
  83. lables, scores, bbox_list_pred = inference_result.infer_result
  84. img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names,
  85. colors)
  86. cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img)
  87. def mkdir(path):
  88. path = path.strip()
  89. path = path.rstrip()
  90. is_exists = os.path.exists(path)
  91. if not is_exists:
  92. os.makedirs(path)
  93. LOG.info(f"{path} is not exists, create the dir")
  94. def run():
  95. input_shape_str = neptune.context.get_parameters("input_shape")
  96. input_shape = tuple(int(v) for v in input_shape_str.split(","))
  97. camera_address = neptune.context.get_parameters('video_url')
  98. mkdir(he_saved_url)
  99. # create little model object
  100. model = neptune.incremental_learning.TSModel(
  101. preprocess=preprocess,
  102. input_shape=input_shape,
  103. create_input_feed=create_input_feed,
  104. create_output_fetch=create_output_fetch
  105. )
  106. # create inference object
  107. inference_instance = neptune.incremental_learning.Inference(model)
  108. # use video streams for testing
  109. camera = cv2.VideoCapture(camera_address)
  110. fps = 10
  111. nframe = 0
  112. # the input of video stream
  113. while 1:
  114. ret, input_yuv = camera.read()
  115. if not ret:
  116. LOG.info(
  117. f"camera is not open, camera_address={camera_address},"
  118. f" sleep 5 second.")
  119. time.sleep(5)
  120. camera = cv2.VideoCapture(camera_address)
  121. continue
  122. if nframe % fps:
  123. nframe += 1
  124. continue
  125. img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
  126. nframe += 1
  127. LOG.info(f"camera is open, current frame index is {nframe}")
  128. inference_result = inference_instance.inference(img_rgb)
  129. output_deal(inference_result, nframe, img_rgb)
  130. if __name__ == "__main__":
  131. run()