# Copyright 2021 The KubeEdge Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time import warnings import cv2 import numpy as np from sedna.common.config import Context from sedna.common.file_ops import FileOps from sedna.core.incremental_learning import IncrementalLearning from interface import Estimator he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp') rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp') class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False) def draw_boxes(img, labels, scores, bboxes, class_names, colors): line_type = 2 text_thickness = 1 box_thickness = 1 # get color code colors = colors.split(",") colors_code = [] for color in colors: if color == 'green': colors_code.append((0, 255, 0)) elif color == 'blue': colors_code.append((255, 0, 0)) elif color == 'yellow': colors_code.append((0, 255, 255)) else: colors_code.append((0, 0, 255)) label_dict = {i: label for i, label in enumerate(class_names)} for i in range(bboxes.shape[0]): bbox = bboxes[i] if float("inf") in bbox or float("-inf") in bbox: continue label = int(labels[i]) score = "%.2f" % round(scores[i], 2) text = label_dict.get(label) + ":" + score p1 = (int(bbox[0]), int(bbox[1])) p2 = (int(bbox[2]), int(bbox[3])) if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): continue try: cv2.rectangle(img, p1[::-1], p2[::-1], colors_code[labels[i]], box_thickness) cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), text_thickness, line_type) except TypeError as err: warnings.warn(f"Draw box fail: {err}") return img def output_deal(is_hard_example, infer_result, nframe, img_rgb): # save and show image img_rgb = np.array(img_rgb) img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) colors = 'yellow,blue,green,red' lables, scores, bbox_list_pred = infer_result img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, colors) if is_hard_example: cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img) def mkdir(path): path = path.strip() path = path.rstrip() is_exists = os.path.exists(path) if not is_exists: os.makedirs(path) def deal_infer_rsl(model_output): all_classes, all_scores, all_bboxes = model_output rsl = [] for c, s, bbox in zip(all_classes, all_scores, all_bboxes): bbox[0], bbox[1], bbox[2], bbox[3] = bbox[1], bbox[0], bbox[3], bbox[2] rsl.append(bbox.tolist() + [s, c]) return rsl def run(): camera_address = Context.get_parameters('video_url') # get hard exmaple mining algorithm from config hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config( threshold_img=0.9 ) input_shape_str = Context.get_parameters("input_shape") input_shape = tuple(int(v) for v in input_shape_str.split(",")) # create Incremental Learning instance incremental_instance = IncrementalLearning( estimator=Estimator, hard_example_mining=hard_example_mining ) # use video streams for testing camera = cv2.VideoCapture(camera_address) fps = 10 nframe = 0 # the input of video stream while 1: ret, input_yuv = camera.read() if not ret: time.sleep(5) camera = cv2.VideoCapture(camera_address) continue if nframe % fps: nframe += 1 continue img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB) nframe += 1 warnings.warn(f"camera is open, current frame index is {nframe}") results, _, is_hard_example = incremental_instance.inference( img_rgb, post_process=deal_infer_rsl, input_shape=input_shape) output_deal(is_hard_example, results, nframe, img_rgb) if __name__ == "__main__": run()