You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import warnings
  17. import cv2
  18. import numpy as np
  19. from sedna.common.config import Context
  20. from sedna.common.file_ops import FileOps
  21. from sedna.core.incremental_learning import IncrementalLearning
  22. from interface import Estimator
  23. he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
  24. rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
  25. class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']
  26. FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False)
  27. def draw_boxes(img, labels, scores, bboxes, class_names, colors):
  28. line_type = 2
  29. text_thickness = 1
  30. box_thickness = 1
  31. # get color code
  32. colors = colors.split(",")
  33. colors_code = []
  34. for color in colors:
  35. if color == 'green':
  36. colors_code.append((0, 255, 0))
  37. elif color == 'blue':
  38. colors_code.append((255, 0, 0))
  39. elif color == 'yellow':
  40. colors_code.append((0, 255, 255))
  41. else:
  42. colors_code.append((0, 0, 255))
  43. label_dict = {i: label for i, label in enumerate(class_names)}
  44. for i in range(bboxes.shape[0]):
  45. bbox = bboxes[i]
  46. if float("inf") in bbox or float("-inf") in bbox:
  47. continue
  48. label = int(labels[i])
  49. score = "%.2f" % round(scores[i], 2)
  50. text = label_dict.get(label) + ":" + score
  51. p1 = (int(bbox[0]), int(bbox[1]))
  52. p2 = (int(bbox[2]), int(bbox[3]))
  53. if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
  54. continue
  55. try:
  56. cv2.rectangle(img, p1[::-1], p2[::-1],
  57. colors_code[labels[i]], box_thickness)
  58. cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)),
  59. cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
  60. text_thickness, line_type)
  61. except TypeError as err:
  62. warnings.warn(f"Draw box fail: {err}")
  63. return img
  64. def output_deal(is_hard_example, infer_result, nframe, img_rgb):
  65. # save and show image
  66. img_rgb = np.array(img_rgb)
  67. img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
  68. colors = 'yellow,blue,green,red'
  69. lables, scores, bbox_list_pred = infer_result
  70. img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names,
  71. colors)
  72. if is_hard_example:
  73. cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img)
  74. cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img)
  75. def mkdir(path):
  76. path = path.strip()
  77. path = path.rstrip()
  78. is_exists = os.path.exists(path)
  79. if not is_exists:
  80. os.makedirs(path)
  81. def deal_infer_rsl(model_output):
  82. all_classes, all_scores, all_bboxes = model_output
  83. rsl = []
  84. for c, s, bbox in zip(all_classes, all_scores, all_bboxes):
  85. bbox[0], bbox[1], bbox[2], bbox[3] = bbox[1], bbox[0], bbox[3], bbox[2]
  86. rsl.append(bbox.tolist() + [s, c])
  87. return rsl
  88. def run():
  89. camera_address = Context.get_parameters('video_url')
  90. # get hard exmaple mining algorithm from config
  91. hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config(
  92. threshold_img=0.9
  93. )
  94. input_shape_str = Context.get_parameters("input_shape")
  95. input_shape = tuple(int(v) for v in input_shape_str.split(","))
  96. # create Incremental Learning instance
  97. incremental_instance = IncrementalLearning(
  98. estimator=Estimator, hard_example_mining=hard_example_mining
  99. )
  100. # use video streams for testing
  101. camera = cv2.VideoCapture(camera_address)
  102. fps = 10
  103. nframe = 0
  104. # the input of video stream
  105. while 1:
  106. ret, input_yuv = camera.read()
  107. if not ret:
  108. time.sleep(5)
  109. camera = cv2.VideoCapture(camera_address)
  110. continue
  111. if nframe % fps:
  112. nframe += 1
  113. continue
  114. img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
  115. nframe += 1
  116. warnings.warn(f"camera is open, current frame index is {nframe}")
  117. results, _, is_hard_example = incremental_instance.inference(
  118. img_rgb, post_process=deal_infer_rsl, input_shape=input_shape)
  119. output_deal(is_hard_example, results, nframe, img_rgb)
  120. if __name__ == "__main__":
  121. run()