| @@ -27,6 +27,7 @@ federated-learning-surface-defect-detection-train.Dockerfile | |||
| incremental-learning-helmet-detection.Dockerfile | |||
| joint-inference-helmet-detection-big.Dockerfile | |||
| joint-inference-helmet-detection-little.Dockerfile | |||
| lifelong-learning-atcii-classifier.Dockerfile | |||
| ) | |||
| for dockerfile in ${dockerfiles[@]}; do | |||
| @@ -0,0 +1,177 @@ | |||
| # Using Lifelong Learning Job in Thermal Comfort Prediction Scenario | |||
| This document introduces how to use lifelong learning job in thermal comfort prediction scenario. | |||
| Using the lifelong learning job, our application can automatically retrains, evaluates, | |||
| and updates models based on the data generated at the edge. | |||
| ## Thermal Comfort Prediction Experiment | |||
| ### Install Sedna | |||
| Follow the [Sedna installation document](/docs/setup/install.md) to install Sedna. | |||
| ### Prepare Dataset | |||
| In this example, you can download ASHRAE Global Thermal Comfort Database II to initial lifelong learning Knowledgebase . | |||
| download [datasets](https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/atcii-classifier/dataset.tar.gz), including train、evaluation and incremental dataset. | |||
| ``` | |||
| cd / | |||
| wget https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/atcii-classifier/dataset.tar.gz | |||
| tar -zxvf dataset.tar.gz | |||
| ``` | |||
| ### Prepare for Knowledgebase Server | |||
| in this example, we create a knowledgebase restful server with sqlite3, the database will storage to `LIFELONG_KB_URL`, and | |||
| run at `GM Node`. | |||
| ### Prepare Image | |||
| this example use the image: | |||
| ``` | |||
| swr.cn-southwest-2.myhuaweicloud.com/sedna-test/sedna/kb:v0.0.1 | |||
| ``` | |||
| ### Create KB Deployment | |||
| ``` | |||
| kubectl create -f scripts/knowledge-server/kb.yaml | |||
| ``` | |||
| ### Create Lifelong Job | |||
| in this example, `$WORKER_NODE` is a custom node, you can fill it which you actually run. | |||
| ``` | |||
| WORKER_NODE="edge-node" | |||
| ``` | |||
| Create Dataset | |||
| ``` | |||
| kubectl create -f - <<EOF | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: Dataset | |||
| metadata: | |||
| name: lifelong-dataset | |||
| spec: | |||
| url: "/data/lifelong_learning/trainData.csv" | |||
| format: "csv" | |||
| nodeName: $WORKER_NODE | |||
| EOF | |||
| ``` | |||
| Also, you can trigger retraining by use `incremental Dataset`[trainData2.csv] to replace `trainData.csv` | |||
| Start The Lifelong Learning Job | |||
| ``` | |||
| kubectl create -f - <<EOF | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: LifelongLearningJob | |||
| metadata: | |||
| name: atcii-classifier-demo | |||
| spec: | |||
| dataset: | |||
| name: "lifelong-dataset" | |||
| trainProb: 0.8 | |||
| trainSpec: | |||
| template: | |||
| spec: | |||
| nodeName: "edge-node" | |||
| containers: | |||
| - image: kubeedge/sedna-example-lifelong-learning-atcii-classifier:v0.1.0 | |||
| name: train-worker | |||
| imagePullPolicy: IfNotPresent | |||
| args: ["train.py"] | |||
| env: | |||
| - name: "early_stopping_rounds" | |||
| value: "100" | |||
| - name: "metric_name" | |||
| value: "mlogloss" | |||
| trigger: | |||
| checkPeriodSeconds: 60 | |||
| timer: | |||
| start: 02:00 | |||
| end: 24:00 | |||
| condition: | |||
| operator: ">" | |||
| threshold: 500 | |||
| metric: num_of_samples | |||
| evalSpec: | |||
| template: | |||
| spec: | |||
| nodeName: "edge-node" | |||
| containers: | |||
| - image: kubeedge/sedna-example-lifelong-learning-atcii-classifier:v0.1.0 | |||
| name: eval-worker | |||
| imagePullPolicy: IfNotPresent | |||
| args: ["eval.py"] | |||
| env: | |||
| - name: "metrics" | |||
| value: "precision_score" | |||
| - name: "metric_param" | |||
| value: "{'average': 'micro'}" | |||
| - name: "model_threshold" | |||
| value: "0.5" | |||
| deploySpec: | |||
| template: | |||
| spec: | |||
| nodeName: "edge-node" | |||
| containers: | |||
| - image: kubeedge/sedna-example-lifelong-learning-atcii-classifier:v0.1.0 | |||
| name: infer-worker | |||
| imagePullPolicy: IfNotPresent | |||
| args: ["inference.py"] | |||
| env: | |||
| - name: "UT_SAVED_URL" | |||
| value: "/ut_saved_url" | |||
| - name: "infer_dataset_url" | |||
| value: "/data/testData.csv" | |||
| volumeMounts: | |||
| - name: utdir | |||
| mountPath: /ut_saved_url | |||
| - name: inferdata | |||
| mountPath: /data/ | |||
| resources: # user defined resources | |||
| limits: | |||
| memory: 2Gi | |||
| volumes: # user defined volumes | |||
| - name: utdir | |||
| hostPath: | |||
| path: /lifelong/unseen_task/ | |||
| type: DirectoryOrCreate | |||
| - name: inferdata | |||
| hostPath: | |||
| path: /lifelong/data/ | |||
| type: DirectoryOrCreate | |||
| outputDir: "/output" | |||
| EOF | |||
| ``` | |||
| ### Check Lifelong Learning Job | |||
| query the service status | |||
| ``` | |||
| kubectl get lifelonglearningjob atcii-classifier-demo | |||
| ``` | |||
| In the `lifelonglearningjob` resource atcii-classifier-demo, the following trigger is configured: | |||
| ``` | |||
| trigger: | |||
| checkPeriodSeconds: 60 | |||
| timer: | |||
| start: 02:00 | |||
| end: 20:00 | |||
| condition: | |||
| operator: ">" | |||
| threshold: 500 | |||
| metric: num_of_samples | |||
| ``` | |||
| ### Unseen Tasks samples Labeling | |||
| In a real word, we need to label the hard examples in our unseen tasks which storage in `UT_SAVED_URL` with annotation tools and then put the examples to `Dataset`'s url. | |||
| ### Effect Display | |||
| in this example, false and failed detections occur at stage of inference before lifelong learning, after lifelong learning, | |||
| Greatly improve the precision and accuracy of the dataset. | |||
|  | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import json | |||
| from interface import DATACONF, Estimator, feature_process | |||
| from sedna.common.config import Context, BaseConfig | |||
| from sedna.datasources import CSVDataParse | |||
| from sedna.core.lifelong_learning import LifelongLearning | |||
| def main(): | |||
| test_dataset_url = BaseConfig.test_dataset_url | |||
| valid_data = CSVDataParse(data_type="valid", func=feature_process) | |||
| valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) | |||
| attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | |||
| model_threshold = float(Context.get_parameters('model_threshold', 0)) | |||
| ll_job = LifelongLearning( | |||
| estimator=Estimator, | |||
| task_definition="TaskDefinitionByDataAttr", | |||
| task_definition_param=attribute | |||
| ) | |||
| eval_experiment = ll_job.evaluate( | |||
| data=valid_data, metrics="precision_score", | |||
| metrics_param={"average": "micro"}, | |||
| model_threshold=model_threshold | |||
| ) | |||
| return eval_experiment | |||
| if __name__ == '__main__': | |||
| print(main()) | |||
| @@ -0,0 +1,72 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| import csv | |||
| import json | |||
| import time | |||
| from interface import DATACONF, Estimator, feature_process | |||
| from sedna.common.config import Context | |||
| from sedna.datasources import CSVDataParse | |||
| from sedna.core.lifelong_learning import LifelongLearning | |||
| def main(): | |||
| utd = Context.get_parameters("UTD_NAME", "TaskAttr") | |||
| attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | |||
| utd_parameters = Context.get_parameters("UTD_PARAMETERS", {}) | |||
| ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp") | |||
| ll_job = LifelongLearning( | |||
| estimator=Estimator, | |||
| task_mining="TaskMiningByDataAttr", | |||
| task_mining_param=attribute, | |||
| unseen_task_detect=utd, | |||
| unseen_task_detect_param=utd_parameters) | |||
| infer_dataset_url = Context.get_parameters('infer_dataset_url') | |||
| file_handle = open(infer_dataset_url, "r", encoding="utf-8") | |||
| header = list(csv.reader([file_handle.readline().strip()]))[0] | |||
| infer_data = CSVDataParse(data_type="test", func=feature_process) | |||
| unseen_sample = open(os.path.join(ut_saved_url, "unseen_sample.csv"), | |||
| "w", encoding="utf-8") | |||
| unseen_sample.write("\t".join(header + ['pred']) + "\n") | |||
| output_sample = open(f"{infer_dataset_url}_out.csv", "w", encoding="utf-8") | |||
| output_sample.write("\t".join(header + ['pred']) + "\n") | |||
| while 1: | |||
| where = file_handle.tell() | |||
| line = file_handle.readline() | |||
| if not line: | |||
| time.sleep(1) | |||
| file_handle.seek(where) | |||
| continue | |||
| reader = list(csv.reader([line.strip()])) | |||
| rows = reader[0] | |||
| data = dict(zip(header, rows)) | |||
| infer_data.parse(data, label=DATACONF["LABEL"]) | |||
| rsl, is_unseen, target_task = ll_job.inference(infer_data) | |||
| rows.append(list(rsl)[0]) | |||
| if is_unseen: | |||
| unseen_sample.write("\t".join(map(str, rows)) + "\n") | |||
| output_sample.write("\t".join(map(str, rows)) + "\n") | |||
| unseen_sample.close() | |||
| output_sample.close() | |||
| if __name__ == '__main__': | |||
| print(main()) | |||
| @@ -0,0 +1,124 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| import pandas as pd | |||
| import numpy as np | |||
| import xgboost | |||
| from sklearn.model_selection import train_test_split | |||
| from sklearn.metrics import precision_score | |||
| os.environ['BACKEND_TYPE'] = 'SKLEARN' | |||
| DATACONF = { | |||
| "ATTRIBUTES": ["Season", "Cooling startegy_building level"], | |||
| "LABEL": "Thermal preference", | |||
| } | |||
| def feature_process(df: pd.DataFrame): | |||
| if "City" in df.columns: | |||
| df.drop(["City"], axis=1, inplace=True) | |||
| for feature in df.columns: | |||
| if feature in ["Season", ]: | |||
| continue | |||
| df[feature] = df[feature].apply(lambda x: float(x) if x else 0.0) | |||
| df['Thermal preference'] = df['Thermal preference'].apply( | |||
| lambda x: int(float(x)) if x else 1) | |||
| return df | |||
| class Estimator: | |||
| def __init__(self): | |||
| """Model init""" | |||
| self.model = xgboost.XGBClassifier( | |||
| learning_rate=0.1, | |||
| n_estimators=600, | |||
| max_depth=2, | |||
| min_child_weight=1, | |||
| gamma=0, | |||
| subsample=0.8, | |||
| colsample_bytree=0.8, | |||
| objective="multi:softmax", | |||
| num_class=3, | |||
| nthread=4, | |||
| seed=27) | |||
| def train(self, train_data, valid_data=None, | |||
| save_best=True, | |||
| metric_name="mlogloss", | |||
| early_stopping_rounds=100 | |||
| ): | |||
| es = [ | |||
| xgboost.callback.EarlyStopping( | |||
| metric_name=metric_name, | |||
| rounds=early_stopping_rounds, | |||
| save_best=save_best | |||
| ) | |||
| ] | |||
| x, y = train_data.x, train_data.y | |||
| if valid_data: | |||
| x1, y1 = valid_data.x, valid_data.y | |||
| else: | |||
| x, x1, y, y1 = train_test_split( | |||
| x, y, test_size=0.1, random_state=42) | |||
| history = self.model.fit(x, y, eval_set=[(x1, y1), ], callbacks=es) | |||
| d = {} | |||
| for k, v in history.evals_result().items(): | |||
| for k1, v1, in v.items(): | |||
| m = np.mean(v1) | |||
| if k1 not in d: | |||
| d[k1] = [] | |||
| d[k1].append(m) | |||
| for k, v in d.items(): | |||
| d[k] = np.mean(v) | |||
| return d | |||
| def predict(self, datas, **kwargs): | |||
| """ Model inference """ | |||
| return self.model.predict(datas) | |||
| def predict_proba(self, datas, **kwargs): | |||
| return self.model.predict_proba(datas) | |||
| def evaluate(self, test_data, **kwargs): | |||
| """ Model evaluate """ | |||
| y_pred = self.predict(test_data.x) | |||
| return precision_score(test_data.y, y_pred, average="micro") | |||
| def load(self, model_url): | |||
| self.model.load_model(model_url) | |||
| return self | |||
| def save(self, model_path=None): | |||
| """ | |||
| save model as a single pb file from checkpoint | |||
| """ | |||
| return self.model.save_model(model_path) | |||
| if __name__ == '__main__': | |||
| from sedna.datasources import CSVDataParse | |||
| from sedna.common.config import BaseConfig | |||
| train_dataset_url = BaseConfig.train_dataset_url | |||
| train_data = CSVDataParse(data_type="train", func=feature_process) | |||
| train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) | |||
| test_dataset_url = BaseConfig.test_dataset_url | |||
| valid_data = CSVDataParse(data_type="valid", func=feature_process) | |||
| valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) | |||
| model = Estimator() | |||
| print(model.train(train_data)) | |||
| print(model.evaluate(test_data=valid_data)) | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import json | |||
| from interface import DATACONF, Estimator, feature_process | |||
| from sedna.common.config import Context, BaseConfig | |||
| from sedna.datasources import CSVDataParse | |||
| from sedna.core.lifelong_learning import LifelongLearning | |||
| def main(): | |||
| # load dataset. | |||
| train_dataset_url = BaseConfig.train_dataset_url | |||
| train_data = CSVDataParse(data_type="train", func=feature_process) | |||
| train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) | |||
| attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | |||
| early_stopping_rounds = int( | |||
| Context.get_parameters( | |||
| "early_stopping_rounds", 100)) | |||
| metric_name = Context.get_parameters("metric_name", "mlogloss") | |||
| ll_job = LifelongLearning( | |||
| estimator=Estimator, | |||
| task_definition="TaskDefinitionByDataAttr", | |||
| task_definition_param=attribute | |||
| ) | |||
| train_experiment = ll_job.train( | |||
| train_data=train_data, | |||
| metric_name=metric_name, | |||
| early_stopping_rounds=early_stopping_rounds | |||
| ) | |||
| return train_experiment | |||
| if __name__ == '__main__': | |||
| print(main()) | |||
| @@ -14,3 +14,5 @@ | |||
| from .aggregation import * | |||
| from .hard_example_mining import * | |||
| from .multi_task_learning import * | |||
| from .unseen_task_detect import * | |||
| @@ -0,0 +1,16 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .task_jobs import * | |||
| from .multi_task_learning import MulTaskLearning | |||
| @@ -0,0 +1,312 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| import json | |||
| import joblib | |||
| from .task_jobs.artifact import Model, Task, TaskGroup | |||
| from sedna.datasources import BaseDataSource | |||
| from sedna.backend import set_backend | |||
| from sedna.common.log import LOGGER | |||
| from sedna.common.config import Context | |||
| from sedna.common.file_ops import FileOps | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('MulTaskLearning',) | |||
| class MulTaskLearning: | |||
| _method_pair = { | |||
| 'TaskDefinitionBySVC': 'TaskMiningBySVC', | |||
| 'TaskDefinitionByDataAttr': 'TaskMiningByDataAttr', | |||
| } | |||
| def __init__(self, | |||
| estimator=None, | |||
| task_definition="TaskDefinitionByDataAttr", | |||
| task_relationship_discovery=None, | |||
| task_mining=None, | |||
| task_remodeling=None, | |||
| inference_integrate=None, | |||
| task_definition_param=None, | |||
| relationship_discovery_param=None, | |||
| task_mining_param=None, | |||
| task_remodeling_param=None, | |||
| inference_integrate_param=None | |||
| ): | |||
| if not task_relationship_discovery: | |||
| task_relationship_discovery = "DefaultTaskRelationDiscover" | |||
| if not task_remodeling: | |||
| task_remodeling = "DefaultTaskRemodeling" | |||
| if not inference_integrate: | |||
| inference_integrate = "DefaultInferenceIntegrate" | |||
| self.method_selection = dict( | |||
| task_definition=task_definition, | |||
| task_relationship_discovery=task_relationship_discovery, | |||
| task_mining=task_mining, | |||
| task_remodeling=task_remodeling, | |||
| inference_integrate=inference_integrate, | |||
| task_definition_param=task_definition_param, | |||
| task_relationship_discovery_param=relationship_discovery_param, | |||
| task_mining_param=task_mining_param, | |||
| task_remodeling_param=task_remodeling_param, | |||
| inference_integrate_param=inference_integrate_param) | |||
| self.models = None | |||
| self.extractor = None | |||
| self.base_model = estimator | |||
| self.task_groups = None | |||
| self.task_index_url = Context.get_parameters( | |||
| "MODEL_URLS", '/tmp/index.pkl' | |||
| ) | |||
| self.min_train_sample = int(Context.get_parameters( | |||
| "MIN_TRAIN_SAMPLE", '10' | |||
| )) | |||
| @staticmethod | |||
| def parse_param(param_str): | |||
| if not param_str: | |||
| return {} | |||
| try: | |||
| raw_dict = json.loads(param_str, encoding="utf-8") | |||
| except json.JSONDecodeError: | |||
| raw_dict = {} | |||
| return raw_dict | |||
| def task_definition(self, samples): | |||
| method_name = self.method_selection.get( | |||
| "task_definition", "TaskDefinitionByDataAttr") | |||
| extend_param = self.parse_param( | |||
| self.method_selection.get("task_definition_param")) | |||
| method_cls = ClassFactory.get_cls( | |||
| ClassType.MTL, method_name)(**extend_param) | |||
| return method_cls(samples) | |||
| def task_relationship_discovery(self, tasks): | |||
| method_name = self.method_selection.get("task_relationship_discovery") | |||
| extend_param = self.parse_param( | |||
| self.method_selection.get("task_relationship_discovery_param") | |||
| ) | |||
| method_cls = ClassFactory.get_cls( | |||
| ClassType.MTL, method_name)(**extend_param) | |||
| return method_cls(tasks) | |||
| def task_mining(self, samples): | |||
| method_name = self.method_selection.get("task_mining") | |||
| extend_param = self.parse_param( | |||
| self.method_selection.get("task_mining_param")) | |||
| if not method_name: | |||
| task_definition = self.method_selection.get( | |||
| "task_definition", "TaskDefinitionByDataAttr") | |||
| method_name = self._method_pair.get(task_definition, | |||
| 'TaskMiningByDataAttr') | |||
| extend_param = self.parse_param( | |||
| self.method_selection.get("task_definition_param")) | |||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | |||
| task_extractor=self.extractor, **extend_param | |||
| ) | |||
| return method_cls(samples=samples) | |||
| def task_remodeling(self, samples, mappings): | |||
| method_name = self.method_selection.get("task_remodeling") | |||
| extend_param = self.parse_param( | |||
| self.method_selection.get("task_remodeling_param")) | |||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | |||
| models=self.models, **extend_param) | |||
| return method_cls(samples=samples, mappings=mappings) | |||
| def inference_integrate(self, tasks): | |||
| method_name = self.method_selection.get("inference_integrate") | |||
| extend_param = self.parse_param( | |||
| self.method_selection.get("inference_integrate_param")) | |||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | |||
| models=self.models, **extend_param) | |||
| return method_cls(tasks=tasks) if method_cls else tasks | |||
| def train(self, train_data: BaseDataSource, | |||
| valid_data: BaseDataSource = None, | |||
| post_process=None, **kwargs): | |||
| tasks, task_extractor, train_data = self.task_definition(train_data) | |||
| self.extractor = task_extractor | |||
| task_groups = self.task_relationship_discovery(tasks) | |||
| self.models = [] | |||
| callback = None | |||
| if post_process: | |||
| callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() | |||
| self.task_groups = [] | |||
| feedback = {} | |||
| rare_task = [] | |||
| for i, task in enumerate(task_groups): | |||
| if not isinstance(task, TaskGroup): | |||
| rare_task.append(i) | |||
| self.models.append(None) | |||
| self.task_groups.append(None) | |||
| continue | |||
| if not (task.samples and len(task.samples) | |||
| > self.min_train_sample): | |||
| self.models.append(None) | |||
| self.task_groups.append(None) | |||
| rare_task.append(i) | |||
| n = len(task.samples) | |||
| continue | |||
| LOGGER.info(f"MTL Train start {i} : {task.entry}") | |||
| model_obj = set_backend(estimator=self.base_model) | |||
| res = model_obj.train(train_data=task.samples, **kwargs) | |||
| if callback: | |||
| res = callback(model_obj, res) | |||
| model_path = model_obj.save(model_name=f"{task.entry}.model") | |||
| model = Model(index=i, entry=task.entry, | |||
| model=model_path, result=res) | |||
| model.meta_attr = [t.meta_attr for t in task.tasks] | |||
| task.model = model | |||
| self.models.append(model) | |||
| feedback[task.entry] = res | |||
| self.task_groups.append(task) | |||
| if len(rare_task): | |||
| model_obj = set_backend(estimator=self.base_model) | |||
| res = model_obj.train(train_data=train_data, **kwargs) | |||
| model_path = model_obj.save(model_name="global.model") | |||
| for i in rare_task: | |||
| task = task_groups[i] | |||
| entry = getattr(task, 'entry', "global") | |||
| if not isinstance(task, TaskGroup): | |||
| task = TaskGroup( | |||
| entry=entry, tasks=[] | |||
| ) | |||
| model = Model(index=i, entry=entry, | |||
| model=model_path, result=res) | |||
| model.meta_attr = [t.meta_attr for t in task.tasks] | |||
| task.model = model | |||
| task.samples = train_data | |||
| self.models[i] = model | |||
| feedback[entry] = res | |||
| self.task_groups[i] = task | |||
| extractor_file = FileOps.join_path( | |||
| os.path.dirname(self.task_index_url), | |||
| "kb_extractor.pkl" | |||
| ) | |||
| joblib.dump(self.extractor, extractor_file) | |||
| task_index = { | |||
| "extractor": extractor_file, | |||
| "task_groups": self.task_groups | |||
| } | |||
| joblib.dump(task_index, self.task_index_url) | |||
| if valid_data: | |||
| feedback = self.evaluate(valid_data, **kwargs) | |||
| return feedback | |||
| def predict(self, data: BaseDataSource, | |||
| post_process=None, **kwargs): | |||
| if not (self.models and self.extractor): | |||
| task_index = joblib.load(self.task_index_url) | |||
| extractor_file = FileOps.join_path( | |||
| os.path.dirname(self.task_index_url), | |||
| "kb_extractor.pkl" | |||
| ) | |||
| if not callable(task_index['extractor']) and \ | |||
| isinstance(task_index['extractor'], str): | |||
| FileOps.download(task_index['extractor'], extractor_file) | |||
| self.extractor = joblib.load(extractor_file) | |||
| else: | |||
| self.extractor = task_index['extractor'] | |||
| self.task_groups = task_index['task_groups'] | |||
| self.models = [task.model for task in self.task_groups] | |||
| data, mappings = self.task_mining(samples=data) | |||
| samples, models = self.task_remodeling(samples=data, mappings=mappings) | |||
| callback = None | |||
| if post_process: | |||
| callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() | |||
| tasks = [] | |||
| for inx, df in enumerate(samples): | |||
| m = models[inx] | |||
| if not isinstance(m, Model): | |||
| continue | |||
| model_obj = set_backend(estimator=self.base_model) | |||
| evaluator = model_obj.load(m.model) if isinstance( | |||
| m.model, str) else m.model | |||
| pred = evaluator.predict(df.x, kwargs=kwargs) | |||
| if callable(callback): | |||
| pred = callback(pred, df) | |||
| task = Task(entry=m.entry, samples=df) | |||
| task.result = pred | |||
| task.model = m | |||
| tasks.append(task) | |||
| res = self.inference_integrate(tasks) | |||
| return res, tasks | |||
| def evaluate(self, data: BaseDataSource, | |||
| metrics=None, | |||
| metrics_param=None, | |||
| **kwargs): | |||
| from sklearn import metrics as sk_metrics | |||
| result, tasks = self.predict(data, kwargs=kwargs) | |||
| m_dict = {} | |||
| if metrics: | |||
| if callable(metrics): # if metrics is a function | |||
| m_name = getattr(metrics, '__name__', "mtl_eval") | |||
| m_dict = { | |||
| m_name: metrics | |||
| } | |||
| elif isinstance(metrics, (set, list)): # if metrics is multiple | |||
| for inx, m in enumerate(metrics): | |||
| m_name = getattr(m, '__name__', f"mtl_eval_{inx}") | |||
| if isinstance(m, str): | |||
| m = getattr(sk_metrics, m) | |||
| if not callable(m): | |||
| continue | |||
| m_dict[m_name] = m | |||
| elif isinstance(metrics, str): # if metrics is single | |||
| m_dict = { | |||
| metrics: getattr(sk_metrics, metrics, sk_metrics.log_loss) | |||
| } | |||
| elif isinstance(metrics, dict): # if metrics with name | |||
| for k, v in metrics.items(): | |||
| if isinstance(v, str): | |||
| v = getattr(sk_metrics, v) | |||
| if not callable(v): | |||
| continue | |||
| m_dict[k] = v | |||
| if not len(m_dict): | |||
| m_dict = { | |||
| 'precision_score': sk_metrics.precision_score | |||
| } | |||
| metrics_param = {"average": "micro"} | |||
| data.x['pred_y'] = result | |||
| data.x['real_y'] = data.y | |||
| if not metrics_param: | |||
| metrics_param = {} | |||
| elif isinstance(metrics_param, str): | |||
| metrics_param = self.parse_param(metrics_param) | |||
| tasks_detail = [] | |||
| for task in tasks: | |||
| sample = task.samples | |||
| pred = task.result | |||
| scores = { | |||
| name: metric(sample.y, pred, **metrics_param) | |||
| for name, metric in m_dict.items() | |||
| } | |||
| task.scores = scores | |||
| tasks_detail.append(task) | |||
| task_eval_res = { | |||
| name: metric(data.y, result, **metrics_param) | |||
| for name, metric in m_dict.items() | |||
| } | |||
| return task_eval_res, tasks_detail | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .task_definition import * | |||
| from .task_relation_discover import * | |||
| from .task_mining import * | |||
| from .task_remodeling import * | |||
| from .inference_integrate import * | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from typing import List | |||
| __all__ = ('Task', 'TaskGroup', 'Model') | |||
| class Task: | |||
| def __init__(self, entry, samples, meta_attr=None): | |||
| self.entry = entry | |||
| self.samples = samples | |||
| self.meta_attr = meta_attr | |||
| self.model = None # define on running | |||
| self.result = None # define on running | |||
| class TaskGroup: | |||
| def __init__(self, entry, tasks: List[Task]): | |||
| self.entry = entry | |||
| self.tasks = tasks | |||
| self.samples = None # define by task_relation_discover algorithms | |||
| self.model = None # define on train | |||
| class Model: | |||
| def __init__(self, index: int, entry, model, result): | |||
| self.index = index # integer | |||
| self.entry = entry | |||
| self.model = model | |||
| self.result = result | |||
| self.meta_attr = None # define on running | |||
| @@ -0,0 +1,33 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| from typing import List | |||
| from .artifact import Task | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('DefaultInferenceIntegrate', ) | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class DefaultInferenceIntegrate: | |||
| def __init__(self, models: list, **kwargs): | |||
| self.models = models | |||
| def __call__(self, tasks: List[Task]): | |||
| res = {} | |||
| for task in tasks: | |||
| res.update(dict(zip(task.samples.inx, task.result))) | |||
| return np.array([z[1] | |||
| for z in sorted(res.items(), key=lambda x: x[0])]) | |||
| @@ -0,0 +1,110 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import pandas as pd | |||
| from typing import List, Any, Tuple | |||
| from .artifact import Task | |||
| from sedna.datasources import BaseDataSource | |||
| from sedna.common.class_factory import ClassType, ClassFactory | |||
| __all__ = ('TaskDefinitionBySVC', 'TaskDefinitionByDataAttr') | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class TaskDefinitionBySVC: | |||
| def __init__(self, **kwargs): | |||
| n_class = kwargs.get("n_class", "") | |||
| self.n_class = max(2, int(n_class)) if str(n_class).isdigit() else 2 | |||
| def __call__(self, | |||
| samples: BaseDataSource) -> Tuple[List[Task], | |||
| Any, | |||
| BaseDataSource]: | |||
| from sklearn.svm import SVC | |||
| from sklearn.cluster import AgglomerativeClustering | |||
| d_type = samples.data_type | |||
| x_data = samples.x | |||
| y_data = samples.y | |||
| if not isinstance(x_data, pd.DataFrame): | |||
| raise TypeError(f"{d_type} data should only be pd.DataFrame") | |||
| tasks = [] | |||
| legal = list( | |||
| filter(lambda col: x_data[col].dtype == 'float64', x_data.columns)) | |||
| df = x_data[legal] | |||
| c1 = AgglomerativeClustering(n_clusters=self.n_class).fit_predict(df) | |||
| c2 = SVC(gamma=0.01) | |||
| c2.fit(df, c1) | |||
| for task in range(self.n_class): | |||
| g_attr = f"svc_{task}" | |||
| task_df = BaseDataSource(data_type=d_type) | |||
| task_df.x = x_data.iloc[np.where(c1 == task)] | |||
| task_df.y = y_data.iloc[np.where(c1 == task)] | |||
| task_obj = Task(entry=g_attr, samples=task_df) | |||
| tasks.append(task_obj) | |||
| samples.x = df | |||
| return tasks, c2, samples | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class TaskDefinitionByDataAttr: | |||
| def __init__(self, **kwargs): | |||
| self.attr_filed = kwargs.get("attribute", []) | |||
| def __call__(self, | |||
| samples: BaseDataSource) -> Tuple[List[Task], | |||
| Any, | |||
| BaseDataSource]: | |||
| tasks = [] | |||
| d_type = samples.data_type | |||
| x_data = samples.x | |||
| y_data = samples.y | |||
| if not isinstance(x_data, pd.DataFrame): | |||
| raise TypeError(f"{d_type} data should only be pd.DataFrame") | |||
| _inx = 0 | |||
| task_index = {} | |||
| for meta_attr, df in x_data.groupby(self.attr_filed): | |||
| if isinstance(meta_attr, (list, tuple, set)): | |||
| g_attr = "_".join( | |||
| map(lambda x: str(x).replace("_", "-"), meta_attr)) | |||
| meta_attr = list(meta_attr) | |||
| else: | |||
| g_attr = str(meta_attr).replace("_", "-") | |||
| meta_attr = [meta_attr] | |||
| g_attr = g_attr.replace(" ", "") | |||
| if g_attr in task_index: | |||
| old_task = tasks[task_index[g_attr]] | |||
| old_task.x = pd.concat([old_task.x, df]) | |||
| old_task.y = pd.concat([old_task.y, y_data.iloc[df.index]]) | |||
| continue | |||
| task_index[g_attr] = _inx | |||
| task_df = BaseDataSource(data_type=d_type) | |||
| task_df.x = df.drop(self.attr_filed, axis=1) | |||
| task_df.y = y_data.iloc[df.index] | |||
| task_obj = Task(entry=g_attr, samples=task_df, meta_attr=meta_attr) | |||
| tasks.append(task_obj) | |||
| _inx += 1 | |||
| x_data.drop(self.attr_filed, axis=1, inplace=True) | |||
| samples = BaseDataSource(data_type=d_type) | |||
| samples.x = x_data | |||
| samples.y = y_data | |||
| return tasks, task_index, samples | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from sedna.datasources import BaseDataSource | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('TaskMiningBySVC', 'TaskMiningByDataAttr') | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class TaskMiningBySVC: | |||
| def __init__(self, task_extractor, **kwargs): | |||
| self.task_extractor = task_extractor | |||
| def __call__(self, samples: BaseDataSource): | |||
| df = samples.x | |||
| allocations = [0, ] * len(df) | |||
| legal = list( | |||
| filter(lambda col: df[col].dtype == 'float64', df.columns)) | |||
| if not len(legal): | |||
| return allocations | |||
| allocations = list(self.task_extractor.predict(df[legal])) | |||
| return samples, allocations | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class TaskMiningByDataAttr: | |||
| def __init__(self, task_extractor, **kwargs): | |||
| self.task_extractor = task_extractor | |||
| self.attr_filed = kwargs.get("attribute", []) | |||
| def __call__(self, samples: BaseDataSource): | |||
| df = samples.x | |||
| meta_attr = df[self.attr_filed] | |||
| allocations = meta_attr.apply( | |||
| lambda x: self.task_extractor.get( | |||
| "_".join( | |||
| map(lambda y: str(x[y]).replace("_", "-").replace(" ", ""), | |||
| self.attr_filed) | |||
| ), | |||
| 0), | |||
| axis=1).values.tolist() | |||
| samples.x = df.drop(self.attr_filed, axis=1) | |||
| samples.meta_attr = meta_attr | |||
| return samples, allocations | |||
| @@ -0,0 +1,34 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from typing import List | |||
| from .artifact import Task, TaskGroup | |||
| from sedna.common.class_factory import ClassType, ClassFactory | |||
| __all__ = ('DefaultTaskRelationDiscover', ) | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class DefaultTaskRelationDiscover: | |||
| def __init__(self, **kwargs): | |||
| pass | |||
| def __call__(self, tasks: List[Task]) -> List[TaskGroup]: | |||
| tg = [] | |||
| for task in tasks: | |||
| tg_obj = TaskGroup(entry=task.entry, tasks=[task]) | |||
| tg_obj.samples = task.samples | |||
| tg.append(tg_obj) | |||
| return tg | |||
| @@ -0,0 +1,44 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| from typing import List, Tuple | |||
| from sedna.datasources import BaseDataSource | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('DefaultTaskRemodeling',) | |||
| @ClassFactory.register(ClassType.MTL) | |||
| class DefaultTaskRemodeling: | |||
| def __init__(self, models: list, **kwargs): | |||
| self.models = models | |||
| def __call__(self, samples: BaseDataSource, mappings: List) \ | |||
| -> Tuple[List[BaseDataSource], List]: | |||
| mappings = np.array(mappings) | |||
| data, models = [], [] | |||
| d_type = samples.data_type | |||
| for m in np.unique(mappings): | |||
| task_df = BaseDataSource(data_type=d_type) | |||
| _inx = np.where(mappings == m) | |||
| task_df.x = samples.x.iloc[_inx] | |||
| if d_type != "test": | |||
| task_df.y = samples.y.iloc[_inx] | |||
| task_df.inx = _inx[0].tolist() | |||
| task_df.meta_attr = samples.meta_attr.iloc[_inx].values | |||
| data.append(task_df) | |||
| model = self.models[m] or self.models[0] | |||
| models.append(model) | |||
| return data, models | |||
| @@ -0,0 +1,68 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Unseen Task detect Algorithms for Lifelong Learning""" | |||
| import abc | |||
| import numpy as np | |||
| from typing import List | |||
| from sedna.algorithms.multi_task_learning.task_jobs.artifact import Task | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('ModelProbeFilter', 'TaskAttrFilter') | |||
| class BaseFilter(metaclass=abc.ABCMeta): | |||
| """The base class to define unified interface.""" | |||
| def __call__(self, task: Task = None): | |||
| """predict function, and it must be implemented by | |||
| different methods class. | |||
| :param task: inference task | |||
| :return: `True` means unseen task, `False` means not an unseen task. | |||
| """ | |||
| raise NotImplementedError | |||
| @ClassFactory.register(ClassType.UTD, alias="ModelProbe") | |||
| class ModelProbeFilter(BaseFilter, abc.ABC): | |||
| def __init__(self): | |||
| pass | |||
| def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): | |||
| all_proba = [] | |||
| for task in tasks: | |||
| sample = task.samples | |||
| model = task.model | |||
| if hasattr(model, "predict_proba"): | |||
| proba = model.predict_proba(sample) | |||
| all_proba.append(np.max(proba)) | |||
| return np.mean(all_proba) > threshold if all_proba else True | |||
| @ClassFactory.register(ClassType.UTD, alias="TaskAttr") | |||
| class TaskAttrFilter(BaseFilter, abc.ABC): | |||
| def __init__(self): | |||
| pass | |||
| def __call__(self, tasks: List[Task] = None, **kwargs): | |||
| for task in tasks: | |||
| model_attr = list(map(list, task.model.meta_attr)) | |||
| sample_attr = list(map(list, task.samples.meta_attr)) | |||
| if not (model_attr and sample_attr): | |||
| continue | |||
| if list(model_attr) == list(sample_attr): | |||
| return False | |||
| return True | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .lifelong_learning import * | |||
| @@ -0,0 +1,221 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import os | |||
| import joblib | |||
| import tempfile | |||
| from sedna.backend import set_backend | |||
| from sedna.core.base import JobBase | |||
| from sedna.common.file_ops import FileOps | |||
| from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus | |||
| from sedna.common.config import Context | |||
| from sedna.common.class_factory import ClassType, ClassFactory | |||
| from sedna.algorithms.multi_task_learning import MulTaskLearning | |||
| from sedna.service.client import KBClient | |||
| class LifelongLearning(JobBase): | |||
| """ | |||
| Lifelong learning Experiment | |||
| """ | |||
| def __init__(self, | |||
| estimator, | |||
| task_definition="TaskDefinitionByDataAttr", | |||
| task_relationship_discovery=None, | |||
| task_mining=None, | |||
| task_remodeling=None, | |||
| inference_integrate=None, | |||
| unseen_task_detect="TaskAttrFilter", | |||
| task_definition_param=None, | |||
| relationship_discovery_param=None, | |||
| task_mining_param=None, | |||
| task_remodeling_param=None, | |||
| inference_integrate_param=None, | |||
| unseen_task_detect_param=None): | |||
| e = MulTaskLearning( | |||
| estimator=estimator, | |||
| task_definition=task_definition, | |||
| task_relationship_discovery=task_relationship_discovery, | |||
| task_mining=task_mining, | |||
| task_remodeling=task_remodeling, | |||
| inference_integrate=inference_integrate, | |||
| task_definition_param=task_definition_param, | |||
| relationship_discovery_param=relationship_discovery_param, | |||
| task_mining_param=task_mining_param, | |||
| task_remodeling_param=task_remodeling_param, | |||
| inference_integrate_param=inference_integrate_param) | |||
| self.unseen_task_detect = unseen_task_detect | |||
| self.unseen_task_detect_param = e.parse_param( | |||
| unseen_task_detect_param) | |||
| config = dict( | |||
| ll_kb_server=Context.get_parameters("KB_SERVER"), | |||
| output_url=Context.get_parameters("OUTPUT_URL", "/tmp") | |||
| ) | |||
| task_index = FileOps.join_path(config['output_url'], 'index.pkl') | |||
| config['task_index'] = task_index | |||
| super(LifelongLearning, self).__init__( | |||
| estimator=estimator, config=config) | |||
| self.job_kind = K8sResourceKind.LIFELONG_JOB.value | |||
| self.kb_server = KBClient(kbserver=self.config.ll_kb_server) | |||
| def train(self, train_data, | |||
| valid_data=None, | |||
| post_process=None, | |||
| action="initial", | |||
| **kwargs): | |||
| """ | |||
| :param train_data: data use to train model | |||
| :param valid_data: data use to valid model | |||
| :param post_process: callback function | |||
| :param action: initial - kb init, update - kb incremental update | |||
| """ | |||
| callback_func = None | |||
| if post_process is not None: | |||
| callback_func = ClassFactory.get_cls( | |||
| ClassType.CALLBACK, post_process) | |||
| res = self.estimator.train( | |||
| train_data=train_data, | |||
| valid_data=valid_data, | |||
| **kwargs | |||
| ) # todo: Distinguishing incremental update and fully overwrite | |||
| task_groups = self.estimator.estimator.task_groups | |||
| extractor_file = FileOps.join_path( | |||
| os.path.dirname(self.estimator.estimator.task_index_url), | |||
| "kb_extractor.pkl" | |||
| ) | |||
| try: | |||
| extractor_file = self.kb_server.upload_file(extractor_file) | |||
| except Exception as err: | |||
| self.log.error( | |||
| f"Upload task extractor_file fail {extractor_file}: {err}") | |||
| extractor_file = joblib.load(extractor_file) | |||
| for task in task_groups: | |||
| try: | |||
| model = self.kb_server.upload_file(task.model.model) | |||
| except Exception: | |||
| model_obj = set_backend( | |||
| estimator=self.estimator.estimator.base_model | |||
| ) | |||
| model = model_obj.load(task.model.model) | |||
| task.model.model = model | |||
| task_info = { | |||
| "task_groups": task_groups, | |||
| "extractor": extractor_file | |||
| } | |||
| fd, name = tempfile.mkstemp() | |||
| joblib.dump(task_info, name) | |||
| index_file = self.kb_server.update_db(name) | |||
| if not index_file: | |||
| self.log.error(f"KB update Fail !") | |||
| index_file = name | |||
| FileOps.download(index_file, self.config.task_index) | |||
| if os.path.isfile(name): | |||
| os.close(fd) | |||
| os.remove(name) | |||
| task_info_res = self.estimator.model_info( | |||
| self.config.task_index, result=res, | |||
| relpath=self.config.data_path_prefix) | |||
| self.report_task_info( | |||
| None, K8sResourceKindStatus.COMPLETED.value, task_info_res) | |||
| self.log.info(f"Lifelong learning Experiment Finished, " | |||
| f"KB idnex save in {self.config.task_index}") | |||
| return callback_func(self.estimator, res) if callback_func else res | |||
| def update(self, train_data, valid_data=None, post_process=None, **kwargs): | |||
| return self.train( | |||
| train_data=train_data, | |||
| valid_data=valid_data, | |||
| post_process=post_process, | |||
| action="update", | |||
| **kwargs | |||
| ) | |||
| def evaluate(self, data, post_process=None, model_threshold=0.1, **kwargs): | |||
| callback_func = None | |||
| if callable(post_process): | |||
| callback_func = post_process | |||
| elif post_process is not None: | |||
| callback_func = ClassFactory.get_cls( | |||
| ClassType.CALLBACK, post_process) | |||
| task_index_url = self.get_parameters( | |||
| "MODEL_URLS", self.config.task_index) | |||
| index_url = self.estimator.estimator.task_index_url | |||
| self.log.info( | |||
| f"Download kb index from {task_index_url} to {index_url}") | |||
| FileOps.download(task_index_url, index_url) | |||
| res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) | |||
| drop_tasks = [] | |||
| for detail in tasks_detail: | |||
| scores = detail.scores | |||
| entry = detail.entry | |||
| self.log.info(f"{entry} socres: {scores}") | |||
| if any(map(lambda x: float(x) < model_threshold, scores.values())): | |||
| self.log.warn( | |||
| f"{entry} will not be deploy " | |||
| f"because scores lt {model_threshold}") | |||
| drop_tasks.append(entry) | |||
| continue | |||
| drop_task = ",".join(drop_tasks) | |||
| index_file = self.kb_server.update_task_status(drop_task, new_status=0) | |||
| if not index_file: | |||
| self.log.error(f"KB update Fail !") | |||
| index_file = str(index_url) | |||
| self.log.info( | |||
| f"upload kb index from {index_file} to {self.config.task_index}") | |||
| FileOps.download(index_file, self.config.task_index) | |||
| task_info_res = self.estimator.model_info( | |||
| self.config.task_index, result=res, | |||
| relpath=self.config.data_path_prefix) | |||
| self.report_task_info( | |||
| None, | |||
| K8sResourceKindStatus.COMPLETED.value, | |||
| task_info_res, | |||
| kind="eval") | |||
| return callback_func(res) if callback_func else res | |||
| def inference(self, data=None, post_process=None, **kwargs): | |||
| task_index_url = self.get_parameters( | |||
| "MODEL_URLS", self.config.task_index) | |||
| index_url = self.estimator.estimator.task_index_url | |||
| FileOps.download(task_index_url, index_url) | |||
| res, tasks = self.estimator.predict( | |||
| data=data, post_process=post_process, **kwargs | |||
| ) | |||
| is_unseen_task = False | |||
| if self.unseen_task_detect: | |||
| try: | |||
| if callable(self.unseen_task_detect): | |||
| unseen_task_detect_algorithm = self.unseen_task_detect() | |||
| else: | |||
| unseen_task_detect_algorithm = ClassFactory.get_cls( | |||
| ClassType.UTD, self.unseen_task_detect | |||
| )() | |||
| except ValueError as err: | |||
| self.log.error( | |||
| "Lifelong learning Experiment " | |||
| "Inference [UTD] : {}".format(err)) | |||
| else: | |||
| is_unseen_task = unseen_task_detect_algorithm( | |||
| tasks=tasks, result=res, **self.unseen_task_detect_param | |||
| ) | |||
| return res, is_unseen_task, tasks | |||