- Add docs and code comment - fix bugs: epoch always be 1, inference result not saved, s3 upload fail Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>tags/v0.3.1
| @@ -24,11 +24,11 @@ from sedna.common.file_ops import FileOps | |||||
| from sedna.core.incremental_learning import IncrementalLearning | from sedna.core.incremental_learning import IncrementalLearning | ||||
| from interface import Estimator | from interface import Estimator | ||||
| he_saved_url = Context.get_parameters("HE_SAVED_URL") | |||||
| he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp') | |||||
| rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp') | |||||
| class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] | class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] | ||||
| FileOps.clean_folder([he_saved_url], clean=False) | |||||
| FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False) | |||||
| def draw_boxes(img, labels, scores, bboxes, class_names, colors): | def draw_boxes(img, labels, scores, bboxes, class_names, colors): | ||||
| @@ -59,11 +59,14 @@ def draw_boxes(img, labels, scores, bboxes, class_names, colors): | |||||
| p2 = (int(bbox[2]), int(bbox[3])) | p2 = (int(bbox[2]), int(bbox[3])) | ||||
| if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): | if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): | ||||
| continue | continue | ||||
| cv2.rectangle(img, p1[::-1], p2[::-1], | |||||
| colors_code[labels[i]], box_thickness) | |||||
| cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), | |||||
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), | |||||
| text_thickness, line_type) | |||||
| try: | |||||
| cv2.rectangle(img, p1[::-1], p2[::-1], | |||||
| colors_code[labels[i]], box_thickness) | |||||
| cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), | |||||
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), | |||||
| text_thickness, line_type) | |||||
| except TypeError as err: | |||||
| warnings.warn(f"Draw box fail: {err}") | |||||
| return img | return img | ||||
| @@ -72,12 +75,13 @@ def output_deal(is_hard_example, infer_result, nframe, img_rgb): | |||||
| img_rgb = np.array(img_rgb) | img_rgb = np.array(img_rgb) | ||||
| img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | ||||
| colors = 'yellow,blue,green,red' | colors = 'yellow,blue,green,red' | ||||
| if not is_hard_example: | |||||
| return | |||||
| lables, scores, bbox_list_pred = infer_result | lables, scores, bbox_list_pred = infer_result | ||||
| img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, | img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, | ||||
| colors) | colors) | ||||
| cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) | |||||
| if is_hard_example: | |||||
| cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) | |||||
| cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img) | |||||
| def mkdir(path): | def mkdir(path): | ||||
| @@ -15,6 +15,7 @@ | |||||
| import os | import os | ||||
| import six | import six | ||||
| import logging | import logging | ||||
| from urllib.parse import urlparse | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| @@ -26,8 +27,21 @@ from validate_utils import validate | |||||
| from yolo3_multiscale import Yolo3 | from yolo3_multiscale import Yolo3 | ||||
| from yolo3_multiscale import YoloConfig | from yolo3_multiscale import YoloConfig | ||||
| os.environ['BACKEND_TYPE'] = 'TENSORFLOW' | os.environ['BACKEND_TYPE'] = 'TENSORFLOW' | ||||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |||||
| s3_url = os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com") | |||||
| if not (s3_url.startswith("http://") or s3_url.startswith("https://")): | |||||
| _url = f"https://{s3_url}" | |||||
| s3_url = urlparse(s3_url) | |||||
| s3_use_ssl = s3_url.scheme == 'https' if s3_url.scheme else True | |||||
| os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("ACCESS_KEY_ID") | |||||
| os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("SECRET_ACCESS_KEY") | |||||
| os.environ["S3_ENDPOINT"] = s3_url.netloc | |||||
| os.environ["S3_USE_HTTPS"] = "1" if s3_use_ssl else "0" | |||||
| LOG = logging.getLogger(__name__) | LOG = logging.getLogger(__name__) | ||||
| flags = tf.flags.FLAGS | |||||
| def preprocess(image, input_shape): | def preprocess(image, input_shape): | ||||
| @@ -89,7 +103,7 @@ class Estimator: | |||||
| data_gen = DataGen(yolo_config, train_data.x) | data_gen = DataGen(yolo_config, train_data.x) | ||||
| max_epochs = int(kwargs.get("max_epochs", "1")) | |||||
| max_epochs = int(kwargs.get("epochs", flags.max_epochs)) | |||||
| config = tf.ConfigProto(allow_soft_placement=True) | config = tf.ConfigProto(allow_soft_placement=True) | ||||
| config.gpu_options.allow_growth = True | config.gpu_options.allow_growth = True | ||||
| @@ -12,142 +12,4 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """Hard Example Mining Algorithms""" | |||||
| import abc | |||||
| import math | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | |||||
| __all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') | |||||
| class BaseFilter(metaclass=abc.ABCMeta): | |||||
| """The base class to define unified interface.""" | |||||
| def __call__(self, infer_result=None): | |||||
| """predict function, and it must be implemented by | |||||
| different methods class. | |||||
| :param infer_result: prediction result | |||||
| :return: `True` means hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @classmethod | |||||
| def data_check(cls, data): | |||||
| """Check the data in [0,1].""" | |||||
| return 0 <= float(data) <= 1 | |||||
| @ClassFactory.register(ClassType.HEM, alias="Threshold") | |||||
| class ThresholdFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self, threshold=0.5, **kwargs): | |||||
| self.threshold = float(threshold) | |||||
| def __call__(self, infer_result=None): | |||||
| """ | |||||
| :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | |||||
| :return: `True` means hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| # if invalid input, return False | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| return False | |||||
| image_score = 0 | |||||
| for bbox in infer_result: | |||||
| image_score += bbox[4] | |||||
| average_score = image_score / (len(infer_result) or 1) | |||||
| return average_score < self.threshold | |||||
| @ClassFactory.register(ClassType.HEM, alias="CrossEntropy") | |||||
| class CrossEntropyFilter(BaseFilter, abc.ABC): | |||||
| """ Implement the hard samples discovery methods named IBT | |||||
| (image-box-thresholds). | |||||
| :param threshold_cross_entropy: threshold_cross_entropy to filter img, | |||||
| whose hard coefficient is less than | |||||
| threshold_cross_entropy. And its default value is | |||||
| threshold_cross_entropy=0.5 | |||||
| """ | |||||
| def __init__(self, threshold_cross_entropy=0.5, **kwargs): | |||||
| self.threshold_cross_entropy = float(threshold_cross_entropy) | |||||
| def __call__(self, infer_result=None): | |||||
| """judge the img is hard sample or not. | |||||
| :param infer_result: | |||||
| prediction classes list, | |||||
| such as [class1-score, class2-score, class2-score,....], | |||||
| where class-score is the score corresponding to the class, | |||||
| class-score value is in [0,1], who will be ignored if its value | |||||
| not in [0,1]. | |||||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| if not infer_result: | |||||
| # if invalid input, return False | |||||
| return False | |||||
| log_sum = 0.0 | |||||
| data_check_list = [class_probability for class_probability | |||||
| in infer_result | |||||
| if self.data_check(class_probability)] | |||||
| if len(data_check_list) != len(infer_result): | |||||
| return False | |||||
| for class_data in data_check_list: | |||||
| log_sum += class_data * math.log(class_data) | |||||
| confidence_score = 1 + 1.0 * log_sum / math.log( | |||||
| len(infer_result)) | |||||
| return confidence_score < self.threshold_cross_entropy | |||||
| @ClassFactory.register(ClassType.HEM, alias="IBT") | |||||
| class IBTFilter(BaseFilter, abc.ABC): | |||||
| """Implement the hard samples discovery methods named IBT | |||||
| (image-box-thresholds). | |||||
| :param threshold_img: threshold_img to filter img, whose hard coefficient | |||||
| is less than threshold_img. | |||||
| :param threshold_box: threshold_box to calculate hard coefficient, formula | |||||
| is hard coefficient = number(prediction_boxes less than | |||||
| threshold_box)/number(prediction_boxes) | |||||
| """ | |||||
| def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): | |||||
| self.threshold_box = float(threshold_box) | |||||
| self.threshold_img = float(threshold_img) | |||||
| def __call__(self, infer_result=None): | |||||
| """Judge the img is hard sample or not. | |||||
| :param infer_result: | |||||
| prediction boxes list, | |||||
| such as [bbox1, bbox2, bbox3,....], | |||||
| where bbox = [xmin, ymin, xmax, ymax, score, label] | |||||
| score should be in [0,1], who will be ignored if its value not | |||||
| in [0,1]. | |||||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| # if invalid input, return False | |||||
| return False | |||||
| data_check_list = [bbox[4] for bbox in infer_result | |||||
| if self.data_check(bbox[4])] | |||||
| if len(data_check_list) != len(infer_result): | |||||
| return False | |||||
| confidence_score_list = [ | |||||
| float(box_score) for box_score in data_check_list | |||||
| if float(box_score) <= self.threshold_box] | |||||
| return (len(confidence_score_list) / len(infer_result) | |||||
| >= (1 - self.threshold_img)) | |||||
| from .hard_example_mining import * | |||||
| @@ -0,0 +1,153 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Hard Example Mining Algorithms""" | |||||
| import abc | |||||
| import math | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | |||||
| __all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') | |||||
| class BaseFilter(metaclass=abc.ABCMeta): | |||||
| """The base class to define unified interface.""" | |||||
| def __call__(self, infer_result=None): | |||||
| """predict function, and it must be implemented by | |||||
| different methods class. | |||||
| :param infer_result: prediction result | |||||
| :return: `True` means hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @classmethod | |||||
| def data_check(cls, data): | |||||
| """Check the data in [0,1].""" | |||||
| return 0 <= float(data) <= 1 | |||||
| @ClassFactory.register(ClassType.HEM, alias="Threshold") | |||||
| class ThresholdFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self, threshold=0.5, **kwargs): | |||||
| self.threshold = float(threshold) | |||||
| def __call__(self, infer_result=None): | |||||
| """ | |||||
| :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | |||||
| :return: `True` means hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| # if invalid input, return False | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| return False | |||||
| image_score = 0 | |||||
| for bbox in infer_result: | |||||
| image_score += bbox[4] | |||||
| average_score = image_score / (len(infer_result) or 1) | |||||
| return average_score < self.threshold | |||||
| @ClassFactory.register(ClassType.HEM, alias="CrossEntropy") | |||||
| class CrossEntropyFilter(BaseFilter, abc.ABC): | |||||
| """ Implement the hard samples discovery methods named IBT | |||||
| (image-box-thresholds). | |||||
| :param threshold_cross_entropy: threshold_cross_entropy to filter img, | |||||
| whose hard coefficient is less than | |||||
| threshold_cross_entropy. And its default value is | |||||
| threshold_cross_entropy=0.5 | |||||
| """ | |||||
| def __init__(self, threshold_cross_entropy=0.5, **kwargs): | |||||
| self.threshold_cross_entropy = float(threshold_cross_entropy) | |||||
| def __call__(self, infer_result=None): | |||||
| """judge the img is hard sample or not. | |||||
| :param infer_result: | |||||
| prediction classes list, | |||||
| such as [class1-score, class2-score, class2-score,....], | |||||
| where class-score is the score corresponding to the class, | |||||
| class-score value is in [0,1], who will be ignored if its value | |||||
| not in [0,1]. | |||||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| if not infer_result: | |||||
| # if invalid input, return False | |||||
| return False | |||||
| log_sum = 0.0 | |||||
| data_check_list = [class_probability for class_probability | |||||
| in infer_result | |||||
| if self.data_check(class_probability)] | |||||
| if len(data_check_list) != len(infer_result): | |||||
| return False | |||||
| for class_data in data_check_list: | |||||
| log_sum += class_data * math.log(class_data) | |||||
| confidence_score = 1 + 1.0 * log_sum / math.log( | |||||
| len(infer_result)) | |||||
| return confidence_score < self.threshold_cross_entropy | |||||
| @ClassFactory.register(ClassType.HEM, alias="IBT") | |||||
| class IBTFilter(BaseFilter, abc.ABC): | |||||
| """Implement the hard samples discovery methods named IBT | |||||
| (image-box-thresholds). | |||||
| :param threshold_img: threshold_img to filter img, whose hard coefficient | |||||
| is less than threshold_img. | |||||
| :param threshold_box: threshold_box to calculate hard coefficient, formula | |||||
| is hard coefficient = number(prediction_boxes less than | |||||
| threshold_box)/number(prediction_boxes) | |||||
| """ | |||||
| def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): | |||||
| self.threshold_box = float(threshold_box) | |||||
| self.threshold_img = float(threshold_img) | |||||
| def __call__(self, infer_result=None): | |||||
| """Judge the img is hard sample or not. | |||||
| :param infer_result: | |||||
| prediction boxes list, | |||||
| such as [bbox1, bbox2, bbox3,....], | |||||
| where bbox = [xmin, ymin, xmax, ymax, score, label] | |||||
| score should be in [0,1], who will be ignored if its value not | |||||
| in [0,1]. | |||||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| # if invalid input, return False | |||||
| return False | |||||
| data_check_list = [bbox[4] for bbox in infer_result | |||||
| if self.data_check(bbox[4])] | |||||
| if len(data_check_list) != len(infer_result): | |||||
| return False | |||||
| confidence_score_list = [ | |||||
| float(box_score) for box_score in data_check_list | |||||
| if float(box_score) <= self.threshold_box] | |||||
| return (len(confidence_score_list) / len(infer_result) | |||||
| >= (1 - self.threshold_img)) | |||||
| @@ -48,7 +48,7 @@ def set_backend(estimator=None, config=None): | |||||
| warnings.warn(f"{backend_type} Not Support yet, use itself") | warnings.warn(f"{backend_type} Not Support yet, use itself") | ||||
| from sedna.backend.base import BackendBase as REGISTER | from sedna.backend.base import BackendBase as REGISTER | ||||
| model_save_url = config.get("model_url") | model_save_url = config.get("model_url") | ||||
| base_model_save = config.get("base_model_save") or model_save_url | |||||
| base_model_save = config.get("base_model_url") or model_save_url | |||||
| model_save_name = config.get("model_name") | model_save_name = config.get("model_name") | ||||
| return REGISTER( | return REGISTER( | ||||
| estimator=estimator, use_cuda=use_cuda, | estimator=estimator, use_cuda=use_cuda, | ||||
| @@ -36,8 +36,9 @@ class TFBackend(BackendBase): | |||||
| super(TFBackend, self).__init__( | super(TFBackend, self).__init__( | ||||
| estimator=estimator, fine_tune=fine_tune, **kwargs) | estimator=estimator, fine_tune=fine_tune, **kwargs) | ||||
| self.framework = "tensorflow" | self.framework = "tensorflow" | ||||
| sess_config = self._init_gpu_session_config( | |||||
| ) if self.use_cuda else self._init_cpu_session_config() | |||||
| sess_config = (self._init_gpu_session_config() | |||||
| if self.use_cuda else self._init_cpu_session_config()) | |||||
| self.graph = tf.Graph() | self.graph = tf.Graph() | ||||
| with self.graph.as_default(): | with self.graph.as_default(): | ||||
| @@ -28,9 +28,13 @@ class IncrementalLearning(JobBase): | |||||
| Incremental learning | Incremental learning | ||||
| """ | """ | ||||
| def __init__(self, estimator, config=None): | |||||
| super(IncrementalLearning, self).__init__( | |||||
| estimator=estimator, config=config) | |||||
| def __init__(self, estimator): | |||||
| """ | |||||
| Initial a IncrementalLearning job | |||||
| :param estimator: Customize estimator | |||||
| """ | |||||
| super(IncrementalLearning, self).__init__(estimator=estimator) | |||||
| self.model_urls = self.get_parameters( | self.model_urls = self.get_parameters( | ||||
| "MODEL_URLS") # use in evaluation | "MODEL_URLS") # use in evaluation | ||||
| @@ -42,6 +46,15 @@ class IncrementalLearning(JobBase): | |||||
| valid_data=None, | valid_data=None, | ||||
| post_process=None, | post_process=None, | ||||
| **kwargs): | **kwargs): | ||||
| """ | |||||
| Training task for IncrementalLearning | |||||
| :param train_data: datasource use for train | |||||
| :param valid_data: datasource use for evaluation | |||||
| :param post_process: post process | |||||
| :param kwargs: params for training of customize estimator | |||||
| :return: estimator | |||||
| """ | |||||
| callback_func = None | callback_func = None | ||||
| if post_process is not None: | if post_process is not None: | ||||
| callback_func = ClassFactory.get_cls( | callback_func = ClassFactory.get_cls( | ||||
| @@ -58,6 +71,14 @@ class IncrementalLearning(JobBase): | |||||
| self.estimator) if callback_func else self.estimator | self.estimator) if callback_func else self.estimator | ||||
| def inference(self, data=None, post_process=None, **kwargs): | def inference(self, data=None, post_process=None, **kwargs): | ||||
| """ | |||||
| Inference task for IncrementalLearning | |||||
| :param data: inference sample | |||||
| :param post_process: post process | |||||
| :param kwargs: params for inference of customize estimator | |||||
| :return: inference result, result after post_process, if is hard sample | |||||
| """ | |||||
| if not self.estimator.has_load: | if not self.estimator.has_load: | ||||
| self.estimator.load(self.model_path) | self.estimator.load(self.model_path) | ||||
| @@ -81,6 +102,14 @@ class IncrementalLearning(JobBase): | |||||
| return infer_res, res, is_hard_example | return infer_res, res, is_hard_example | ||||
| def evaluate(self, data, post_process=None, **kwargs): | def evaluate(self, data, post_process=None, **kwargs): | ||||
| """ | |||||
| Evaluate task for IncrementalLearning | |||||
| :param data: datasource use for evaluation | |||||
| :param post_process: post process | |||||
| :param kwargs: params for evaluate of customize estimator | |||||
| :return: evaluate metrics | |||||
| """ | |||||
| callback_func = None | callback_func = None | ||||
| if callable(post_process): | if callable(post_process): | ||||
| callback_func = post_process | callback_func = post_process | ||||