You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

little_model.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import logging
  2. import time
  3. import copy
  4. import cv2
  5. import numpy as np
  6. import os
  7. import neptune
  8. from neptune.hard_example_mining import IBTFilter
  9. from neptune.joint_inference.joint_inference import InferenceResult
  10. LOG = logging.getLogger(__name__)
  11. # Predefined color values for frames and display categories
  12. colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),
  13. (0, 255, 255), (255, 255, 255)]
  14. class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']
  15. all_output_path = neptune.context.get_parameters(
  16. 'all_examples_inference_output'
  17. )
  18. hard_example_edge_output_path = neptune.context.get_parameters(
  19. 'hard_example_edge_inference_output'
  20. )
  21. hard_example_cloud_output_path = neptune.context.get_parameters(
  22. 'hard_example_cloud_inference_output'
  23. )
  24. def draw_boxes(img, bboxes, colors, text_thickness, box_thickness):
  25. img_copy = copy.deepcopy(img)
  26. line_type = 2
  27. # get color code
  28. colors = colors.split(",")
  29. colors_code = []
  30. for color in colors:
  31. if color == 'green':
  32. colors_code.append((0, 255, 0))
  33. elif color == 'blue':
  34. colors_code.append((255, 0, 0))
  35. elif color == 'yellow':
  36. colors_code.append((0, 255, 255))
  37. else:
  38. colors_code.append((0, 0, 255))
  39. label_dict = {i: label for i, label in enumerate(class_names)}
  40. for bbox in bboxes:
  41. if float("inf") in bbox or float("-inf") in bbox:
  42. continue
  43. label = int(bbox[5])
  44. score = "%.2f" % round(bbox[4], 2)
  45. text = label_dict.get(label) + ":" + score
  46. p1 = (int(bbox[1]), int(bbox[0]))
  47. p2 = (int(bbox[3]), int(bbox[2]))
  48. if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
  49. continue
  50. cv2.rectangle(img_copy, p1[::-1], p2[::-1], colors_code[label],
  51. box_thickness)
  52. cv2.putText(img_copy, text, (p1[1], p1[0] + 20 * (label + 1)),
  53. cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
  54. text_thickness, line_type)
  55. return img_copy
  56. def preprocess(image, input_shape):
  57. """Preprocess functions in edge model inference"""
  58. # resize image with unchanged aspect ratio using padding by opencv
  59. h, w, _ = image.shape
  60. input_h, input_w = input_shape
  61. scale = min(float(input_w) / float(w), float(input_h) / float(h))
  62. nw = int(w * scale)
  63. nh = int(h * scale)
  64. image = cv2.resize(image, (nw, nh))
  65. new_image = np.zeros((input_h, input_w, 3), np.float32)
  66. new_image.fill(128)
  67. bh, bw, _ = new_image.shape
  68. new_image[int((bh - nh) / 2):(nh + int((bh - nh) / 2)),
  69. int((bw - nw) / 2):(nw + int((bw - nw) / 2)), :] = image
  70. new_image /= 255.
  71. new_image = np.expand_dims(new_image, 0) # Add batch dimension.
  72. return new_image
  73. def create_input_feed(sess, new_image, img_data):
  74. """Create input feed for edge model inference"""
  75. input_feed = {}
  76. input_img_data = sess.graph.get_tensor_by_name('images:0')
  77. input_feed[input_img_data] = new_image
  78. input_img_shape = sess.graph.get_tensor_by_name('shapes:0')
  79. input_feed[input_img_shape] = [img_data.shape[0], img_data.shape[1]]
  80. return input_feed
  81. def create_output_fetch(sess):
  82. """Create output fetch for edge model inference"""
  83. output_classes = sess.graph.get_tensor_by_name('concat_19:0')
  84. output_scores = sess.graph.get_tensor_by_name('concat_18:0')
  85. output_boxes = sess.graph.get_tensor_by_name('concat_17:0')
  86. output_fetch = [output_classes, output_scores, output_boxes]
  87. return output_fetch
  88. def postprocess(model_output):
  89. all_classes, all_scores, all_bboxes = model_output
  90. bboxes = []
  91. for c, s, bbox in zip(all_classes, all_scores, all_bboxes):
  92. bbox[0], bbox[1], bbox[2], bbox[3] = bbox[1], bbox[0], bbox[3], bbox[2]
  93. bboxes.append(bbox.tolist() + [s, c])
  94. return bboxes
  95. def output_deal(inference_result: InferenceResult, nframe, img_rgb):
  96. # save and show image
  97. img_rgb = np.array(img_rgb)
  98. img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
  99. collaboration_frame = draw_boxes(img_rgb, inference_result.final_result,
  100. colors="green,blue,yellow,red",
  101. text_thickness=None,
  102. box_thickness=None)
  103. cv2.imwrite(f"{all_output_path}/{nframe}.jpeg", collaboration_frame)
  104. # save hard example image to dir
  105. if not inference_result.is_hard_example:
  106. return
  107. if inference_result.hard_example_cloud_result is not None:
  108. cv2.imwrite(f"{hard_example_cloud_output_path}/{nframe}.jpeg",
  109. collaboration_frame)
  110. edge_collaboration_frame = draw_boxes(
  111. img_rgb,
  112. inference_result.hard_example_edge_result,
  113. colors="green,blue,yellow,red",
  114. text_thickness=None,
  115. box_thickness=None)
  116. cv2.imwrite(f"{hard_example_edge_output_path}/{nframe}.jpeg",
  117. edge_collaboration_frame)
  118. def mkdir(path):
  119. path = path.strip()
  120. path = path.rstrip()
  121. is_exists = os.path.exists(path)
  122. if not is_exists:
  123. os.makedirs(path)
  124. LOG.info(f"{path} is not exists, create the dir")
  125. def run():
  126. input_shape_str = neptune.context.get_parameters("input_shape")
  127. input_shape = tuple(int(v) for v in input_shape_str.split(","))
  128. camera_address = neptune.context.get_parameters('video_url')
  129. mkdir(all_output_path)
  130. mkdir(hard_example_edge_output_path)
  131. mkdir(hard_example_cloud_output_path)
  132. # create little model object
  133. model = neptune.joint_inference.TSLittleModel(
  134. preprocess=preprocess,
  135. postprocess=postprocess,
  136. input_shape=input_shape,
  137. create_input_feed=create_input_feed,
  138. create_output_fetch=create_output_fetch
  139. )
  140. # create hard example algorithm
  141. threshold_box = float(neptune.context.get_hem_parameters(
  142. "threshold_box", 0.5
  143. ))
  144. threshold_img = float(neptune.context.get_hem_parameters(
  145. "threshold_img", 0.5
  146. ))
  147. hard_example_mining_algorithm = IBTFilter(threshold_img, threshold_box)
  148. # create joint inference object
  149. inference_instance = neptune.joint_inference.JointInference(
  150. little_model=model,
  151. hard_example_mining_algorithm=hard_example_mining_algorithm
  152. )
  153. # use video streams for testing
  154. camera = cv2.VideoCapture(camera_address)
  155. fps = 10
  156. nframe = 0
  157. # the input of video stream
  158. while 1:
  159. ret, input_yuv = camera.read()
  160. if not ret:
  161. LOG.info(
  162. f"camera is not open, camera_address={camera_address},"
  163. f" sleep 5 second.")
  164. time.sleep(5)
  165. camera = cv2.VideoCapture(camera_address)
  166. continue
  167. if nframe % fps:
  168. nframe += 1
  169. continue
  170. img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
  171. nframe += 1
  172. LOG.info(f"camera is open, current frame index is {nframe}")
  173. inference_result = inference_instance.inference(img_rgb)
  174. output_deal(inference_result, nframe, img_rgb)
  175. if __name__ == "__main__":
  176. run()