| @@ -1,2 +1,229 @@ | |||||
| ### Sedna Python SDK | |||||
| # Sedna Python SDK | |||||
| The Sedna Python Software Development Kit (SDK) aims to provide developers with a convenient yet flexible tool to write the Sedna applications. | |||||
| This document introduces how to obtain and call Sedna Python SDK. | |||||
| ## Introduction | |||||
| Expose the Edge AI features to applications, i.e. training or inference programs. | |||||
| ## Requirements and Installation | |||||
| The build process is tested with Python 3.6, Ubuntu 18.04.5 LTS | |||||
| ```bash | |||||
| # Clone the repo | |||||
| git clone --recursive https://github.com/kubeedge/sedna.git | |||||
| cd sedna/lib | |||||
| # Build the pip package | |||||
| python setup.py bdist_wheel | |||||
| # Install the pip package | |||||
| pip install dist/sedna*.whl | |||||
| ``` | |||||
| Install via Setuptools | |||||
| ```bash | |||||
| python setup.py install --user | |||||
| ``` | |||||
| ## Use Python SDK | |||||
| 0. (optional) Check `Sedna` version | |||||
| ```bash | |||||
| $ python -c "import sedna; print(sedna.__version__)" | |||||
| ``` | |||||
| 1. Import the required modules as follows: | |||||
| ```python | |||||
| from sedna.core.joint_inference import JointInference, BigModelService | |||||
| from sedna.core.federated_learning import FederatedLearning | |||||
| from sedna.core.incremental_learning import IncrementalLearning | |||||
| from sedna.core.lifelong_learning import LifelongLearning | |||||
| ``` | |||||
| 2. Define an `Estimator`: | |||||
| ```python | |||||
| import os | |||||
| # Keras | |||||
| import keras | |||||
| from keras.layers import Dense, MaxPooling2D, Conv2D, Flatten, Dropout | |||||
| from keras.models import Sequential | |||||
| os.environ['BACKEND_TYPE'] = 'KERAS' | |||||
| def KerasEstimator(): | |||||
| model = Sequential() | |||||
| model.add(Conv2D(64, kernel_size=(3, 3), | |||||
| activation="relu", strides=(2, 2), | |||||
| input_shape=(128, 128, 3))) | |||||
| model.add(MaxPooling2D(pool_size=(2, 2))) | |||||
| model.add(Conv2D(32, kernel_size=(3, 3), activation="relu")) | |||||
| model.add(MaxPooling2D(pool_size=(2, 2))) | |||||
| model.add(Flatten()) | |||||
| model.add(Dropout(0.25)) | |||||
| model.add(Dense(64, activation="relu")) | |||||
| model.add(Dense(32, activation="relu")) | |||||
| model.add(Dropout(0.5)) | |||||
| model.add(Dense(2, activation="softmax")) | |||||
| model.compile(loss="categorical_crossentropy", | |||||
| optimizer="adam", | |||||
| metrics=["accuracy"]) | |||||
| loss = keras.losses.CategoricalCrossentropy(from_logits=True) | |||||
| metrics = [keras.metrics.categorical_accuracy] | |||||
| optimizer = keras.optimizers.Adam(learning_rate=0.1) | |||||
| model.compile(loss=loss, optimizer=optimizer, metrics=metrics) | |||||
| return model | |||||
| ``` | |||||
| ```python | |||||
| # XGBOOST | |||||
| import os | |||||
| import xgboost | |||||
| os.environ['BACKEND_TYPE'] = 'SKLEARN' | |||||
| XGBEstimator = xgboost.XGBClassifier( | |||||
| learning_rate=0.1, | |||||
| n_estimators=600, | |||||
| max_depth=2, | |||||
| min_child_weight=1, | |||||
| gamma=0, | |||||
| subsample=0.8, | |||||
| colsample_bytree=0.8, | |||||
| objective="multi:softmax", | |||||
| num_class=3, | |||||
| nthread=4, | |||||
| seed=27 | |||||
| ) | |||||
| ``` | |||||
| ```python | |||||
| # Customize | |||||
| class Estimator: | |||||
| def __init__(self, **kwargs): | |||||
| ... | |||||
| def load(self, model_url=""): | |||||
| ... | |||||
| def save(self, model_path=None): | |||||
| ... | |||||
| def predict(self, data, **kwargs): | |||||
| ... | |||||
| def evaluate(self, valid_data, **kwargs): | |||||
| ... | |||||
| def train(self, train_data, valid_data=None, **kwargs): | |||||
| ... | |||||
| ``` | |||||
| > **Notes**: Estimator is a high-level API that greatly simplifies machine learning programming. Estimators encapsulate training, evaluation, prediction, and exporting for your model. | |||||
| 3. Initialize a Incremental Learning Job: | |||||
| ```python | |||||
| # get hard exmaple mining algorithm from config | |||||
| hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config( | |||||
| threshold_img=0.9 | |||||
| ) | |||||
| # create Incremental Learning infernece instance | |||||
| il_job = IncrementalLearning( | |||||
| estimator=Estimator, | |||||
| hard_example_mining=hard_example_mining | |||||
| ) | |||||
| ``` | |||||
| where: | |||||
| - `IncrementalLearning` is the Cloud-edge job you want to access. | |||||
| - `Estimator` is the base model for your ML job. | |||||
| - `hard_example_mining` is the parameters of incremental learning job. | |||||
| Inference | |||||
| --------- | |||||
| > **Note:** The `job parameters` of each feature are different. | |||||
| 4. Running Job - training / inference / evaluation. | |||||
| ```python | |||||
| results, final_res, is_hard_example = il_job.inference( | |||||
| img_rgb, | |||||
| post_process=deal_infer_rsl, | |||||
| input_shape=input_shape | |||||
| ) | |||||
| ``` | |||||
| where: | |||||
| - `img_rgb` is the sample used to inference | |||||
| - `deal_infer_rsl` is a function used to process result after model predict | |||||
| - `input_shape` is the parameters of `Estimator` in inference | |||||
| - `results` is the result predicted by model | |||||
| - `final_res` is the result after process by `deal_infer_rsl` | |||||
| - `is_hard_example` tells if the sample is hard sample or not | |||||
| ## Customize algorithm | |||||
| Sedna provides a class called `class_factory.py` in `common` package, in which only a few lines of changes are required to become a module of sedna. | |||||
| Two classes are defined in `class_factory.py`, namely `ClassType` and `ClassFactory`. | |||||
| `ClassFactory` can register the modules you want to reuse through decorators. For example, in the following code example, you have customized an **hard_example_mining algorithm**, you only need to add a line of `ClassFactory.register(ClassType.HEM)` to complete the registration. | |||||
| ```python | |||||
| @ClassFactory.register(ClassType.HEM, alias="Threshold") | |||||
| class ThresholdFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self, threshold=0.5, **kwargs): | |||||
| self.threshold = float(threshold) | |||||
| def __call__(self, infer_result=None): | |||||
| # if invalid input, return False | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| return False | |||||
| image_score = 0 | |||||
| for bbox in infer_result: | |||||
| image_score += bbox[4] | |||||
| average_score = image_score / (len(infer_result) or 1) | |||||
| return average_score < self.threshold | |||||
| ``` | |||||
| After registration, you only need to change the name of the hem and parameters in the yaml file, and then the corresponding class will be automatically called according to the name. | |||||
| ```yaml | |||||
| deploySpec: | |||||
| hardExampleMining: | |||||
| name: "Threshold" | |||||
| parameters: | |||||
| - key: "threshold" | |||||
| value: "0.9" | |||||
| ``` | |||||
| @@ -1 +1 @@ | |||||
| 0.0.3 | |||||
| 0.3.1 | |||||
| @@ -14,4 +14,4 @@ | |||||
| """sedna version information.""" | """sedna version information.""" | ||||
| __version__ = '0.0.2' | |||||
| __version__ = '0.3.1' | |||||
| @@ -12,7 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| from .aggregation import * | |||||
| from .hard_example_mining import * | |||||
| from .multi_task_learning import * | |||||
| from .unseen_task_detect import * | |||||
| from . import aggregation # federated_learning | |||||
| from . import hard_example_mining # joint_inference incremental_learning | |||||
| from . import multi_task_learning # lifelong_learning | |||||
| from . import unseen_task_detect # lifelong_learning | |||||
| @@ -12,4 +12,4 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| from .aggregation import * | |||||
| from . import aggregation | |||||
| @@ -18,21 +18,32 @@ import abc | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from typing import List | from typing import List | ||||
| import numpy as np | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | from sedna.common.class_factory import ClassFactory, ClassType | ||||
| __all__ = ('AggClient', 'FedAvg',) | __all__ = ('AggClient', 'FedAvg',) | ||||
| class AggClient: | class AggClient: | ||||
| """Aggregation clients""" | |||||
| """ | |||||
| Client that interacts with cloud aggregator | |||||
| Parameters | |||||
| ---------- | |||||
| num_samples: int | |||||
| number of samples for the current weights | |||||
| weights: List | |||||
| weights of the layer as a list of number-like array, | |||||
| such as [[0, 0, 0, 0], [0, 0, 0, 0] ... ] | |||||
| """ | |||||
| num_samples: int | num_samples: int | ||||
| weights: List | weights: List | ||||
| class BaseAggregation(metaclass=abc.ABCMeta): | class BaseAggregation(metaclass=abc.ABCMeta): | ||||
| """Abstract class of aggregator""" | |||||
| """ | |||||
| Abstract class of aggregator | |||||
| """ | |||||
| def __init__(self): | def __init__(self): | ||||
| self.total_size = 0 | self.total_size = 0 | ||||
| @@ -45,19 +56,41 @@ class BaseAggregation(metaclass=abc.ABCMeta): | |||||
| but some can be calculated only after all aggregated data is uploaded. | but some can be calculated only after all aggregated data is uploaded. | ||||
| therefore, this abstractmethod should consider that all weights are | therefore, this abstractmethod should consider that all weights are | ||||
| uploaded. | uploaded. | ||||
| :param clients: All clients in federated learning job | |||||
| :return: final weights | |||||
| Parameters | |||||
| ---------- | |||||
| clients: List | |||||
| All clients in federated learning job | |||||
| Returns | |||||
| ------- | |||||
| Array-like | |||||
| final weights use to update model layer | |||||
| """ | """ | ||||
| @ClassFactory.register(ClassType.FL_AGG) | @ClassFactory.register(ClassType.FL_AGG) | ||||
| class FedAvg(BaseAggregation, abc.ABC): | class FedAvg(BaseAggregation, abc.ABC): | ||||
| """ | """ | ||||
| Federated averaging algorithm : Calculate the average weight | |||||
| according to the number of samples | |||||
| Federated averaging algorithm | |||||
| """ | """ | ||||
| def aggregate(self, clients: List[AggClient]): | def aggregate(self, clients: List[AggClient]): | ||||
| """ | |||||
| Calculate the average weight according to the number of samples | |||||
| Parameters | |||||
| ---------- | |||||
| clients: List | |||||
| All clients in federated learning job | |||||
| Returns | |||||
| ------- | |||||
| update_weights : Array-like | |||||
| final weights use to update model layer | |||||
| """ | |||||
| import numpy as np | |||||
| if not len(clients): | if not len(clients): | ||||
| return self.weights | return self.weights | ||||
| self.total_size = sum([c.num_samples for c in clients]) | self.total_size = sum([c.num_samples for c in clients]) | ||||
| @@ -12,4 +12,4 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| from .hard_example_mining import * | |||||
| from . import hard_example_mining | |||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Hard Example Mining Algorithms""" | """Hard Example Mining Algorithms""" | ||||
| import abc | import abc | ||||
| import math | import math | ||||
| @@ -25,11 +26,18 @@ class BaseFilter(metaclass=abc.ABCMeta): | |||||
| """The base class to define unified interface.""" | """The base class to define unified interface.""" | ||||
| def __call__(self, infer_result=None): | def __call__(self, infer_result=None): | ||||
| """predict function, and it must be implemented by | |||||
| different methods class. | |||||
| """ | |||||
| predict function, judge the sample is hard or not. | |||||
| :param infer_result: prediction result | |||||
| :return: `True` means hard sample, `False` means not a hard sample. | |||||
| Parameters | |||||
| ---------- | |||||
| infer_result : array_like | |||||
| prediction result | |||||
| Returns | |||||
| ------- | |||||
| is_hard_sample : bool | |||||
| `True` means hard sample, `False` means not. | |||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -41,14 +49,18 @@ class BaseFilter(metaclass=abc.ABCMeta): | |||||
| @ClassFactory.register(ClassType.HEM, alias="Threshold") | @ClassFactory.register(ClassType.HEM, alias="Threshold") | ||||
| class ThresholdFilter(BaseFilter, abc.ABC): | class ThresholdFilter(BaseFilter, abc.ABC): | ||||
| def __init__(self, threshold=0.5, **kwargs): | |||||
| """ | |||||
| **Object detection** Hard samples discovery methods named `Threshold` | |||||
| Parameters | |||||
| ---------- | |||||
| threshold: float | |||||
| hard coefficient threshold score to filter img, default to 0.5. | |||||
| """ | |||||
| def __init__(self, threshold: float = 0.5, **kwargs): | |||||
| self.threshold = float(threshold) | self.threshold = float(threshold) | ||||
| def __call__(self, infer_result=None): | |||||
| """ | |||||
| :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | |||||
| :return: `True` means hard sample, `False` means not a hard sample. | |||||
| """ | |||||
| def __call__(self, infer_result=None) -> bool: | |||||
| # if invalid input, return False | # if invalid input, return False | ||||
| if not (infer_result | if not (infer_result | ||||
| and all(map(lambda x: len(x) > 4, infer_result))): | and all(map(lambda x: len(x) > 4, infer_result))): | ||||
| @@ -65,28 +77,34 @@ class ThresholdFilter(BaseFilter, abc.ABC): | |||||
| @ClassFactory.register(ClassType.HEM, alias="CrossEntropy") | @ClassFactory.register(ClassType.HEM, alias="CrossEntropy") | ||||
| class CrossEntropyFilter(BaseFilter, abc.ABC): | class CrossEntropyFilter(BaseFilter, abc.ABC): | ||||
| """ Implement the hard samples discovery methods named IBT | |||||
| (image-box-thresholds). | |||||
| """ | |||||
| **Object detection** Hard samples discovery methods named `CrossEntropy` | |||||
| :param threshold_cross_entropy: threshold_cross_entropy to filter img, | |||||
| whose hard coefficient is less than | |||||
| threshold_cross_entropy. And its default value is | |||||
| threshold_cross_entropy=0.5 | |||||
| Parameters | |||||
| ---------- | |||||
| threshold_cross_entropy: float | |||||
| hard coefficient threshold score to filter img, default to 0.5. | |||||
| """ | """ | ||||
| def __init__(self, threshold_cross_entropy=0.5, **kwargs): | def __init__(self, threshold_cross_entropy=0.5, **kwargs): | ||||
| self.threshold_cross_entropy = float(threshold_cross_entropy) | self.threshold_cross_entropy = float(threshold_cross_entropy) | ||||
| def __call__(self, infer_result=None): | |||||
| def __call__(self, infer_result=None) -> bool: | |||||
| """judge the img is hard sample or not. | """judge the img is hard sample or not. | ||||
| :param infer_result: | |||||
| prediction classes list, | |||||
| such as [class1-score, class2-score, class2-score,....], | |||||
| where class-score is the score corresponding to the class, | |||||
| class-score value is in [0,1], who will be ignored if its value | |||||
| not in [0,1]. | |||||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||||
| Parameters | |||||
| ---------- | |||||
| infer_result: array_like | |||||
| prediction classes list, such as | |||||
| [class1-score, class2-score, class2-score,....], | |||||
| where class-score is the score corresponding to the class, | |||||
| class-score value is in [0,1], who will be ignored if its | |||||
| value not in [0,1]. | |||||
| Returns | |||||
| ------- | |||||
| is hard sample: bool | |||||
| `True` means hard sample, `False` means not. | |||||
| """ | """ | ||||
| if not infer_result: | if not infer_result: | ||||
| @@ -110,30 +128,38 @@ class CrossEntropyFilter(BaseFilter, abc.ABC): | |||||
| @ClassFactory.register(ClassType.HEM, alias="IBT") | @ClassFactory.register(ClassType.HEM, alias="IBT") | ||||
| class IBTFilter(BaseFilter, abc.ABC): | class IBTFilter(BaseFilter, abc.ABC): | ||||
| """Implement the hard samples discovery methods named IBT | |||||
| (image-box-thresholds). | |||||
| :param threshold_img: threshold_img to filter img, whose hard coefficient | |||||
| is less than threshold_img. | |||||
| :param threshold_box: threshold_box to calculate hard coefficient, formula | |||||
| is hard coefficient = number(prediction_boxes less than | |||||
| threshold_box)/number(prediction_boxes) | |||||
| """ | |||||
| **Object detection** Hard samples discovery methods named `IBT` | |||||
| Parameters | |||||
| ---------- | |||||
| threshold_img: float | |||||
| hard coefficient threshold score to filter img, default to 0.5. | |||||
| threshold_box: float | |||||
| threshold_box to calculate hard coefficient, formula is hard | |||||
| coefficient = number(prediction_boxes less than threshold_box) / | |||||
| number(prediction_boxes) | |||||
| """ | """ | ||||
| def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): | def __init__(self, threshold_img=0.5, threshold_box=0.5, **kwargs): | ||||
| self.threshold_box = float(threshold_box) | self.threshold_box = float(threshold_box) | ||||
| self.threshold_img = float(threshold_img) | self.threshold_img = float(threshold_img) | ||||
| def __call__(self, infer_result=None): | |||||
| def __call__(self, infer_result=None) -> bool: | |||||
| """Judge the img is hard sample or not. | """Judge the img is hard sample or not. | ||||
| :param infer_result: | |||||
| prediction boxes list, | |||||
| such as [bbox1, bbox2, bbox3,....], | |||||
| where bbox = [xmin, ymin, xmax, ymax, score, label] | |||||
| score should be in [0,1], who will be ignored if its value not | |||||
| in [0,1]. | |||||
| :return: `True` means a hard sample, `False` means not a hard sample. | |||||
| Parameters | |||||
| ---------- | |||||
| infer_result: array_like | |||||
| prediction boxes list, such as [bbox1, bbox2, bbox3,....], | |||||
| where bbox = [xmin, ymin, xmax, ymax, score, label] | |||||
| score should be in [0,1], who will be ignored if its value not | |||||
| in [0,1]. | |||||
| Returns | |||||
| ------- | |||||
| is hard sample: bool | |||||
| `True` means hard sample, `False` means not. | |||||
| """ | """ | ||||
| if not (infer_result | if not (infer_result | ||||
| @@ -12,5 +12,5 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| from .task_jobs import * | |||||
| from . import task_jobs | |||||
| from .multi_task_learning import MulTaskLearning | from .multi_task_learning import MulTaskLearning | ||||
| @@ -12,8 +12,9 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """Multiple task transfer learning algorithms""" | |||||
| import json | import json | ||||
| import pandas as pd | |||||
| from sedna.datasources import BaseDataSource | from sedna.datasources import BaseDataSource | ||||
| from sedna.backend import set_backend | from sedna.backend import set_backend | ||||
| @@ -29,6 +30,72 @@ __all__ = ('MulTaskLearning',) | |||||
| class MulTaskLearning: | class MulTaskLearning: | ||||
| """ | |||||
| An auto machine learning framework for edge-cloud multitask learning | |||||
| See Also | |||||
| -------- | |||||
| Train: Data + Estimator -> Task Definition -> Task Relationship Discovery | |||||
| -> Feature Engineering -> Training | |||||
| Inference: Data -> Task Allocation -> Task Mining -> Feature Engineering | |||||
| -> Task Remodeling -> Inference | |||||
| Parameters | |||||
| ---------- | |||||
| estimator : Instance | |||||
| An instance with the high-level API that greatly simplifies | |||||
| machine learning programming. Estimators encapsulate training, | |||||
| evaluation, prediction, and exporting for your model. | |||||
| task_definition : Dict | |||||
| Divide multiple tasks based on data, | |||||
| see `task_jobs.task_definition` for more detail. | |||||
| task_relationship_discovery : Dict | |||||
| Discover relationships between all tasks, see | |||||
| `task_jobs.task_relationship_discovery` for more detail. | |||||
| task_mining : Dict | |||||
| Mining tasks of inference sample, | |||||
| see `task_jobs.task_mining` for more detail. | |||||
| task_remodeling : Dict | |||||
| Remodeling tasks based on their relationships, | |||||
| see `task_jobs.task_remodeling` for more detail. | |||||
| inference_integrate : Dict | |||||
| Integrate the inference results of all related | |||||
| tasks, see `task_jobs.inference_integrate` for more detail. | |||||
| Examples | |||||
| -------- | |||||
| >>> from xgboost import XGBClassifier | |||||
| >>> from sedna.algorithms.multi_task_learning import MulTaskLearning | |||||
| >>> estimator = XGBClassifier(objective="binary:logistic") | |||||
| >>> task_definition = { | |||||
| "method": "TaskDefinitionByDataAttr", | |||||
| "param": {"attribute": ["season", "city"]} | |||||
| } | |||||
| >>> task_relationship_discovery = { | |||||
| "method": "DefaultTaskRelationDiscover", "param": {} | |||||
| } | |||||
| >>> task_mining = { | |||||
| "method": "TaskMiningByDataAttr", | |||||
| "param": {"attribute": ["season", "city"]} | |||||
| } | |||||
| >>> task_remodeling = None | |||||
| >>> inference_integrate = { | |||||
| "method": "DefaultInferenceIntegrate", "param": {} | |||||
| } | |||||
| >>> mul_task_instance = MulTaskLearning( | |||||
| estimator=estimator, | |||||
| task_definition=task_definition, | |||||
| task_relationship_discovery=task_relationship_discovery, | |||||
| task_mining=task_mining, | |||||
| task_remodeling=task_remodeling, | |||||
| inference_integrate=inference_integrate | |||||
| ) | |||||
| Notes | |||||
| ----- | |||||
| All method defined under `task_jobs` and registered in `ClassFactory`. | |||||
| """ | |||||
| _method_pair = { | _method_pair = { | ||||
| 'TaskDefinitionBySVC': 'TaskMiningBySVC', | 'TaskDefinitionBySVC': 'TaskMiningBySVC', | ||||
| 'TaskDefinitionByDataAttr': 'TaskMiningByDataAttr', | 'TaskDefinitionByDataAttr': 'TaskMiningByDataAttr', | ||||
| @@ -66,7 +133,7 @@ class MulTaskLearning: | |||||
| )) | )) | ||||
| @staticmethod | @staticmethod | ||||
| def parse_param(param_str): | |||||
| def _parse_param(param_str): | |||||
| if not param_str: | if not param_str: | ||||
| return {} | return {} | ||||
| if isinstance(param_str, dict): | if isinstance(param_str, dict): | ||||
| @@ -84,7 +151,7 @@ class MulTaskLearning: | |||||
| method_name = self.task_definition.get( | method_name = self.task_definition.get( | ||||
| "method", "TaskDefinitionByDataAttr" | "method", "TaskDefinitionByDataAttr" | ||||
| ) | ) | ||||
| extend_param = self.parse_param( | |||||
| extend_param = self._parse_param( | |||||
| self.task_definition.get("param") | self.task_definition.get("param") | ||||
| ) | ) | ||||
| method_cls = ClassFactory.get_cls( | method_cls = ClassFactory.get_cls( | ||||
| @@ -96,7 +163,7 @@ class MulTaskLearning: | |||||
| Merge tasks from task_definition | Merge tasks from task_definition | ||||
| """ | """ | ||||
| method_name = self.task_relationship_discovery.get("method") | method_name = self.task_relationship_discovery.get("method") | ||||
| extend_param = self.parse_param( | |||||
| extend_param = self._parse_param( | |||||
| self.task_relationship_discovery.get("param") | self.task_relationship_discovery.get("param") | ||||
| ) | ) | ||||
| method_cls = ClassFactory.get_cls( | method_cls = ClassFactory.get_cls( | ||||
| @@ -108,7 +175,7 @@ class MulTaskLearning: | |||||
| Mining tasks of inference sample base on task attribute extractor | Mining tasks of inference sample base on task attribute extractor | ||||
| """ | """ | ||||
| method_name = self.task_mining.get("method") | method_name = self.task_mining.get("method") | ||||
| extend_param = self.parse_param( | |||||
| extend_param = self._parse_param( | |||||
| self.task_mining.get("param") | self.task_mining.get("param") | ||||
| ) | ) | ||||
| @@ -118,7 +185,7 @@ class MulTaskLearning: | |||||
| ) | ) | ||||
| method_name = self._method_pair.get(task_definition, | method_name = self._method_pair.get(task_definition, | ||||
| 'TaskMiningByDataAttr') | 'TaskMiningByDataAttr') | ||||
| extend_param = self.parse_param( | |||||
| extend_param = self._parse_param( | |||||
| self.task_definition.get("param")) | self.task_definition.get("param")) | ||||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | ||||
| task_extractor=self.extractor, **extend_param | task_extractor=self.extractor, **extend_param | ||||
| @@ -130,7 +197,7 @@ class MulTaskLearning: | |||||
| Remodeling tasks from task mining | Remodeling tasks from task mining | ||||
| """ | """ | ||||
| method_name = self.task_remodeling.get("method") | method_name = self.task_remodeling.get("method") | ||||
| extend_param = self.parse_param( | |||||
| extend_param = self._parse_param( | |||||
| self.task_remodeling.get("param")) | self.task_remodeling.get("param")) | ||||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | ||||
| models=self.models, **extend_param) | models=self.models, **extend_param) | ||||
| @@ -141,7 +208,7 @@ class MulTaskLearning: | |||||
| Aggregate inference results from target models | Aggregate inference results from target models | ||||
| """ | """ | ||||
| method_name = self.inference_integrate.get("method") | method_name = self.inference_integrate.get("method") | ||||
| extend_param = self.parse_param( | |||||
| extend_param = self._parse_param( | |||||
| self.inference_integrate.get("param")) | self.inference_integrate.get("param")) | ||||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | ||||
| models=self.models, **extend_param) | models=self.models, **extend_param) | ||||
| @@ -150,6 +217,29 @@ class MulTaskLearning: | |||||
| def train(self, train_data: BaseDataSource, | def train(self, train_data: BaseDataSource, | ||||
| valid_data: BaseDataSource = None, | valid_data: BaseDataSource = None, | ||||
| post_process=None, **kwargs): | post_process=None, **kwargs): | ||||
| """ | |||||
| fit for update the knowledge based on training data. | |||||
| Parameters | |||||
| ---------- | |||||
| train_data : BaseDataSource | |||||
| Train data, see `sedna.datasources.BaseDataSource` for more detail. | |||||
| valid_data : BaseDataSource | |||||
| Valid data, BaseDataSource or None. | |||||
| post_process : function | |||||
| function or a registered method, callback after `estimator` train. | |||||
| kwargs : Dict | |||||
| parameters for `estimator` training, Like: | |||||
| `early_stopping_rounds` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| feedback : Dict | |||||
| contain all training result in each tasks. | |||||
| task_index_url : str | |||||
| task extractor model path, used for task mining. | |||||
| """ | |||||
| tasks, task_extractor, train_data = self._task_definition(train_data) | tasks, task_extractor, train_data = self._task_definition(train_data) | ||||
| self.extractor = task_extractor | self.extractor = task_extractor | ||||
| task_groups = self._task_relationship_discovery(tasks) | task_groups = self._task_relationship_discovery(tasks) | ||||
| @@ -234,6 +324,16 @@ class MulTaskLearning: | |||||
| return feedback, self.task_index_url | return feedback, self.task_index_url | ||||
| def load(self, task_index_url=None): | def load(self, task_index_url=None): | ||||
| """ | |||||
| load task_detail (tasks/models etc ...) from task index file. | |||||
| It'll automatically loaded during `inference` and `evaluation` phases. | |||||
| Parameters | |||||
| ---------- | |||||
| task_index_url : str | |||||
| task index file path, default self.task_index_url. | |||||
| """ | |||||
| if task_index_url: | if task_index_url: | ||||
| self.task_index_url = task_index_url | self.task_index_url = task_index_url | ||||
| assert FileOps.exists(self.task_index_url), FileExistsError( | assert FileOps.exists(self.task_index_url), FileExistsError( | ||||
| @@ -248,6 +348,28 @@ class MulTaskLearning: | |||||
| def predict(self, data: BaseDataSource, | def predict(self, data: BaseDataSource, | ||||
| post_process=None, **kwargs): | post_process=None, **kwargs): | ||||
| """ | |||||
| predict the result for input data based on training knowledge. | |||||
| Parameters | |||||
| ---------- | |||||
| data : BaseDataSource | |||||
| inference sample, see `sedna.datasources.BaseDataSource` for | |||||
| more detail. | |||||
| post_process: function | |||||
| function or a registered method, effected after `estimator` | |||||
| prediction, like: label transform. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` predict, Like: | |||||
| `ntree_limit` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| result : array_like | |||||
| results array, contain all inference results in each sample. | |||||
| tasks : List | |||||
| tasks assigned to each sample. | |||||
| """ | |||||
| if not (self.models and self.extractor): | if not (self.models and self.extractor): | ||||
| self.load() | self.load() | ||||
| @@ -284,6 +406,31 @@ class MulTaskLearning: | |||||
| metrics=None, | metrics=None, | ||||
| metrics_param=None, | metrics_param=None, | ||||
| **kwargs): | **kwargs): | ||||
| """ | |||||
| evaluated the performance of each task from training, filter tasks | |||||
| based on the defined rules. | |||||
| Parameters | |||||
| ---------- | |||||
| data : BaseDataSource | |||||
| valid data, see `sedna.datasources.BaseDataSource` for more detail. | |||||
| metrics : function / str | |||||
| Metrics to assess performance on the task by given prediction. | |||||
| metrics_param : Dict | |||||
| parameter for metrics function. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` evaluate, Like: | |||||
| `ntree_limit` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| task_eval_res : Dict | |||||
| all metric results. | |||||
| tasks_detail : List[Object] | |||||
| all metric results in each task. | |||||
| """ | |||||
| import pandas as pd | |||||
| from sklearn import metrics as sk_metrics | from sklearn import metrics as sk_metrics | ||||
| result, tasks = self.predict(data, **kwargs) | result, tasks = self.predict(data, **kwargs) | ||||
| @@ -325,7 +472,7 @@ class MulTaskLearning: | |||||
| if not metrics_param: | if not metrics_param: | ||||
| metrics_param = {} | metrics_param = {} | ||||
| elif isinstance(metrics_param, str): | elif isinstance(metrics_param, str): | ||||
| metrics_param = self.parse_param(metrics_param) | |||||
| metrics_param = self._parse_param(metrics_param) | |||||
| tasks_detail = [] | tasks_detail = [] | ||||
| for task in tasks: | for task in tasks: | ||||
| sample = task.samples | sample = task.samples | ||||
| @@ -12,10 +12,13 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| from .task_definition import * | |||||
| from .task_relation_discover import * | |||||
| # train | |||||
| from . import task_definition | |||||
| from . import task_relation_discover | |||||
| from .task_mining import * | |||||
| from .task_remodeling import * | |||||
| # inference | |||||
| from . import task_mining | |||||
| from . import task_remodeling | |||||
| from .inference_integrate import * | |||||
| # result integrate | |||||
| from . import inference_integrate | |||||
| @@ -12,6 +12,10 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | |||||
| Integrate the inference results of all related tasks | |||||
| """ | |||||
| from typing import List | from typing import List | ||||
| import numpy as np | import numpy as np | ||||
| @@ -25,10 +29,26 @@ __all__ = ('DefaultInferenceIntegrate', ) | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class DefaultInferenceIntegrate: | class DefaultInferenceIntegrate: | ||||
| """ | |||||
| Default calculation algorithm for inference integration | |||||
| Parameters | |||||
| ---------- | |||||
| models: All models used for sample inference | |||||
| """ | |||||
| def __init__(self, models: list, **kwargs): | def __init__(self, models: list, **kwargs): | ||||
| self.models = models | self.models = models | ||||
| def __call__(self, tasks: List[Task]): | def __call__(self, tasks: List[Task]): | ||||
| """ | |||||
| Parameters | |||||
| ---------- | |||||
| tasks: All tasks with sample result | |||||
| Returns | |||||
| ------- | |||||
| result: minimum result | |||||
| """ | |||||
| res = {} | res = {} | ||||
| for task in tasks: | for task in tasks: | ||||
| res.update(dict(zip(task.samples.inx, task.result))) | res.update(dict(zip(task.samples.inx, task.result))) | ||||
| @@ -12,6 +12,19 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | |||||
| Divide multiple tasks based on data | |||||
| Parameters | |||||
| ---------- | |||||
| samples: Train data, see `sedna.datasources.BaseDataSource` for more detail. | |||||
| Returns | |||||
| ------- | |||||
| tasks: All tasks based on training data. | |||||
| task_extractor: Model with a method to predicting target tasks | |||||
| """ | |||||
| from typing import List, Any, Tuple | from typing import List, Any, Tuple | ||||
| import numpy as np | import numpy as np | ||||
| @@ -28,6 +41,16 @@ __all__ = ('TaskDefinitionBySVC', 'TaskDefinitionByDataAttr') | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class TaskDefinitionBySVC: | class TaskDefinitionBySVC: | ||||
| """ | |||||
| Dividing datasets with `AgglomerativeClustering` based on kernel distance, | |||||
| Using SVC to fit the clustering result. | |||||
| Parameters | |||||
| ---------- | |||||
| n_class: int or None | |||||
| The number of clusters to find, default=2. | |||||
| """ | |||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| n_class = kwargs.get("n_class", "") | n_class = kwargs.get("n_class", "") | ||||
| self.n_class = max(2, int(n_class)) if str(n_class).isdigit() else 2 | self.n_class = max(2, int(n_class)) if str(n_class).isdigit() else 2 | ||||
| @@ -67,6 +90,15 @@ class TaskDefinitionBySVC: | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class TaskDefinitionByDataAttr: | class TaskDefinitionByDataAttr: | ||||
| """ | |||||
| Dividing datasets based on the common attributes, | |||||
| generally used for structured data. | |||||
| Parameters | |||||
| ---------- | |||||
| attribute: List[Metadata] | |||||
| metadata is usually a class feature label with a finite values. | |||||
| """ | |||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| self.attr_filed = kwargs.get("attribute", []) | self.attr_filed = kwargs.get("attribute", []) | ||||
| @@ -12,6 +12,18 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | |||||
| Mining tasks of inference sample base on task attribute extractor | |||||
| Parameters | |||||
| ---------- | |||||
| samples : infer sample, see `sedna.datasources.BaseDataSource` for more detail. | |||||
| Returns | |||||
| ------- | |||||
| allocations : tasks that assigned to each sample | |||||
| """ | |||||
| from sedna.datasources import BaseDataSource | from sedna.datasources import BaseDataSource | ||||
| from sedna.common.class_factory import ClassFactory, ClassType | from sedna.common.class_factory import ClassFactory, ClassType | ||||
| @@ -21,6 +33,14 @@ __all__ = ('TaskMiningBySVC', 'TaskMiningByDataAttr') | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class TaskMiningBySVC: | class TaskMiningBySVC: | ||||
| """ | |||||
| Corresponding to `TaskDefinitionBySVC` | |||||
| Parameters | |||||
| ---------- | |||||
| task_extractor : Model | |||||
| SVC Model used to predicting target tasks | |||||
| """ | |||||
| def __init__(self, task_extractor, **kwargs): | def __init__(self, task_extractor, **kwargs): | ||||
| self.task_extractor = task_extractor | self.task_extractor = task_extractor | ||||
| @@ -38,6 +58,17 @@ class TaskMiningBySVC: | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class TaskMiningByDataAttr: | class TaskMiningByDataAttr: | ||||
| """ | |||||
| Corresponding to `TaskDefinitionByDataAttr` | |||||
| Parameters | |||||
| ---------- | |||||
| task_extractor : Dict | |||||
| used to match target tasks | |||||
| attr_filed: List[Metadata] | |||||
| metadata is usually a class feature | |||||
| label with a finite values. | |||||
| """ | |||||
| def __init__(self, task_extractor, **kwargs): | def __init__(self, task_extractor, **kwargs): | ||||
| self.task_extractor = task_extractor | self.task_extractor = task_extractor | ||||
| self.attr_filed = kwargs.get("attribute", []) | self.attr_filed = kwargs.get("attribute", []) | ||||
| @@ -12,6 +12,18 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | |||||
| Discover relationships between all tasks | |||||
| Parameters | |||||
| ---------- | |||||
| tasks :all tasks form `task_definition` | |||||
| Returns | |||||
| ------- | |||||
| task_groups : List of groups which including at least 1 task. | |||||
| """ | |||||
| from typing import List | from typing import List | ||||
| from sedna.common.class_factory import ClassType, ClassFactory | from sedna.common.class_factory import ClassType, ClassFactory | ||||
| @@ -24,6 +36,10 @@ __all__ = ('DefaultTaskRelationDiscover', ) | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class DefaultTaskRelationDiscover: | class DefaultTaskRelationDiscover: | ||||
| """ | |||||
| Assume that each task is independent of each other | |||||
| """ | |||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| pass | pass | ||||
| @@ -12,6 +12,19 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | |||||
| Remodeling tasks based on their relationships | |||||
| Parameters | |||||
| ---------- | |||||
| mappings :all assigned tasks get from the `task_mining` | |||||
| samples : input samples | |||||
| Returns | |||||
| ------- | |||||
| models : List of groups which including at least 1 task. | |||||
| """ | |||||
| from typing import List | from typing import List | ||||
| import numpy as np | import numpy as np | ||||
| @@ -25,10 +38,17 @@ __all__ = ('DefaultTaskRemodeling',) | |||||
| @ClassFactory.register(ClassType.MTL) | @ClassFactory.register(ClassType.MTL) | ||||
| class DefaultTaskRemodeling: | class DefaultTaskRemodeling: | ||||
| """ | |||||
| Assume that each task is independent of each other | |||||
| """ | |||||
| def __init__(self, models: list, **kwargs): | def __init__(self, models: list, **kwargs): | ||||
| self.models = models | self.models = models | ||||
| def __call__(self, samples: BaseDataSource, mappings: List): | def __call__(self, samples: BaseDataSource, mappings: List): | ||||
| """ | |||||
| Grouping based on assigned tasks | |||||
| """ | |||||
| mappings = np.array(mappings) | mappings = np.array(mappings) | ||||
| data, models = [], [] | data, models = [], [] | ||||
| d_type = samples.data_type | d_type = samples.data_type | ||||
| @@ -12,7 +12,9 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """Unseen task detection algorithms for Lifelong Learning""" | |||||
| """ | |||||
| Unseen task detection algorithms for Lifelong Learning | |||||
| """ | |||||
| import abc | import abc | ||||
| from typing import List | from typing import List | ||||
| @@ -28,22 +30,46 @@ __all__ = ('ModelProbeFilter', 'TaskAttrFilter') | |||||
| class BaseFilter(metaclass=abc.ABCMeta): | class BaseFilter(metaclass=abc.ABCMeta): | ||||
| """The base class to define unified interface.""" | """The base class to define unified interface.""" | ||||
| def __call__(self, task: Task = None): | |||||
| """predict function, and it must be implemented by | |||||
| def __call__(self, tasks: Task = None): | |||||
| """ | |||||
| predict function, and it must be implemented by | |||||
| different methods class. | different methods class. | ||||
| :param task: inference task | |||||
| :return: `True` means unseen task, `False` means not an unseen task. | |||||
| Parameters | |||||
| ---------- | |||||
| tasks : inference task | |||||
| Returns | |||||
| ------- | |||||
| is unseen task : bool | |||||
| `True` means unseen task, `False` means not. | |||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @ClassFactory.register(ClassType.UTD) | @ClassFactory.register(ClassType.UTD) | ||||
| class ModelProbeFilter(BaseFilter, abc.ABC): | class ModelProbeFilter(BaseFilter, abc.ABC): | ||||
| """ | |||||
| Judgment based on the confidence of the prediction result, | |||||
| typically used for classification problems | |||||
| """ | |||||
| def __init__(self): | def __init__(self): | ||||
| pass | pass | ||||
| def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): | def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): | ||||
| """ | |||||
| Parameters | |||||
| ---------- | |||||
| tasks : inference task | |||||
| threshold : float | |||||
| threshold considered credible | |||||
| Returns | |||||
| ------- | |||||
| is unseen task: bool | |||||
| `True` means unseen task, `False` means not. | |||||
| """ | |||||
| all_proba = [] | all_proba = [] | ||||
| for task in tasks: | for task in tasks: | ||||
| sample = task.samples | sample = task.samples | ||||
| @@ -56,10 +82,23 @@ class ModelProbeFilter(BaseFilter, abc.ABC): | |||||
| @ClassFactory.register(ClassType.UTD) | @ClassFactory.register(ClassType.UTD) | ||||
| class TaskAttrFilter(BaseFilter, abc.ABC): | class TaskAttrFilter(BaseFilter, abc.ABC): | ||||
| """ | |||||
| Judgment based on whether the metadata of the sample has been found in KB | |||||
| """ | |||||
| def __init__(self): | def __init__(self): | ||||
| pass | pass | ||||
| def __call__(self, tasks: List[Task] = None, **kwargs): | def __call__(self, tasks: List[Task] = None, **kwargs): | ||||
| """ | |||||
| Parameters | |||||
| ---------- | |||||
| tasks : inference task | |||||
| Returns | |||||
| ------- | |||||
| is unseen task: bool | |||||
| `True` means unseen task, `False` means not. | |||||
| """ | |||||
| for task in tasks: | for task in tasks: | ||||
| model_attr = list(map(list, task.model.meta_attr)) | model_attr = list(map(list, task.model.meta_attr)) | ||||
| sample_attr = list(map(list, task.samples.meta_attr)) | sample_attr = list(map(list, task.samples.meta_attr)) | ||||
| @@ -21,7 +21,7 @@ from sedna.common.config import BaseConfig | |||||
| def set_backend(estimator=None, config=None): | def set_backend(estimator=None, config=None): | ||||
| """Create Trainer clss.""" | |||||
| """Create Trainer class""" | |||||
| if estimator is None: | if estimator is None: | ||||
| return | return | ||||
| if config is None: | if config is None: | ||||
| @@ -34,6 +34,7 @@ else: | |||||
| class TFBackend(BackendBase): | class TFBackend(BackendBase): | ||||
| """Tensorflow Framework Backend base Class""" | |||||
| def __init__(self, estimator, fine_tune=True, **kwargs): | def __init__(self, estimator, fine_tune=True, **kwargs): | ||||
| super(TFBackend, self).__init__( | super(TFBackend, self).__init__( | ||||
| @@ -128,6 +129,8 @@ class TFBackend(BackendBase): | |||||
| class KerasBackend(TFBackend): | class KerasBackend(TFBackend): | ||||
| """Keras Framework Backend base Class""" | |||||
| def __init__(self, estimator, fine_tune=True, **kwargs): | def __init__(self, estimator, fine_tune=True, **kwargs): | ||||
| super(TFBackend, self).__init__( | super(TFBackend, self).__init__( | ||||
| estimator=estimator, fine_tune=fine_tune, **kwargs) | estimator=estimator, fine_tune=fine_tune, **kwargs) | ||||
| @@ -15,6 +15,11 @@ | |||||
| # Copy from https://github.com/huawei-noah/vega/blob/master/zeus/common/class_factory.py # noqa | # Copy from https://github.com/huawei-noah/vega/blob/master/zeus/common/class_factory.py # noqa | ||||
| # We made a re-modify due to vega is exceed out needs | # We made a re-modify due to vega is exceed out needs | ||||
| """ | |||||
| Management class registration and bind configuration properties, | |||||
| provides the type of class supported. | |||||
| """ | |||||
| from inspect import isfunction, isclass | from inspect import isfunction, isclass | ||||
| @@ -34,7 +39,7 @@ class ClassType: | |||||
| class ClassFactory(object): | class ClassFactory(object): | ||||
| """ | """ | ||||
| A Factory Class to manage all class need to register with config. | |||||
| A Factory Class to manage all class need to register with config. | |||||
| """ | """ | ||||
| __registry__ = {} | __registry__ = {} | ||||
| @@ -124,7 +129,7 @@ class ClassFactory(object): | |||||
| :param type_name: type name of class registry | :param type_name: type name of class registry | ||||
| :param t_cls_name: class name | :param t_cls_name: class name | ||||
| :return:t_cls | |||||
| :return: t_cls | |||||
| """ | """ | ||||
| if not cls.is_exists(type_name, t_cls_name): | if not cls.is_exists(type_name, t_cls_name): | ||||
| raise ValueError( | raise ValueError( | ||||
| @@ -16,6 +16,10 @@ from enum import Enum | |||||
| class K8sResourceKind(Enum): | class K8sResourceKind(Enum): | ||||
| """ | |||||
| Sedna job/service kind | |||||
| """ | |||||
| DEFAULT = "default" | DEFAULT = "default" | ||||
| JOINT_INFERENCE_SERVICE = "jointinferenceservice" | JOINT_INFERENCE_SERVICE = "jointinferenceservice" | ||||
| FEDERATED_LEARNING_JOB = "federatedlearningjob" | FEDERATED_LEARNING_JOB = "federatedlearningjob" | ||||
| @@ -24,12 +28,20 @@ class K8sResourceKind(Enum): | |||||
| class K8sResourceKindStatus(Enum): | class K8sResourceKindStatus(Enum): | ||||
| """ | |||||
| Job/Service status | |||||
| """ | |||||
| COMPLETED = "completed" | COMPLETED = "completed" | ||||
| FAILED = "failed" | FAILED = "failed" | ||||
| RUNNING = "running" | RUNNING = "running" | ||||
| class KBResourceConstant(Enum): | class KBResourceConstant(Enum): | ||||
| """ | |||||
| Knowledge used constant | |||||
| """ | |||||
| MIN_TRAIN_SAMPLE = 10 | MIN_TRAIN_SAMPLE = 10 | ||||
| KB_INDEX_NAME = "index.pkl" | KB_INDEX_NAME = "index.pkl" | ||||
| TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl" | TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl" | ||||
| @@ -48,8 +48,11 @@ def _create_minio_client(): | |||||
| class FileOps: | class FileOps: | ||||
| """This is a class with some class methods | |||||
| to handle some files or folder.""" | |||||
| """ | |||||
| This is a class with some class methods | |||||
| to handle some files or folder. | |||||
| """ | |||||
| _GCS_PREFIX = "gs://" | _GCS_PREFIX = "gs://" | ||||
| _S3_PREFIX = "s3://" | _S3_PREFIX = "s3://" | ||||
| _LOCAL_PREFIX = "file://" | _LOCAL_PREFIX = "file://" | ||||
| @@ -49,23 +49,3 @@ def singleton(cls): | |||||
| return __instances__[cls] | return __instances__[cls] | ||||
| return get_instance | return get_instance | ||||
| def model_layer_flatten(weights): | |||||
| """like this: | |||||
| weights.shape=[(3, 3, 3, 64), (64,), (3, 3, 64, 32), (32,), (6272, 64), | |||||
| (64,), (64, 32), (32,), (32, 2), (2,)] | |||||
| flatten_weights=[(1728,), (64,), (18432,), (32,), (401408,), (64,), | |||||
| (2048,), (32,), (64,), (2,)] | |||||
| :param weights: | |||||
| :return: | |||||
| """ | |||||
| flatten = [layer.reshape((-1)) for layer in weights] | |||||
| return flatten | |||||
| def model_layer_reshape(flatten_weights, shapes): | |||||
| shaped_model = [] | |||||
| for idx, flatten_layer in enumerate(flatten_weights): | |||||
| shaped_model.append(flatten_layer.reshape(shapes[idx])) | |||||
| return shaped_model | |||||
| @@ -12,6 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| import time | import time | ||||
| from sedna.core.base import JobBase | from sedna.core.base import JobBase | ||||
| @@ -24,15 +25,32 @@ from sedna.common.constant import K8sResourceKindStatus | |||||
| class FederatedLearning(JobBase): | class FederatedLearning(JobBase): | ||||
| """ | """ | ||||
| Federated learning | |||||
| """ | |||||
| Federated learning enables multiple actors to build a common, robust | |||||
| machine learning model without sharing data, thus allowing to address | |||||
| critical issues such as data privacy, data security, data access rights | |||||
| and access to heterogeneous data. | |||||
| Sedna provide the related interfaces for application development. | |||||
| Parameters | |||||
| ---------- | |||||
| estimator: Instance | |||||
| An instance with the high-level API that greatly simplifies | |||||
| machine learning programming. Estimators encapsulate training, | |||||
| evaluation, prediction, and exporting for your model. | |||||
| aggregation: str | |||||
| aggregation algo which has registered to ClassFactory, | |||||
| see `sedna.algorithms.aggregation` for more detail. | |||||
| Examples | |||||
| -------- | |||||
| >>> Estimator = keras.models.Sequential() | |||||
| >>> fl_model = FederatedLearning( | |||||
| estimator=Estimator, | |||||
| aggregation="FedAvg" | |||||
| ) | |||||
| """ | |||||
| def __init__(self, estimator, aggregation="FedAvg"): | def __init__(self, estimator, aggregation="FedAvg"): | ||||
| """ | |||||
| Initial a FederatedLearning job | |||||
| :param estimator: Customize estimator | |||||
| :param aggregation: aggregation algorithm for FederatedLearning | |||||
| """ | |||||
| protocol = Context.get_parameters("AGG_PROTOCOL", "ws") | protocol = Context.get_parameters("AGG_PROTOCOL", "ws") | ||||
| agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1") | agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1") | ||||
| @@ -53,6 +71,13 @@ class FederatedLearning(JobBase): | |||||
| self.register(timeout=connect_timeout) | self.register(timeout=connect_timeout) | ||||
| def register(self, timeout=300): | def register(self, timeout=300): | ||||
| """ | |||||
| Deprecated, Client proactively subscribes to the aggregation service. | |||||
| Parameters | |||||
| ---------- | |||||
| timeout: int, connect timeout. Default: 300 | |||||
| """ | |||||
| self.log.info( | self.log.info( | ||||
| f"Node {self.worker_name} connect to : {self.config.agg_uri}") | f"Node {self.worker_name} connect to : {self.config.agg_uri}") | ||||
| self.node = AggregationClient( | self.node = AggregationClient( | ||||
| @@ -73,10 +98,20 @@ class FederatedLearning(JobBase): | |||||
| **kwargs): | **kwargs): | ||||
| """ | """ | ||||
| Training task for FederatedLearning | Training task for FederatedLearning | ||||
| :param train_data: datasource use for train | |||||
| :param valid_data: datasource use for evaluation | |||||
| :param post_process: post process | |||||
| :param kwargs: params for training of customize estimator | |||||
| Parameters | |||||
| ---------- | |||||
| train_data: BaseDataSource | |||||
| datasource use for train, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| valid_data: BaseDataSource | |||||
| datasource use for evaluation, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| post_process: function or a registered method | |||||
| effected after `estimator` training. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` training, | |||||
| Like: `early_stopping_rounds` in Xgboost.XGBClassifier | |||||
| """ | """ | ||||
| callback_func = None | callback_func = None | ||||
| @@ -24,17 +24,44 @@ __all__ = ("IncrementalLearning",) | |||||
| class IncrementalLearning(JobBase): | class IncrementalLearning(JobBase): | ||||
| """ | """ | ||||
| Incremental learning | |||||
| Incremental learning is a method of machine learning in which input data | |||||
| is continuously used to extend the existing model's knowledge i.e. to | |||||
| further train the model. It represents a dynamic technique of supervised | |||||
| learning and unsupervised learning that can be applied when training data | |||||
| becomes available gradually over time. | |||||
| Sedna provide the related interfaces for application development. | |||||
| Parameters | |||||
| ---------- | |||||
| estimator : Instance | |||||
| An instance with the high-level API that greatly simplifies | |||||
| machine learning programming. Estimators encapsulate training, | |||||
| evaluation, prediction, and exporting for your model. | |||||
| hard_example_mining : Dict | |||||
| HEM algorithms with parameters which has registered to ClassFactory, | |||||
| see `sedna.algorithms.hard_example_mining` for more detail. | |||||
| Examples | |||||
| -------- | |||||
| >>> Estimator = keras.models.Sequential() | |||||
| >>> il_model = IncrementalLearning( | |||||
| estimator=Estimator, | |||||
| hard_example_mining={ | |||||
| "method": "IBT", | |||||
| "param": { | |||||
| "threshold_img": 0.9 | |||||
| } | |||||
| } | |||||
| ) | |||||
| Notes | |||||
| ----- | |||||
| Sedna provide an interface call `get_hem_algorithm_from_config` to build | |||||
| the `hard_example_mining` parameter from CRD definition. | |||||
| """ | """ | ||||
| def __init__(self, estimator, hard_example_mining: dict = None): | def __init__(self, estimator, hard_example_mining: dict = None): | ||||
| """ | |||||
| Initial a IncrementalLearning job | |||||
| :param estimator: Customize estimator | |||||
| :param hard_example_mining: dict, hard example mining | |||||
| algorithms with parameters | |||||
| """ | |||||
| super(IncrementalLearning, self).__init__(estimator=estimator) | super(IncrementalLearning, self).__init__(estimator=estimator) | ||||
| self.model_urls = self.get_parameters( | self.model_urls = self.get_parameters( | ||||
| @@ -54,9 +81,24 @@ class IncrementalLearning(JobBase): | |||||
| @classmethod | @classmethod | ||||
| def get_hem_algorithm_from_config(cls, **param): | def get_hem_algorithm_from_config(cls, **param): | ||||
| """ | """ | ||||
| get the `algorithm` name and `param` of hard_example_mining from crd | |||||
| :param param: update value in parameters of hard_example_mining | |||||
| :return: dict, e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}} | |||||
| get the `algorithm` name and `param` of hard_example_mining from crd | |||||
| Parameters | |||||
| ---------- | |||||
| param : Dict | |||||
| update value in parameters of hard_example_mining | |||||
| Returns | |||||
| ------- | |||||
| dict | |||||
| e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}} | |||||
| Examples | |||||
| -------- | |||||
| >>> IncrementalLearning.get_hem_algorithm_from_config( | |||||
| threshold_img=0.9 | |||||
| ) | |||||
| {"method": "IBT", "param": {"threshold_img": 0.9}} | |||||
| """ | """ | ||||
| return cls.parameters.get_algorithm_from_api( | return cls.parameters.get_algorithm_from_api( | ||||
| algorithm="HEM", | algorithm="HEM", | ||||
| @@ -69,11 +111,24 @@ class IncrementalLearning(JobBase): | |||||
| **kwargs): | **kwargs): | ||||
| """ | """ | ||||
| Training task for IncrementalLearning | Training task for IncrementalLearning | ||||
| :param train_data: datasource use for train | |||||
| :param valid_data: datasource use for evaluation | |||||
| :param post_process: post process | |||||
| :param kwargs: params for training of customize estimator | |||||
| :return: estimator | |||||
| Parameters | |||||
| ---------- | |||||
| train_data: BaseDataSource | |||||
| datasource use for train, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| valid_data: BaseDataSource | |||||
| datasource use for evaluation, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| post_process: function or a registered method | |||||
| effected after `estimator` training. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` training, | |||||
| Like: `early_stopping_rounds` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| estimator | |||||
| """ | """ | ||||
| callback_func = None | callback_func = None | ||||
| @@ -94,10 +149,23 @@ class IncrementalLearning(JobBase): | |||||
| def inference(self, data=None, post_process=None, **kwargs): | def inference(self, data=None, post_process=None, **kwargs): | ||||
| """ | """ | ||||
| Inference task for IncrementalLearning | Inference task for IncrementalLearning | ||||
| :param data: inference sample | |||||
| :param post_process: post process | |||||
| :param kwargs: params for inference of customize estimator | |||||
| :return: inference result, result after post_process, if is hard sample | |||||
| Parameters | |||||
| ---------- | |||||
| data: BaseDataSource | |||||
| datasource use for inference, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| post_process: function or a registered method | |||||
| effected after `estimator` inference. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` inference, | |||||
| Like: `ntree_limit` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| inference result : object | |||||
| result after post_process : object | |||||
| if is hard sample : bool | |||||
| """ | """ | ||||
| if not self.estimator.has_load: | if not self.estimator.has_load: | ||||
| @@ -125,10 +193,21 @@ class IncrementalLearning(JobBase): | |||||
| def evaluate(self, data, post_process=None, **kwargs): | def evaluate(self, data, post_process=None, **kwargs): | ||||
| """ | """ | ||||
| Evaluate task for IncrementalLearning | Evaluate task for IncrementalLearning | ||||
| :param data: datasource use for evaluation | |||||
| :param post_process: post process | |||||
| :param kwargs: params for evaluate of customize estimator | |||||
| :return: evaluate metrics | |||||
| Parameters | |||||
| ---------- | |||||
| data: BaseDataSource | |||||
| datasource use for evaluation, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| post_process: function or a registered method | |||||
| effected after `estimator` evaluation. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` evaluate, | |||||
| Like: `metric_name` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| evaluate metrics : List | |||||
| """ | """ | ||||
| callback_func = None | callback_func = None | ||||
| @@ -29,6 +29,18 @@ class BigModelService(JobBase): | |||||
| """ | """ | ||||
| Large model services implemented | Large model services implemented | ||||
| Provides RESTful interfaces for large-model inference. | Provides RESTful interfaces for large-model inference. | ||||
| Parameters | |||||
| ---------- | |||||
| estimator : Instance, big model | |||||
| An instance with the high-level API that greatly simplifies | |||||
| machine learning programming. Estimators encapsulate training, | |||||
| evaluation, prediction, and exporting for your model. | |||||
| Examples | |||||
| -------- | |||||
| >>> Estimator = xgboost.XGBClassifier() | |||||
| >>> BigModelService(estimator=Estimator).start() | |||||
| """ | """ | ||||
| def __init__(self, estimator=None): | def __init__(self, estimator=None): | ||||
| @@ -44,7 +56,6 @@ class BigModelService(JobBase): | |||||
| def start(self): | def start(self): | ||||
| """ | """ | ||||
| Start inference rest server | Start inference rest server | ||||
| :return: | |||||
| """ | """ | ||||
| if callable(self.estimator): | if callable(self.estimator): | ||||
| @@ -66,10 +77,21 @@ class BigModelService(JobBase): | |||||
| def inference(self, data=None, post_process=None, **kwargs): | def inference(self, data=None, post_process=None, **kwargs): | ||||
| """ | """ | ||||
| Inference task for JointInference | Inference task for JointInference | ||||
| :param data: inference sample | |||||
| :param post_process: post process | |||||
| :param kwargs: params for inference of big model | |||||
| :return: inference result | |||||
| Parameters | |||||
| ---------- | |||||
| data: BaseDataSource | |||||
| datasource use for inference, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| post_process: function or a registered method | |||||
| effected after `estimator` inference. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` inference, | |||||
| Like: `ntree_limit` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| inference result | |||||
| """ | """ | ||||
| callback_func = None | callback_func = None | ||||
| @@ -87,16 +109,40 @@ class BigModelService(JobBase): | |||||
| class JointInference(JobBase): | class JointInference(JobBase): | ||||
| """ | """ | ||||
| Joint inference | |||||
| Sedna provide a framework make sure under the condition of limited | |||||
| resources on the edge, difficult inference tasks are offloaded to the | |||||
| cloud to improve the overall performance, keeping the throughput. | |||||
| Parameters | |||||
| ---------- | |||||
| estimator : Instance | |||||
| An instance with the high-level API that greatly simplifies | |||||
| machine learning programming. Estimators encapsulate training, | |||||
| evaluation, prediction, and exporting for your model. | |||||
| hard_example_mining : Dict | |||||
| HEM algorithms with parameters which has registered to ClassFactory, | |||||
| see `sedna.algorithms.hard_example_mining` for more detail. | |||||
| Examples | |||||
| -------- | |||||
| >>> Estimator = keras.models.Sequential() | |||||
| >>> ji_service = JointInference( | |||||
| estimator=Estimator, | |||||
| hard_example_mining={ | |||||
| "method": "IBT", | |||||
| "param": { | |||||
| "threshold_img": 0.9 | |||||
| } | |||||
| } | |||||
| ) | |||||
| Notes | |||||
| ----- | |||||
| Sedna provide an interface call `get_hem_algorithm_from_config` to build | |||||
| the `hard_example_mining` parameter from CRD definition. | |||||
| """ | """ | ||||
| def __init__(self, estimator=None, hard_example_mining: dict = None): | def __init__(self, estimator=None, hard_example_mining: dict = None): | ||||
| """ | |||||
| Initial a JointInference Job | |||||
| :param estimator: Customize estimator | |||||
| :param hard_example_mining: dict, hard example mining | |||||
| """ | |||||
| super(JointInference, self).__init__(estimator=estimator) | super(JointInference, self).__init__(estimator=estimator) | ||||
| self.job_kind = K8sResourceKind.JOINT_INFERENCE_SERVICE.value | self.job_kind = K8sResourceKind.JOINT_INFERENCE_SERVICE.value | ||||
| self.local_ip = get_host_ip() | self.local_ip = get_host_ip() | ||||
| @@ -141,8 +187,23 @@ class JointInference(JobBase): | |||||
| def get_hem_algorithm_from_config(cls, **param): | def get_hem_algorithm_from_config(cls, **param): | ||||
| """ | """ | ||||
| get the `algorithm` name and `param` of hard_example_mining from crd | get the `algorithm` name and `param` of hard_example_mining from crd | ||||
| :param param: update value in parameters of hard_example_mining | |||||
| :return: dict, e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}} | |||||
| Parameters | |||||
| ---------- | |||||
| param : Dict | |||||
| update value in parameters of hard_example_mining | |||||
| Returns | |||||
| ------- | |||||
| dict | |||||
| e.g.: {"method": "IBT", "param": {"threshold_img": 0.5}} | |||||
| Examples | |||||
| -------- | |||||
| >>> JointInference.get_hem_algorithm_from_config( | |||||
| threshold_img=0.9 | |||||
| ) | |||||
| {"method": "IBT", "param": {"threshold_img": 0.9}} | |||||
| """ | """ | ||||
| return cls.parameters.get_algorithm_from_api( | return cls.parameters.get_algorithm_from_api( | ||||
| algorithm="HEM", | algorithm="HEM", | ||||
| @@ -151,12 +212,25 @@ class JointInference(JobBase): | |||||
| def inference(self, data=None, post_process=None, **kwargs): | def inference(self, data=None, post_process=None, **kwargs): | ||||
| """ | """ | ||||
| Inference task for IncrementalLearning | |||||
| :param data: inference sample | |||||
| :param post_process: post process | |||||
| :param kwargs: params for inference of customize estimator | |||||
| :return: if is hard sample, real result, | |||||
| little model result, big model result | |||||
| Inference task with JointInference | |||||
| Parameters | |||||
| ---------- | |||||
| data: BaseDataSource | |||||
| datasource use for inference, see | |||||
| `sedna.datasources.BaseDataSource` for more detail. | |||||
| post_process: function or a registered method | |||||
| effected after `estimator` inference. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` inference, | |||||
| Like: `ntree_limit` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| if is hard sample : bool | |||||
| inference result : object | |||||
| result from little-model : object | |||||
| result from big-model: object | |||||
| """ | """ | ||||
| callback_func = None | callback_func = None | ||||
| @@ -28,7 +28,68 @@ from sedna.service.client import KBClient | |||||
| class LifelongLearning(JobBase): | class LifelongLearning(JobBase): | ||||
| """ | """ | ||||
| Lifelong learning | |||||
| Lifelong Learning (LL) is an advanced machine learning (ML) paradigm that | |||||
| learns continuously, accumulates the knowledge learned in the past, and | |||||
| uses/adapts it to help future learning and problem solving. | |||||
| Sedna provide the related interfaces for application development. | |||||
| Parameters | |||||
| ---------- | |||||
| estimator : Instance | |||||
| An instance with the high-level API that greatly simplifies | |||||
| machine learning programming. Estimators encapsulate training, | |||||
| evaluation, prediction, and exporting for your model. | |||||
| task_definition : Dict | |||||
| Divide multiple tasks based on data, | |||||
| see `task_jobs.task_definition` for more detail. | |||||
| task_relationship_discovery : Dict | |||||
| Discover relationships between all tasks, see | |||||
| `task_jobs.task_relationship_discovery` for more detail. | |||||
| task_mining : Dict | |||||
| Mining tasks of inference sample, | |||||
| see `task_jobs.task_mining` for more detail. | |||||
| task_remodeling : Dict | |||||
| Remodeling tasks based on their relationships, | |||||
| see `task_jobs.task_remodeling` for more detail. | |||||
| inference_integrate : Dict | |||||
| Integrate the inference results of all related | |||||
| tasks, see `task_jobs.inference_integrate` for more detail. | |||||
| unseen_task_detect: Dict | |||||
| unseen task detect algorithms with parameters which has registered to | |||||
| ClassFactory, see `sedna.algorithms.unseen_task_detect` for more detail | |||||
| Examples | |||||
| -------- | |||||
| >>> estimator = XGBClassifier(objective="binary:logistic") | |||||
| >>> task_definition = { | |||||
| "method": "TaskDefinitionByDataAttr", | |||||
| "param": {"attribute": ["season", "city"]} | |||||
| } | |||||
| >>> task_relationship_discovery = { | |||||
| "method": "DefaultTaskRelationDiscover", "param": {} | |||||
| } | |||||
| >>> task_mining = { | |||||
| "method": "TaskMiningByDataAttr", | |||||
| "param": {"attribute": ["season", "city"]} | |||||
| } | |||||
| >>> task_remodeling = None | |||||
| >>> inference_integrate = { | |||||
| "method": "DefaultInferenceIntegrate", "param": {} | |||||
| } | |||||
| >>> unseen_task_detect = { | |||||
| "method": "TaskAttrFilter", "param": {} | |||||
| } | |||||
| >>> ll_jobs = LifelongLearning( | |||||
| estimator=estimator, | |||||
| task_definition=task_definition, | |||||
| task_relationship_discovery=task_relationship_discovery, | |||||
| task_mining=task_mining, | |||||
| task_remodeling=task_remodeling, | |||||
| inference_integrate=inference_integrate, | |||||
| unseen_task_detect=unseen_task_detect | |||||
| ) | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -57,7 +118,7 @@ class LifelongLearning(JobBase): | |||||
| inference_integrate=inference_integrate) | inference_integrate=inference_integrate) | ||||
| self.unseen_task_detect = unseen_task_detect.get("method", | self.unseen_task_detect = unseen_task_detect.get("method", | ||||
| "TaskAttrFilter") | "TaskAttrFilter") | ||||
| self.unseen_task_detect_param = e.parse_param( | |||||
| self.unseen_task_detect_param = e._parse_param( | |||||
| unseen_task_detect.get("param", {}) | unseen_task_detect.get("param", {}) | ||||
| ) | ) | ||||
| config = dict( | config = dict( | ||||
| @@ -79,10 +140,25 @@ class LifelongLearning(JobBase): | |||||
| action="initial", | action="initial", | ||||
| **kwargs): | **kwargs): | ||||
| """ | """ | ||||
| :param train_data: data use to train model | |||||
| :param valid_data: data use to valid model | |||||
| :param post_process: callback function | |||||
| :param action: initial - kb init, update - kb incremental update | |||||
| fit for update the knowledge based on training data. | |||||
| Parameters | |||||
| ---------- | |||||
| train_data : BaseDataSource | |||||
| Train data, see `sedna.datasources.BaseDataSource` for more detail. | |||||
| valid_data : BaseDataSource | |||||
| Valid data, BaseDataSource or None. | |||||
| post_process : function | |||||
| function or a registered method, callback after `estimator` train. | |||||
| action : str | |||||
| `update` or `initial` the knowledge base | |||||
| kwargs : Dict | |||||
| parameters for `estimator` training, Like: | |||||
| `early_stopping_rounds` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| train_history : object | |||||
| """ | """ | ||||
| callback_func = None | callback_func = None | ||||
| @@ -180,6 +256,19 @@ class LifelongLearning(JobBase): | |||||
| ) | ) | ||||
| def evaluate(self, data, post_process=None, **kwargs): | def evaluate(self, data, post_process=None, **kwargs): | ||||
| """ | |||||
| evaluated the performance of each task from training, filter tasks | |||||
| based on the defined rules. | |||||
| Parameters | |||||
| ---------- | |||||
| data : BaseDataSource | |||||
| valid data, see `sedna.datasources.BaseDataSource` for more detail. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` evaluate, Like: | |||||
| `ntree_limit` in Xgboost.XGBClassifier | |||||
| """ | |||||
| callback_func = None | callback_func = None | ||||
| if callable(post_process): | if callable(post_process): | ||||
| callback_func = post_process | callback_func = post_process | ||||
| @@ -244,7 +333,30 @@ class LifelongLearning(JobBase): | |||||
| return callback_func(res) if callback_func else res | return callback_func(res) if callback_func else res | ||||
| def inference(self, data=None, post_process=None, **kwargs): | def inference(self, data=None, post_process=None, **kwargs): | ||||
| """ | |||||
| predict the result for input data based on training knowledge. | |||||
| Parameters | |||||
| ---------- | |||||
| data : BaseDataSource | |||||
| inference sample, see `sedna.datasources.BaseDataSource` for | |||||
| more detail. | |||||
| post_process: function | |||||
| function or a registered method, effected after `estimator` | |||||
| prediction, like: label transform. | |||||
| kwargs: Dict | |||||
| parameters for `estimator` predict, Like: | |||||
| `ntree_limit` in Xgboost.XGBClassifier | |||||
| Returns | |||||
| ------- | |||||
| result : array_like | |||||
| results array, contain all inference results in each sample. | |||||
| is_unseen_task : bool | |||||
| `true` means detect an unseen task, `false` means not | |||||
| tasks : List | |||||
| tasks assigned to each sample. | |||||
| """ | |||||
| task_index_url = self.get_parameters( | task_index_url = self.get_parameters( | ||||
| "MODEL_URLS", self.config.task_index) | "MODEL_URLS", self.config.task_index) | ||||
| index_url = self.estimator.estimator.task_index_url | index_url = self.estimator.estimator.task_index_url | ||||
| @@ -24,6 +24,23 @@ __all__ = ('BaseDataSource', 'TxtDataParse', 'CSVDataParse') | |||||
| class BaseDataSource: | class BaseDataSource: | ||||
| """ | |||||
| An abstract class representing a :class:`BaseDataSource`. | |||||
| All datasets that represent a map from keys to data samples should subclass | |||||
| it. All subclasses should overwrite parse`, supporting get train/eval/infer | |||||
| data by a function. Subclasses could also optionally overwrite `__len__`, | |||||
| which is expected to return the size of the dataset.overwrite `x` for the | |||||
| feature-embedding, `y` for the target label. | |||||
| Parameters | |||||
| ---------- | |||||
| data_type : str | |||||
| define the datasource is train/eval/test | |||||
| func: function | |||||
| function use to parse an iter object batch by batch | |||||
| """ | |||||
| def __init__(self, data_type="train", func=None): | def __init__(self, data_type="train", func=None): | ||||
| self.data_type = data_type # sample type: train/eval/test | self.data_type = data_type # sample type: train/eval/test | ||||
| self.process_func = None | self.process_func = None | ||||
| @@ -54,7 +71,9 @@ class BaseDataSource: | |||||
| class TxtDataParse(BaseDataSource, ABC): | class TxtDataParse(BaseDataSource, ABC): | ||||
| """txt file which contain image list parser""" | |||||
| """ | |||||
| txt file which contain image list parser | |||||
| """ | |||||
| def __init__(self, data_type, func=None): | def __init__(self, data_type, func=None): | ||||
| super(TxtDataParse, self).__init__(data_type=data_type, func=func) | super(TxtDataParse, self).__init__(data_type=data_type, func=func) | ||||
| @@ -89,7 +108,9 @@ class TxtDataParse(BaseDataSource, ABC): | |||||
| class CSVDataParse(BaseDataSource, ABC): | class CSVDataParse(BaseDataSource, ABC): | ||||
| """csv file which contain Structured Data parser""" | |||||
| """ | |||||
| csv file which contain Structured Data parser | |||||
| """ | |||||
| def __init__(self, data_type, func=None): | def __init__(self, data_type, func=None): | ||||
| super(CSVDataParse, self).__init__(data_type=data_type, func=func) | super(CSVDataParse, self).__init__(data_type=data_type, func=func) | ||||
| @@ -117,7 +117,7 @@ class LCReporter(threading.Thread): | |||||
| class LCClient: | class LCClient: | ||||
| """send info to LC by http""" | |||||
| @classmethod | @classmethod | ||||
| def send(cls, lc_server, worker_name, message: dict): | def send(cls, lc_server, worker_name, message: dict): | ||||
| url = '{0}/sedna/workers/{1}/info'.format( | url = '{0}/sedna/workers/{1}/info'.format( | ||||