[incremental learning] example:keep all results whether is hardExample or not, fixed the issue of using s3 to save modeltags/v0.3.1
| @@ -20,8 +20,6 @@ from sedna.datasources import TxtDataParse | |||
| from interface import Estimator | |||
| max_epochs = 1 | |||
| def _load_txt_dataset(dataset_url): | |||
| # use original dataset url, | |||
| @@ -43,9 +41,9 @@ def main(): | |||
| input_shape = Context.get_parameters("input_shape") | |||
| input_shape = tuple(int(shape) for shape in input_shape.split(',')) | |||
| model = IncrementalLearning(estimator=Estimator) | |||
| return model.evaluate(valid_data, class_names=class_names, | |||
| input_shape=input_shape) | |||
| incremental_instance = IncrementalLearning(estimator=Estimator) | |||
| return incremental_instance.evaluate(valid_data, class_names=class_names, | |||
| input_shape=input_shape) | |||
| if __name__ == '__main__': | |||
| @@ -25,10 +25,11 @@ from sedna.core.incremental_learning import IncrementalLearning | |||
| from interface import Estimator | |||
| he_saved_url = Context.get_parameters("HE_SAVED_URL") | |||
| he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp') | |||
| rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp') | |||
| class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] | |||
| FileOps.clean_folder([he_saved_url], clean=False) | |||
| FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False) | |||
| def draw_boxes(img, labels, scores, bboxes, class_names, colors): | |||
| @@ -59,11 +60,14 @@ def draw_boxes(img, labels, scores, bboxes, class_names, colors): | |||
| p2 = (int(bbox[2]), int(bbox[3])) | |||
| if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1): | |||
| continue | |||
| cv2.rectangle(img, p1[::-1], p2[::-1], | |||
| colors_code[labels[i]], box_thickness) | |||
| cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), | |||
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), | |||
| text_thickness, line_type) | |||
| try: | |||
| cv2.rectangle(img, p1[::-1], p2[::-1], | |||
| colors_code[labels[i]], box_thickness) | |||
| cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), | |||
| cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), | |||
| text_thickness, line_type) | |||
| except TypeError as err: | |||
| warnings.warn(f"Draw box fail: {err}") | |||
| return img | |||
| @@ -72,12 +76,13 @@ def output_deal(is_hard_example, infer_result, nframe, img_rgb): | |||
| img_rgb = np.array(img_rgb) | |||
| img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | |||
| colors = 'yellow,blue,green,red' | |||
| if not is_hard_example: | |||
| return | |||
| lables, scores, bbox_list_pred = infer_result | |||
| img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, | |||
| colors) | |||
| cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) | |||
| if is_hard_example: | |||
| cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) | |||
| cv2.imwrite(f"{rsl_saved_url}/{nframe}.jpeg", img) | |||
| def mkdir(path): | |||
| @@ -100,10 +105,17 @@ def deal_infer_rsl(model_output): | |||
| def run(): | |||
| camera_address = Context.get_parameters('video_url') | |||
| # get hard exmaple mining algorithm from config | |||
| hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config( | |||
| threshold_img=0.9 | |||
| ) | |||
| input_shape_str = Context.get_parameters("input_shape") | |||
| input_shape = tuple(int(v) for v in input_shape_str.split(",")) | |||
| # create little model object | |||
| model = IncrementalLearning(estimator=Estimator) | |||
| # create Incremental Learning instance | |||
| incremental_instance = IncrementalLearning( | |||
| estimator=Estimator, hard_example_mining=hard_example_mining | |||
| ) | |||
| # use video streams for testing | |||
| camera = cv2.VideoCapture(camera_address) | |||
| fps = 10 | |||
| @@ -123,7 +135,7 @@ def run(): | |||
| img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB) | |||
| nframe += 1 | |||
| warnings.warn(f"camera is open, current frame index is {nframe}") | |||
| results, _, is_hard_example = model.inference( | |||
| results, _, is_hard_example = incremental_instance.inference( | |||
| img_rgb, post_process=deal_infer_rsl, input_shape=input_shape) | |||
| output_deal(is_hard_example, results, nframe, img_rgb) | |||
| @@ -15,6 +15,7 @@ | |||
| import os | |||
| import six | |||
| import logging | |||
| from urllib.parse import urlparse | |||
| import cv2 | |||
| import numpy as np | |||
| @@ -26,8 +27,21 @@ from validate_utils import validate | |||
| from yolo3_multiscale import Yolo3 | |||
| from yolo3_multiscale import YoloConfig | |||
| os.environ['BACKEND_TYPE'] = 'TENSORFLOW' | |||
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |||
| s3_url = os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com") | |||
| if not (s3_url.startswith("http://") or s3_url.startswith("https://")): | |||
| _url = f"https://{s3_url}" | |||
| s3_url = urlparse(s3_url) | |||
| s3_use_ssl = s3_url.scheme == 'https' if s3_url.scheme else True | |||
| os.environ["AWS_ACCESS_KEY_ID"] = os.getenv("ACCESS_KEY_ID", "") | |||
| os.environ["AWS_SECRET_ACCESS_KEY"] = os.getenv("SECRET_ACCESS_KEY", "") | |||
| os.environ["S3_ENDPOINT"] = s3_url.netloc | |||
| os.environ["S3_USE_HTTPS"] = "1" if s3_use_ssl else "0" | |||
| LOG = logging.getLogger(__name__) | |||
| flags = tf.flags.FLAGS | |||
| def preprocess(image, input_shape): | |||
| @@ -89,7 +103,7 @@ class Estimator: | |||
| data_gen = DataGen(yolo_config, train_data.x) | |||
| max_epochs = int(kwargs.get("max_epochs", "1")) | |||
| max_epochs = int(kwargs.get("epochs", flags.max_epochs)) | |||
| config = tf.ConfigProto(allow_soft_placement=True) | |||
| config.gpu_options.allow_growth = True | |||
| @@ -24,7 +24,6 @@ from interface import Estimator | |||
| def _load_txt_dataset(dataset_url): | |||
| # use original dataset url, | |||
| # see https://github.com/kubeedge/sedna/issues/35 | |||
| original_dataset_url = Context.get_parameters('original_dataset_url') | |||
| @@ -93,13 +92,13 @@ def main(): | |||
| tf.flags.DEFINE_string('result_url', default=None, | |||
| help='result url for training') | |||
| model = IncrementalLearning(estimator=Estimator) | |||
| return model.train(train_data=train_data, epochs=epochs, | |||
| batch_size=batch_size, | |||
| class_names=class_names, | |||
| input_shape=input_shape, | |||
| obj_threshold=obj_threshold, | |||
| nms_threshold=nms_threshold) | |||
| incremental_instance = IncrementalLearning(estimator=Estimator) | |||
| return incremental_instance.train(train_data=train_data, epochs=epochs, | |||
| batch_size=batch_size, | |||
| class_names=class_names, | |||
| input_shape=input_shape, | |||
| obj_threshold=obj_threshold, | |||
| nms_threshold=nms_threshold) | |||
| if __name__ == '__main__': | |||
| @@ -12,142 +12,4 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Hard Example Mining Algorithms""" | |||
| import abc | |||
| import math | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') | |||
| class BaseFilter(metaclass=abc.ABCMeta): | |||
| """The base class to define unified interface.""" | |||
| def __call__(self, infer_result=None): | |||
| """predict function, and it must be implemented by | |||
| different methods class. | |||
| :param infer_result: prediction result | |||
| :return: `True` means hard sample, `False` means not a hard sample. | |||
| """ | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def data_check(cls, data): | |||
| """Check the data in [0,1].""" | |||
| return 0 <= float(data) <= 1 | |||
| @ClassFactory.register(ClassType.HEM, alias="Threshold") | |||
| class ThresholdFilter(BaseFilter, abc.ABC): | |||
| def __init__(self, threshold=0.5, **kwargs): | |||
| self.threshold = float(threshold) | |||
| def __call__(self, infer_result=None): | |||
| """ | |||
| :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | |||
| :return: `True` means hard sample, `False` means not a hard sample. | |||
| """ | |||
| # if invalid input, return False | |||
| if not (infer_result | |||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||
| return False | |||
| image_score = 0 | |||
| for bbox in infer_result: | |||
| image_score += bbox[4] | |||
| average_score = image_score / (len(infer_result) or 1) | |||
| return average_score < self.threshold | |||
| @ClassFactory.register(ClassType.HEM, alias="CrossEntropy") | |||
| class CrossEntropyFilter(BaseFilter, abc.ABC): | |||
| """ Implement the hard samples discovery methods named IBT | |||
| (image-box-thresholds). | |||
| :param threshold_cross_entropy: threshold_cross_entropy to filter img, | |||
| whose hard coefficient is less than | |||
| threshold_cross_entropy. And its default value is | |||
| threshold_cross_entropy=0.5 | |||
| """ | |||
| def __init__(self, threshold_cross_entropy=0.5, **kwargs): | |||
| self.threshold_cross_entropy = float(threshold_cross_entropy) | |||
| def __call__(self, infer_result=None): | |||
| """judge the img is hard sample or not. | |||
| :param infer_result: | |||
| prediction classes list, | |||
| such as [class1-score, class2-score, class2-score,....], | |||
| where class-score is the score corresponding to the class, | |||
| class-score value is in [0,1], who will be ignored if its value | |||
| not in [0,1]. | |||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||
| """ | |||
| if not infer_result: | |||
| # if invalid input, return False | |||
| return False | |||
| log_sum = 0.0 | |||
| data_check_list = [class_probability for class_probability | |||
| in infer_result | |||
| if self.data_check(class_probability)] | |||
| if len(data_check_list) != len(infer_result): | |||
| return False | |||
| for class_data in data_check_list: | |||
| log_sum += class_data * math.log(class_data) | |||
| confidence_score = 1 + 1.0 * log_sum / math.log( | |||
| len(infer_result)) | |||
| return confidence_score < self.threshold_cross_entropy | |||
| @ClassFactory.register(ClassType.HEM, alias="IBT") | |||
| class IBTFilter(BaseFilter, abc.ABC): | |||
| """Implement the hard samples discovery methods named IBT | |||
| (image-box-thresholds). | |||
| :param threshold_img: threshold_img to filter img, whose hard coefficient | |||
| is less than threshold_img. | |||
| :param threshold_box: threshold_box to calculate hard coefficient, formula | |||
| is hard coefficient = number(prediction_boxes less than | |||
| threshold_box)/number(prediction_boxes) | |||
| """ | |||
| def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): | |||
| self.threshold_box = float(threshold_box) | |||
| self.threshold_img = float(threshold_img) | |||
| def __call__(self, infer_result=None): | |||
| """Judge the img is hard sample or not. | |||
| :param infer_result: | |||
| prediction boxes list, | |||
| such as [bbox1, bbox2, bbox3,....], | |||
| where bbox = [xmin, ymin, xmax, ymax, score, label] | |||
| score should be in [0,1], who will be ignored if its value not | |||
| in [0,1]. | |||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||
| """ | |||
| if not (infer_result | |||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||
| # if invalid input, return False | |||
| return False | |||
| data_check_list = [bbox[4] for bbox in infer_result | |||
| if self.data_check(bbox[4])] | |||
| if len(data_check_list) != len(infer_result): | |||
| return False | |||
| confidence_score_list = [ | |||
| float(box_score) for box_score in data_check_list | |||
| if float(box_score) <= self.threshold_box] | |||
| return (len(confidence_score_list) / len(infer_result) | |||
| >= (1 - self.threshold_img)) | |||
| from .hard_example_mining import * | |||
| @@ -0,0 +1,153 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Hard Example Mining Algorithms""" | |||
| import abc | |||
| import math | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter') | |||
| class BaseFilter(metaclass=abc.ABCMeta): | |||
| """The base class to define unified interface.""" | |||
| def __call__(self, infer_result=None): | |||
| """predict function, and it must be implemented by | |||
| different methods class. | |||
| :param infer_result: prediction result | |||
| :return: `True` means hard sample, `False` means not a hard sample. | |||
| """ | |||
| raise NotImplementedError | |||
| @classmethod | |||
| def data_check(cls, data): | |||
| """Check the data in [0,1].""" | |||
| return 0 <= float(data) <= 1 | |||
| @ClassFactory.register(ClassType.HEM, alias="Threshold") | |||
| class ThresholdFilter(BaseFilter, abc.ABC): | |||
| def __init__(self, threshold=0.5, **kwargs): | |||
| self.threshold = float(threshold) | |||
| def __call__(self, infer_result=None): | |||
| """ | |||
| :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | |||
| :return: `True` means hard sample, `False` means not a hard sample. | |||
| """ | |||
| # if invalid input, return False | |||
| if not (infer_result | |||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||
| return False | |||
| image_score = 0 | |||
| for bbox in infer_result: | |||
| image_score += bbox[4] | |||
| average_score = image_score / (len(infer_result) or 1) | |||
| return average_score < self.threshold | |||
| @ClassFactory.register(ClassType.HEM, alias="CrossEntropy") | |||
| class CrossEntropyFilter(BaseFilter, abc.ABC): | |||
| """ Implement the hard samples discovery methods named IBT | |||
| (image-box-thresholds). | |||
| :param threshold_cross_entropy: threshold_cross_entropy to filter img, | |||
| whose hard coefficient is less than | |||
| threshold_cross_entropy. And its default value is | |||
| threshold_cross_entropy=0.5 | |||
| """ | |||
| def __init__(self, threshold_cross_entropy=0.5, **kwargs): | |||
| self.threshold_cross_entropy = float(threshold_cross_entropy) | |||
| def __call__(self, infer_result=None): | |||
| """judge the img is hard sample or not. | |||
| :param infer_result: | |||
| prediction classes list, | |||
| such as [class1-score, class2-score, class2-score,....], | |||
| where class-score is the score corresponding to the class, | |||
| class-score value is in [0,1], who will be ignored if its value | |||
| not in [0,1]. | |||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||
| """ | |||
| if not infer_result: | |||
| # if invalid input, return False | |||
| return False | |||
| log_sum = 0.0 | |||
| data_check_list = [class_probability for class_probability | |||
| in infer_result | |||
| if self.data_check(class_probability)] | |||
| if len(data_check_list) != len(infer_result): | |||
| return False | |||
| for class_data in data_check_list: | |||
| log_sum += class_data * math.log(class_data) | |||
| confidence_score = 1 + 1.0 * log_sum / math.log( | |||
| len(infer_result)) | |||
| return confidence_score < self.threshold_cross_entropy | |||
| @ClassFactory.register(ClassType.HEM, alias="IBT") | |||
| class IBTFilter(BaseFilter, abc.ABC): | |||
| """Implement the hard samples discovery methods named IBT | |||
| (image-box-thresholds). | |||
| :param threshold_img: threshold_img to filter img, whose hard coefficient | |||
| is less than threshold_img. | |||
| :param threshold_box: threshold_box to calculate hard coefficient, formula | |||
| is hard coefficient = number(prediction_boxes less than | |||
| threshold_box)/number(prediction_boxes) | |||
| """ | |||
| def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): | |||
| self.threshold_box = float(threshold_box) | |||
| self.threshold_img = float(threshold_img) | |||
| def __call__(self, infer_result=None): | |||
| """Judge the img is hard sample or not. | |||
| :param infer_result: | |||
| prediction boxes list, | |||
| such as [bbox1, bbox2, bbox3,....], | |||
| where bbox = [xmin, ymin, xmax, ymax, score, label] | |||
| score should be in [0,1], who will be ignored if its value not | |||
| in [0,1]. | |||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||
| """ | |||
| if not (infer_result | |||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||
| # if invalid input, return False | |||
| return False | |||
| data_check_list = [bbox[4] for bbox in infer_result | |||
| if self.data_check(bbox[4])] | |||
| if len(data_check_list) != len(infer_result): | |||
| return False | |||
| confidence_score_list = [ | |||
| float(box_score) for box_score in data_check_list | |||
| if float(box_score) <= self.threshold_box] | |||
| return (len(confidence_score_list) / len(infer_result) | |||
| >= (1 - self.threshold_img)) | |||
| @@ -48,7 +48,7 @@ def set_backend(estimator=None, config=None): | |||
| warnings.warn(f"{backend_type} Not Support yet, use itself") | |||
| from sedna.backend.base import BackendBase as REGISTER | |||
| model_save_url = config.get("model_url") | |||
| base_model_save = config.get("base_model_save") or model_save_url | |||
| base_model_save = config.get("base_model_url") or model_save_url | |||
| model_save_name = config.get("model_name") | |||
| return REGISTER( | |||
| estimator=estimator, use_cuda=use_cuda, | |||
| @@ -37,8 +37,9 @@ class TFBackend(BackendBase): | |||
| super(TFBackend, self).__init__( | |||
| estimator=estimator, fine_tune=fine_tune, **kwargs) | |||
| self.framework = "tensorflow" | |||
| sess_config = self._init_gpu_session_config( | |||
| ) if self.use_cuda else self._init_cpu_session_config() | |||
| sess_config = (self._init_gpu_session_config() | |||
| if self.use_cuda else self._init_cpu_session_config()) | |||
| self.graph = tf.Graph() | |||
| with self.graph.as_default(): | |||
| @@ -291,3 +291,27 @@ class Context: | |||
| value = cls.parameters.get( | |||
| param) or cls.parameters.get(str(param).upper()) | |||
| return value if value else default | |||
| @classmethod | |||
| def get_algorithm_from_api(cls, algorithm, **param) -> dict: | |||
| """get the algorithm and parameter from api""" | |||
| hard_example_name = cls.get_parameters(f'{algorithm}_NAME') | |||
| hem_parameters = cls.get_parameters(f'{algorithm}_PARAMETERS') | |||
| try: | |||
| hem_parameters = json.loads(hem_parameters) | |||
| hem_parameters = { | |||
| p["key"]: p.get("value", "") | |||
| for p in hem_parameters if "key" in p | |||
| } | |||
| except Exception: | |||
| hem_parameters = {} | |||
| hem_parameters.update(**param) | |||
| hard_example_mining = { | |||
| "method": hard_example_name, | |||
| "param": hem_parameters | |||
| } | |||
| return hard_example_mining | |||
| @@ -71,27 +71,6 @@ class JobBase(DistributedWorker): | |||
| work_name = f"{self.job_name}-{self.worker_id}" | |||
| self.worker_name = self.config.worker_name or work_name | |||
| @property | |||
| def initial_hem(self): | |||
| hem = self.get_parameters("HEM_NAME") | |||
| hem_parameters = self.get_parameters("HEM_PARAMETERS") | |||
| try: | |||
| hem_parameters = json.loads(hem_parameters) | |||
| hem_parameters = { | |||
| p["key"]: p.get("value", "") | |||
| for p in hem_parameters if "key" in p | |||
| } | |||
| except Exception as err: | |||
| self.log.warn(f"Parse HEM_PARAMETERS failure, " | |||
| f"fallback to empty: {err}") | |||
| hem_parameters = {} | |||
| if hem is None: | |||
| hem = self.config.get("hem_name") or "IBT" | |||
| return ClassFactory.get_cls(ClassType.HEM, hem)(**hem_parameters) | |||
| @property | |||
| def model_path(self): | |||
| if os.path.isfile(self.config.model_url): | |||
| @@ -12,7 +12,6 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import json | |||
| from copy import deepcopy | |||
| from sedna.common.file_ops import FileOps | |||
| @@ -28,20 +27,55 @@ class IncrementalLearning(JobBase): | |||
| Incremental learning | |||
| """ | |||
| def __init__(self, estimator, config=None): | |||
| super(IncrementalLearning, self).__init__( | |||
| estimator=estimator, config=config) | |||
| def __init__(self, estimator, hard_example_mining: dict = None): | |||
| """ | |||
| Initial a IncrementalLearning job | |||
| :param estimator: Customize estimator | |||
| :param hard_example_mining: dict, hard example mining | |||
| algorithms with parameters | |||
| """ | |||
| super(IncrementalLearning, self).__init__(estimator=estimator) | |||
| self.model_urls = self.get_parameters( | |||
| "MODEL_URLS") # use in evaluation | |||
| self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value | |||
| FileOps.clean_folder([self.config.model_url], clean=False) | |||
| self.hard_example_mining_algorithm = self.initial_hem | |||
| self.hard_example_mining_algorithm = None | |||
| if not hard_example_mining: | |||
| hard_example_mining = self.get_hem_algorithm_from_config() | |||
| if hard_example_mining: | |||
| hem = hard_example_mining.get("method", "IBT") | |||
| hem_parameters = hard_example_mining.get("param", {}) | |||
| self.hard_example_mining_algorithm = ClassFactory.get_cls( | |||
| ClassType.HEM, hem | |||
| )(**hem_parameters) | |||
| @classmethod | |||
| def get_hem_algorithm_from_config(cls, **param): | |||
| """ | |||
| get the `algorithm` name and `param` of hard_example_mining from crd | |||
| :param param: update value in parameters of hard_example_mining | |||
| :return: dict, e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}} | |||
| """ | |||
| return cls.parameters.get_algorithm_from_api( | |||
| algorithm="HEM", | |||
| **param | |||
| ) | |||
| def train(self, train_data, | |||
| valid_data=None, | |||
| post_process=None, | |||
| **kwargs): | |||
| """ | |||
| Training task for IncrementalLearning | |||
| :param train_data: datasource use for train | |||
| :param valid_data: datasource use for evaluation | |||
| :param post_process: post process | |||
| :param kwargs: params for training of customize estimator | |||
| :return: estimator | |||
| """ | |||
| callback_func = None | |||
| if post_process is not None: | |||
| callback_func = ClassFactory.get_cls( | |||
| @@ -58,6 +92,14 @@ class IncrementalLearning(JobBase): | |||
| self.estimator) if callback_func else self.estimator | |||
| def inference(self, data=None, post_process=None, **kwargs): | |||
| """ | |||
| Inference task for IncrementalLearning | |||
| :param data: inference sample | |||
| :param post_process: post process | |||
| :param kwargs: params for inference of customize estimator | |||
| :return: inference result, result after post_process, if is hard sample | |||
| """ | |||
| if not self.estimator.has_load: | |||
| self.estimator.load(self.model_path) | |||
| @@ -81,6 +123,14 @@ class IncrementalLearning(JobBase): | |||
| return infer_res, res, is_hard_example | |||
| def evaluate(self, data, post_process=None, **kwargs): | |||
| """ | |||
| Evaluate task for IncrementalLearning | |||
| :param data: datasource use for evaluation | |||
| :param post_process: post process | |||
| :param kwargs: params for evaluate of customize estimator | |||
| :return: evaluate metrics | |||
| """ | |||
| callback_func = None | |||
| if callable(post_process): | |||
| callback_func = post_process | |||