diff --git a/examples/incremental_learning/helmet_detection/training/inference.py b/examples/incremental_learning/helmet_detection/training/inference.py index 943d3efd..2686bfbd 100644 --- a/examples/incremental_learning/helmet_detection/training/inference.py +++ b/examples/incremental_learning/helmet_detection/training/inference.py @@ -24,11 +24,11 @@ from sedna.common.file_ops import FileOps from sedna.core.incremental_learning import IncrementalLearning from interface import Estimator - -he_saved_url = Context.get_parameters("HE_SAVED_URL") +he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp') +rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp') class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] -FileOps.clean_folder([he_saved_url], clean=False) +FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False) def draw_boxes(img, labels, scores, bboxes, class_names, colors): @@ -59,11 +59,14 @@ def draw_boxes(img, labels, scores, bboxes, class_names, colors): p2 = (int(bbox[2]), int(bbox[3])) if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): continue - cv2.rectangle(img, p1[::-1], p2[::-1], - colors_code[labels[i]], box_thickness) - cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), - cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), - text_thickness, line_type) + try: + cv2.rectangle(img, p1[::-1], p2[::-1], + colors_code[labels[i]], box_thickness) + cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), + text_thickness, line_type) + except TypeError as err: + warnings.warn(f"Draw box fail: {err}") return img @@ -72,12 +75,13 @@ def output_deal(is_hard_example, infer_result, nframe, img_rgb): img_rgb = np.array(img_rgb) img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) colors = 'yellow,blue,green,red' - if not is_hard_example: - return + lables, scores, bbox_list_pred = infer_result img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, colors) - cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) + if is_hard_example: + cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) + cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img) def mkdir(path): diff --git a/examples/incremental_learning/helmet_detection/training/interface.py b/examples/incremental_learning/helmet_detection/training/interface.py index 8e8f6302..3d1a2abd 100644 --- a/examples/incremental_learning/helmet_detection/training/interface.py +++ b/examples/incremental_learning/helmet_detection/training/interface.py @@ -15,6 +15,7 @@ import os import six import logging +from urllib.parse import urlparse import cv2 import numpy as np @@ -26,8 +27,21 @@ from validate_utils import validate from yolo3_multiscale import Yolo3 from yolo3_multiscale import YoloConfig + os.environ['BACKEND_TYPE'] = 'TENSORFLOW' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +s3_url = os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com") +if not (s3_url.startswith("http://") or s3_url.startswith("https://")): + _url = f"https://{s3_url}" +s3_url = urlparse(s3_url) +s3_use_ssl = s3_url.scheme == 'https' if s3_url.scheme else True + +os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("ACCESS_KEY_ID") +os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("SECRET_ACCESS_KEY") +os.environ["S3_ENDPOINT"] = s3_url.netloc +os.environ["S3_USE_HTTPS"] = "1" if s3_use_ssl else "0" LOG = logging.getLogger(__name__) +flags = tf.flags.FLAGS def preprocess(image, input_shape): @@ -89,7 +103,7 @@ class Estimator: data_gen = DataGen(yolo_config, train_data.x) - max_epochs = int(kwargs.get("max_epochs", "1")) + max_epochs = int(kwargs.get("epochs", flags.max_epochs)) config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True diff --git a/lib/sedna/algorithms/hard_example_mining/__init__.py b/lib/sedna/algorithms/hard_example_mining/__init__.py index f99864b5..2f9b4c5c 100644 --- a/lib/sedna/algorithms/hard_example_mining/__init__.py +++ b/lib/sedna/algorithms/hard_example_mining/__init__.py @@ -12,142 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Hard Example Mining Algorithms""" -import abc -import math - -from sedna.common.class_factory import ClassFactory, ClassType - -__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') - - -class BaseFilter(metaclass=abc.ABCMeta): - """The base class to define unified interface.""" - - def __call__(self, infer_result=None): - """predict function, and it must be implemented by - different methods class. - - :param infer_result: prediction result - :return: `True` means hard sample, `False` means not a hard sample. - """ - raise NotImplementedError - - @classmethod - def data_check(cls, data): - """Check the data in [0,1].""" - return 0 <= float(data) <= 1 - - -@ClassFactory.register(ClassType.HEM, alias="Threshold") -class ThresholdFilter(BaseFilter, abc.ABC): - def __init__(self, threshold=0.5, **kwargs): - self.threshold = float(threshold) - - def __call__(self, infer_result=None): - """ - :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) - :return: `True` means hard sample, `False` means not a hard sample. - """ - # if invalid input, return False - if not (infer_result - and all(map(lambda x: len(x) > 4, infer_result))): - return False - - image_score = 0 - - for bbox in infer_result: - image_score += bbox[4] - - average_score = image_score / (len(infer_result) or 1) - return average_score < self.threshold - - -@ClassFactory.register(ClassType.HEM, alias="CrossEntropy") -class CrossEntropyFilter(BaseFilter, abc.ABC): - """ Implement the hard samples discovery methods named IBT - (image-box-thresholds). - - :param threshold_cross_entropy: threshold_cross_entropy to filter img, - whose hard coefficient is less than - threshold_cross_entropy. And its default value is - threshold_cross_entropy=0.5 - """ - - def __init__(self, threshold_cross_entropy=0.5, **kwargs): - self.threshold_cross_entropy = float(threshold_cross_entropy) - - def __call__(self, infer_result=None): - """judge the img is hard sample or not. - - :param infer_result: - prediction classes list, - such as [class1-score, class2-score, class2-score,....], - where class-score is the score corresponding to the class, - class-score value is in [0,1], who will be ignored if its value - not in [0,1]. - :return: `True` means a hard sample, `False` means not a hard sample. - """ - - if not infer_result: - # if invalid input, return False - return False - - log_sum = 0.0 - data_check_list = [class_probability for class_probability - in infer_result - if self.data_check(class_probability)] - - if len(data_check_list) != len(infer_result): - return False - - for class_data in data_check_list: - log_sum += class_data * math.log(class_data) - confidence_score = 1 + 1.0 * log_sum / math.log( - len(infer_result)) - return confidence_score < self.threshold_cross_entropy - - -@ClassFactory.register(ClassType.HEM, alias="IBT") -class IBTFilter(BaseFilter, abc.ABC): - """Implement the hard samples discovery methods named IBT - (image-box-thresholds). - - :param threshold_img: threshold_img to filter img, whose hard coefficient - is less than threshold_img. - :param threshold_box: threshold_box to calculate hard coefficient, formula - is hard coefficient = number(prediction_boxes less than - threshold_box)/number(prediction_boxes) - """ - - def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): - self.threshold_box = float(threshold_box) - self.threshold_img = float(threshold_img) - - def __call__(self, infer_result=None): - """Judge the img is hard sample or not. - - :param infer_result: - prediction boxes list, - such as [bbox1, bbox2, bbox3,....], - where bbox = [xmin, ymin, xmax, ymax, score, label] - score should be in [0,1], who will be ignored if its value not - in [0,1]. - :return: `True` means a hard sample, `False` means not a hard sample. - """ - - if not (infer_result - and all(map(lambda x: len(x) > 4, infer_result))): - # if invalid input, return False - return False - - data_check_list = [bbox[4] for bbox in infer_result - if self.data_check(bbox[4])] - if len(data_check_list) != len(infer_result): - return False - - confidence_score_list = [ - float(box_score) for box_score in data_check_list - if float(box_score) <= self.threshold_box] - return (len(confidence_score_list) / len(infer_result) - >= (1 - self.threshold_img)) +from .hard_example_mining import * diff --git a/lib/sedna/algorithms/hard_example_mining/hard_example_mining.py b/lib/sedna/algorithms/hard_example_mining/hard_example_mining.py new file mode 100644 index 00000000..f99864b5 --- /dev/null +++ b/lib/sedna/algorithms/hard_example_mining/hard_example_mining.py @@ -0,0 +1,153 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hard Example Mining Algorithms""" +import abc +import math + +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') + + +class BaseFilter(metaclass=abc.ABCMeta): + """The base class to define unified interface.""" + + def __call__(self, infer_result=None): + """predict function, and it must be implemented by + different methods class. + + :param infer_result: prediction result + :return: `True` means hard sample, `False` means not a hard sample. + """ + raise NotImplementedError + + @classmethod + def data_check(cls, data): + """Check the data in [0,1].""" + return 0 <= float(data) <= 1 + + +@ClassFactory.register(ClassType.HEM, alias="Threshold") +class ThresholdFilter(BaseFilter, abc.ABC): + def __init__(self, threshold=0.5, **kwargs): + self.threshold = float(threshold) + + def __call__(self, infer_result=None): + """ + :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) + :return: `True` means hard sample, `False` means not a hard sample. + """ + # if invalid input, return False + if not (infer_result + and all(map(lambda x: len(x) > 4, infer_result))): + return False + + image_score = 0 + + for bbox in infer_result: + image_score += bbox[4] + + average_score = image_score / (len(infer_result) or 1) + return average_score < self.threshold + + +@ClassFactory.register(ClassType.HEM, alias="CrossEntropy") +class CrossEntropyFilter(BaseFilter, abc.ABC): + """ Implement the hard samples discovery methods named IBT + (image-box-thresholds). + + :param threshold_cross_entropy: threshold_cross_entropy to filter img, + whose hard coefficient is less than + threshold_cross_entropy. And its default value is + threshold_cross_entropy=0.5 + """ + + def __init__(self, threshold_cross_entropy=0.5, **kwargs): + self.threshold_cross_entropy = float(threshold_cross_entropy) + + def __call__(self, infer_result=None): + """judge the img is hard sample or not. + + :param infer_result: + prediction classes list, + such as [class1-score, class2-score, class2-score,....], + where class-score is the score corresponding to the class, + class-score value is in [0,1], who will be ignored if its value + not in [0,1]. + :return: `True` means a hard sample, `False` means not a hard sample. + """ + + if not infer_result: + # if invalid input, return False + return False + + log_sum = 0.0 + data_check_list = [class_probability for class_probability + in infer_result + if self.data_check(class_probability)] + + if len(data_check_list) != len(infer_result): + return False + + for class_data in data_check_list: + log_sum += class_data * math.log(class_data) + confidence_score = 1 + 1.0 * log_sum / math.log( + len(infer_result)) + return confidence_score < self.threshold_cross_entropy + + +@ClassFactory.register(ClassType.HEM, alias="IBT") +class IBTFilter(BaseFilter, abc.ABC): + """Implement the hard samples discovery methods named IBT + (image-box-thresholds). + + :param threshold_img: threshold_img to filter img, whose hard coefficient + is less than threshold_img. + :param threshold_box: threshold_box to calculate hard coefficient, formula + is hard coefficient = number(prediction_boxes less than + threshold_box)/number(prediction_boxes) + """ + + def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): + self.threshold_box = float(threshold_box) + self.threshold_img = float(threshold_img) + + def __call__(self, infer_result=None): + """Judge the img is hard sample or not. + + :param infer_result: + prediction boxes list, + such as [bbox1, bbox2, bbox3,....], + where bbox = [xmin, ymin, xmax, ymax, score, label] + score should be in [0,1], who will be ignored if its value not + in [0,1]. + :return: `True` means a hard sample, `False` means not a hard sample. + """ + + if not (infer_result + and all(map(lambda x: len(x) > 4, infer_result))): + # if invalid input, return False + return False + + data_check_list = [bbox[4] for bbox in infer_result + if self.data_check(bbox[4])] + if len(data_check_list) != len(infer_result): + return False + + confidence_score_list = [ + float(box_score) for box_score in data_check_list + if float(box_score) <= self.threshold_box] + return (len(confidence_score_list) / len(infer_result) + >= (1 - self.threshold_img)) diff --git a/lib/sedna/backend/__init__.py b/lib/sedna/backend/__init__.py index 114ca9c5..8c6c386e 100644 --- a/lib/sedna/backend/__init__.py +++ b/lib/sedna/backend/__init__.py @@ -48,7 +48,7 @@ def set_backend(estimator=None, config=None): warnings.warn(f"{backend_type} Not Support yet, use itself") from sedna.backend.base import BackendBase as REGISTER model_save_url = config.get("model_url") - base_model_save = config.get("base_model_save") or model_save_url + base_model_save = config.get("base_model_url") or model_save_url model_save_name = config.get("model_name") return REGISTER( estimator=estimator, use_cuda=use_cuda, diff --git a/lib/sedna/backend/tensorflow/__init__.py b/lib/sedna/backend/tensorflow/__init__.py index cf15fc4f..0017152c 100644 --- a/lib/sedna/backend/tensorflow/__init__.py +++ b/lib/sedna/backend/tensorflow/__init__.py @@ -36,8 +36,9 @@ class TFBackend(BackendBase): super(TFBackend, self).__init__( estimator=estimator, fine_tune=fine_tune, **kwargs) self.framework = "tensorflow" - sess_config = self._init_gpu_session_config( - ) if self.use_cuda else self._init_cpu_session_config() + + sess_config = (self._init_gpu_session_config() + if self.use_cuda else self._init_cpu_session_config()) self.graph = tf.Graph() with self.graph.as_default(): diff --git a/lib/sedna/core/incremental_learning/incremental_learning.py b/lib/sedna/core/incremental_learning/incremental_learning.py index f69492d9..ba0f39ab 100644 --- a/lib/sedna/core/incremental_learning/incremental_learning.py +++ b/lib/sedna/core/incremental_learning/incremental_learning.py @@ -28,9 +28,13 @@ class IncrementalLearning(JobBase): Incremental learning """ - def __init__(self, estimator, config=None): - super(IncrementalLearning, self).__init__( - estimator=estimator, config=config) + def __init__(self, estimator): + """ + Initial a IncrementalLearning job + :param estimator: Customize estimator + """ + + super(IncrementalLearning, self).__init__(estimator=estimator) self.model_urls = self.get_parameters( "MODEL_URLS") # use in evaluation @@ -42,6 +46,15 @@ class IncrementalLearning(JobBase): valid_data=None, post_process=None, **kwargs): + """ + Training task for IncrementalLearning + :param train_data: datasource use for train + :param valid_data: datasource use for evaluation + :param post_process: post process + :param kwargs: params for training of customize estimator + :return: estimator + """ + callback_func = None if post_process is not None: callback_func = ClassFactory.get_cls( @@ -58,6 +71,14 @@ class IncrementalLearning(JobBase): self.estimator) if callback_func else self.estimator def inference(self, data=None, post_process=None, **kwargs): + """ + Inference task for IncrementalLearning + :param data: inference sample + :param post_process: post process + :param kwargs: params for inference of customize estimator + :return: inference result, result after post_process, if is hard sample + """ + if not self.estimator.has_load: self.estimator.load(self.model_path) @@ -81,6 +102,14 @@ class IncrementalLearning(JobBase): return infer_res, res, is_hard_example def evaluate(self, data, post_process=None, **kwargs): + """ + Evaluate task for IncrementalLearning + :param data: datasource use for evaluation + :param post_process: post process + :param kwargs: params for evaluate of customize estimator + :return: evaluate metrics + """ + callback_func = None if callable(post_process): callback_func = post_process