diff --git a/examples/build_image.sh b/examples/build_image.sh index 449d7dec..c3121a3f 100644 --- a/examples/build_image.sh +++ b/examples/build_image.sh @@ -27,6 +27,7 @@ federated-learning-surface-defect-detection-train.Dockerfile incremental-learning-helmet-detection.Dockerfile joint-inference-helmet-detection-big.Dockerfile joint-inference-helmet-detection-little.Dockerfile +lifelong-learning-atcii-classifier.Dockerfile ) for dockerfile in ${dockerfiles[@]}; do diff --git a/examples/lifelong_learning/atcii/README.md b/examples/lifelong_learning/atcii/README.md new file mode 100644 index 00000000..9b096807 --- /dev/null +++ b/examples/lifelong_learning/atcii/README.md @@ -0,0 +1,177 @@ +# Using Lifelong Learning Job in Thermal Comfort Prediction Scenario + +This document introduces how to use lifelong learning job in thermal comfort prediction scenario. +Using the lifelong learning job, our application can automatically retrains, evaluates, +and updates models based on the data generated at the edge. + +## Thermal Comfort Prediction Experiment + + +### Install Sedna + +Follow the [Sedna installation document](/docs/setup/install.md) to install Sedna. + +### Prepare Dataset +In this example, you can download ASHRAE Global Thermal Comfort Database II to initial lifelong learning Knowledgebase . + + + +download [datasets](https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/atcii-classifier/dataset.tar.gz), including train、evaluation and incremental dataset. +``` +cd / +wget https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/atcii-classifier/dataset.tar.gz +tar -zxvf dataset.tar.gz +``` +### Prepare for Knowledgebase Server +in this example, we create a knowledgebase restful server with sqlite3, the database will storage to `LIFELONG_KB_URL`, and +run at `GM Node`. +### Prepare Image +this example use the image: +``` +swr.cn-southwest-2.myhuaweicloud.com/sedna-test/sedna/kb:v0.0.1 +``` + +### Create KB Deployment + +``` +kubectl create -f scripts/knowledge-server/kb.yaml +``` + +### Create Lifelong Job +in this example, `$WORKER_NODE` is a custom node, you can fill it which you actually run. + + +``` +WORKER_NODE="edge-node" +``` +Create Dataset + +``` +kubectl create -f - <" + threshold: 500 + metric: num_of_samples + evalSpec: + template: + spec: + nodeName: "edge-node" + containers: + - image: kubeedge/sedna-example-lifelong-learning-atcii-classifier:v0.1.0 + name: eval-worker + imagePullPolicy: IfNotPresent + args: ["eval.py"] + env: + - name: "metrics" + value: "precision_score" + - name: "metric_param" + value: "{'average': 'micro'}" + - name: "model_threshold" + value: "0.5" + deploySpec: + template: + spec: + nodeName: "edge-node" + containers: + - image: kubeedge/sedna-example-lifelong-learning-atcii-classifier:v0.1.0 + name: infer-worker + imagePullPolicy: IfNotPresent + args: ["inference.py"] + env: + - name: "UT_SAVED_URL" + value: "/ut_saved_url" + - name: "infer_dataset_url" + value: "/data/testData.csv" + volumeMounts: + - name: utdir + mountPath: /ut_saved_url + - name: inferdata + mountPath: /data/ + resources: # user defined resources + limits: + memory: 2Gi + volumes: # user defined volumes + - name: utdir + hostPath: + path: /lifelong/unseen_task/ + type: DirectoryOrCreate + - name: inferdata + hostPath: + path: /lifelong/data/ + type: DirectoryOrCreate + outputDir: "/output" +EOF +``` + +### Check Lifelong Learning Job +query the service status +``` +kubectl get lifelonglearningjob atcii-classifier-demo +``` +In the `lifelonglearningjob` resource atcii-classifier-demo, the following trigger is configured: +``` +trigger: + checkPeriodSeconds: 60 + timer: + start: 02:00 + end: 20:00 + condition: + operator: ">" + threshold: 500 + metric: num_of_samples +``` + +### Unseen Tasks samples Labeling +In a real word, we need to label the hard examples in our unseen tasks which storage in `UT_SAVED_URL` with annotation tools and then put the examples to `Dataset`'s url. + + +### Effect Display +in this example, false and failed detections occur at stage of inference before lifelong learning, after lifelong learning, +Greatly improve the precision and accuracy of the dataset. + +![img_1.png](image/effect_comparison.png) diff --git a/examples/lifelong_learning/atcii/eval.py b/examples/lifelong_learning/atcii/eval.py new file mode 100644 index 00000000..8e52bcd5 --- /dev/null +++ b/examples/lifelong_learning/atcii/eval.py @@ -0,0 +1,44 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from interface import DATACONF, Estimator, feature_process +from sedna.common.config import Context, BaseConfig +from sedna.datasources import CSVDataParse + +from sedna.core.lifelong_learning import LifelongLearning + + +def main(): + test_dataset_url = BaseConfig.test_dataset_url + valid_data = CSVDataParse(data_type="valid", func=feature_process) + valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) + attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) + model_threshold = float(Context.get_parameters('model_threshold', 0)) + + ll_job = LifelongLearning( + estimator=Estimator, + task_definition="TaskDefinitionByDataAttr", + task_definition_param=attribute + ) + eval_experiment = ll_job.evaluate( + data=valid_data, metrics="precision_score", + metrics_param={"average": "micro"}, + model_threshold=model_threshold + ) + return eval_experiment + + +if __name__ == '__main__': + print(main()) diff --git a/examples/lifelong_learning/atcii/image/effect_comparison.png b/examples/lifelong_learning/atcii/image/effect_comparison.png new file mode 100644 index 00000000..bfe50465 Binary files /dev/null and b/examples/lifelong_learning/atcii/image/effect_comparison.png differ diff --git a/examples/lifelong_learning/atcii/inference.py b/examples/lifelong_learning/atcii/inference.py new file mode 100644 index 00000000..31ffdd48 --- /dev/null +++ b/examples/lifelong_learning/atcii/inference.py @@ -0,0 +1,72 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import csv +import json +import time +from interface import DATACONF, Estimator, feature_process +from sedna.common.config import Context +from sedna.datasources import CSVDataParse +from sedna.core.lifelong_learning import LifelongLearning + + +def main(): + + utd = Context.get_parameters("UTD_NAME", "TaskAttr") + attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) + utd_parameters = Context.get_parameters("UTD_PARAMETERS", {}) + ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp") + + ll_job = LifelongLearning( + estimator=Estimator, + task_mining="TaskMiningByDataAttr", + task_mining_param=attribute, + unseen_task_detect=utd, + unseen_task_detect_param=utd_parameters) + + infer_dataset_url = Context.get_parameters('infer_dataset_url') + file_handle = open(infer_dataset_url, "r", encoding="utf-8") + header = list(csv.reader([file_handle.readline().strip()]))[0] + infer_data = CSVDataParse(data_type="test", func=feature_process) + + unseen_sample = open(os.path.join(ut_saved_url, "unseen_sample.csv"), + "w", encoding="utf-8") + unseen_sample.write("\t".join(header + ['pred']) + "\n") + output_sample = open(f"{infer_dataset_url}_out.csv", "w", encoding="utf-8") + output_sample.write("\t".join(header + ['pred']) + "\n") + + while 1: + where = file_handle.tell() + line = file_handle.readline() + if not line: + time.sleep(1) + file_handle.seek(where) + continue + reader = list(csv.reader([line.strip()])) + rows = reader[0] + data = dict(zip(header, rows)) + infer_data.parse(data, label=DATACONF["LABEL"]) + rsl, is_unseen, target_task = ll_job.inference(infer_data) + + rows.append(list(rsl)[0]) + if is_unseen: + unseen_sample.write("\t".join(map(str, rows)) + "\n") + output_sample.write("\t".join(map(str, rows)) + "\n") + unseen_sample.close() + output_sample.close() + + +if __name__ == '__main__': + print(main()) diff --git a/examples/lifelong_learning/atcii/interface.py b/examples/lifelong_learning/atcii/interface.py new file mode 100644 index 00000000..6fee43ea --- /dev/null +++ b/examples/lifelong_learning/atcii/interface.py @@ -0,0 +1,124 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pandas as pd +import numpy as np +import xgboost +from sklearn.model_selection import train_test_split +from sklearn.metrics import precision_score + +os.environ['BACKEND_TYPE'] = 'SKLEARN' + +DATACONF = { + "ATTRIBUTES": ["Season", "Cooling startegy_building level"], + "LABEL": "Thermal preference", +} + + +def feature_process(df: pd.DataFrame): + if "City" in df.columns: + df.drop(["City"], axis=1, inplace=True) + for feature in df.columns: + if feature in ["Season", ]: + continue + df[feature] = df[feature].apply(lambda x: float(x) if x else 0.0) + df['Thermal preference'] = df['Thermal preference'].apply( + lambda x: int(float(x)) if x else 1) + return df + + +class Estimator: + def __init__(self): + """Model init""" + self.model = xgboost.XGBClassifier( + learning_rate=0.1, + n_estimators=600, + max_depth=2, + min_child_weight=1, + gamma=0, + subsample=0.8, + colsample_bytree=0.8, + objective="multi:softmax", + num_class=3, + nthread=4, + seed=27) + + def train(self, train_data, valid_data=None, + save_best=True, + metric_name="mlogloss", + early_stopping_rounds=100 + ): + es = [ + xgboost.callback.EarlyStopping( + metric_name=metric_name, + rounds=early_stopping_rounds, + save_best=save_best + ) + ] + x, y = train_data.x, train_data.y + if valid_data: + x1, y1 = valid_data.x, valid_data.y + else: + x, x1, y, y1 = train_test_split( + x, y, test_size=0.1, random_state=42) + history = self.model.fit(x, y, eval_set=[(x1, y1), ], callbacks=es) + d = {} + for k, v in history.evals_result().items(): + for k1, v1, in v.items(): + m = np.mean(v1) + if k1 not in d: + d[k1] = [] + d[k1].append(m) + for k, v in d.items(): + d[k] = np.mean(v) + return d + + def predict(self, datas, **kwargs): + """ Model inference """ + return self.model.predict(datas) + + def predict_proba(self, datas, **kwargs): + return self.model.predict_proba(datas) + + def evaluate(self, test_data, **kwargs): + """ Model evaluate """ + y_pred = self.predict(test_data.x) + return precision_score(test_data.y, y_pred, average="micro") + + def load(self, model_url): + self.model.load_model(model_url) + return self + + def save(self, model_path=None): + """ + save model as a single pb file from checkpoint + """ + return self.model.save_model(model_path) + + +if __name__ == '__main__': + from sedna.datasources import CSVDataParse + from sedna.common.config import BaseConfig + + train_dataset_url = BaseConfig.train_dataset_url + train_data = CSVDataParse(data_type="train", func=feature_process) + train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) + + test_dataset_url = BaseConfig.test_dataset_url + valid_data = CSVDataParse(data_type="valid", func=feature_process) + valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) + + model = Estimator() + print(model.train(train_data)) + print(model.evaluate(test_data=valid_data)) diff --git a/examples/lifelong_learning/atcii/train.py b/examples/lifelong_learning/atcii/train.py new file mode 100644 index 00000000..d5901a5a --- /dev/null +++ b/examples/lifelong_learning/atcii/train.py @@ -0,0 +1,48 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from interface import DATACONF, Estimator, feature_process +from sedna.common.config import Context, BaseConfig +from sedna.datasources import CSVDataParse + +from sedna.core.lifelong_learning import LifelongLearning + + +def main(): + # load dataset. + train_dataset_url = BaseConfig.train_dataset_url + train_data = CSVDataParse(data_type="train", func=feature_process) + train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) + attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) + early_stopping_rounds = int( + Context.get_parameters( + "early_stopping_rounds", 100)) + metric_name = Context.get_parameters("metric_name", "mlogloss") + ll_job = LifelongLearning( + estimator=Estimator, + task_definition="TaskDefinitionByDataAttr", + task_definition_param=attribute + ) + train_experiment = ll_job.train( + train_data=train_data, + metric_name=metric_name, + early_stopping_rounds=early_stopping_rounds + ) + + return train_experiment + + +if __name__ == '__main__': + print(main()) diff --git a/lib/sedna/algorithms/__init__.py b/lib/sedna/algorithms/__init__.py index 2e9f7352..43a49928 100644 --- a/lib/sedna/algorithms/__init__.py +++ b/lib/sedna/algorithms/__init__.py @@ -14,3 +14,5 @@ from .aggregation import * from .hard_example_mining import * +from .multi_task_learning import * +from .unseen_task_detect import * diff --git a/lib/sedna/algorithms/multi_task_learning/__init__.py b/lib/sedna/algorithms/multi_task_learning/__init__.py new file mode 100644 index 00000000..b43a52b5 --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .task_jobs import * +from .multi_task_learning import MulTaskLearning diff --git a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py new file mode 100644 index 00000000..1a31eae3 --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py @@ -0,0 +1,312 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import joblib + +from .task_jobs.artifact import Model, Task, TaskGroup +from sedna.datasources import BaseDataSource +from sedna.backend import set_backend +from sedna.common.log import LOGGER +from sedna.common.config import Context +from sedna.common.file_ops import FileOps +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('MulTaskLearning',) + + +class MulTaskLearning: + _method_pair = { + 'TaskDefinitionBySVC': 'TaskMiningBySVC', + 'TaskDefinitionByDataAttr': 'TaskMiningByDataAttr', + } + + def __init__(self, + estimator=None, + task_definition="TaskDefinitionByDataAttr", + task_relationship_discovery=None, + task_mining=None, + task_remodeling=None, + inference_integrate=None, + task_definition_param=None, + relationship_discovery_param=None, + task_mining_param=None, + task_remodeling_param=None, + inference_integrate_param=None + ): + if not task_relationship_discovery: + task_relationship_discovery = "DefaultTaskRelationDiscover" + if not task_remodeling: + task_remodeling = "DefaultTaskRemodeling" + if not inference_integrate: + inference_integrate = "DefaultInferenceIntegrate" + self.method_selection = dict( + task_definition=task_definition, + task_relationship_discovery=task_relationship_discovery, + task_mining=task_mining, + task_remodeling=task_remodeling, + inference_integrate=inference_integrate, + task_definition_param=task_definition_param, + task_relationship_discovery_param=relationship_discovery_param, + task_mining_param=task_mining_param, + task_remodeling_param=task_remodeling_param, + inference_integrate_param=inference_integrate_param) + self.models = None + self.extractor = None + self.base_model = estimator + self.task_groups = None + self.task_index_url = Context.get_parameters( + "MODEL_URLS", '/tmp/index.pkl' + ) + self.min_train_sample = int(Context.get_parameters( + "MIN_TRAIN_SAMPLE", '10' + )) + + @staticmethod + def parse_param(param_str): + if not param_str: + return {} + try: + raw_dict = json.loads(param_str, encoding="utf-8") + except json.JSONDecodeError: + raw_dict = {} + return raw_dict + + def task_definition(self, samples): + method_name = self.method_selection.get( + "task_definition", "TaskDefinitionByDataAttr") + extend_param = self.parse_param( + self.method_selection.get("task_definition_param")) + method_cls = ClassFactory.get_cls( + ClassType.MTL, method_name)(**extend_param) + return method_cls(samples) + + def task_relationship_discovery(self, tasks): + method_name = self.method_selection.get("task_relationship_discovery") + extend_param = self.parse_param( + self.method_selection.get("task_relationship_discovery_param") + ) + method_cls = ClassFactory.get_cls( + ClassType.MTL, method_name)(**extend_param) + return method_cls(tasks) + + def task_mining(self, samples): + method_name = self.method_selection.get("task_mining") + extend_param = self.parse_param( + self.method_selection.get("task_mining_param")) + + if not method_name: + task_definition = self.method_selection.get( + "task_definition", "TaskDefinitionByDataAttr") + method_name = self._method_pair.get(task_definition, + 'TaskMiningByDataAttr') + extend_param = self.parse_param( + self.method_selection.get("task_definition_param")) + method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( + task_extractor=self.extractor, **extend_param + ) + return method_cls(samples=samples) + + def task_remodeling(self, samples, mappings): + method_name = self.method_selection.get("task_remodeling") + extend_param = self.parse_param( + self.method_selection.get("task_remodeling_param")) + method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( + models=self.models, **extend_param) + return method_cls(samples=samples, mappings=mappings) + + def inference_integrate(self, tasks): + method_name = self.method_selection.get("inference_integrate") + extend_param = self.parse_param( + self.method_selection.get("inference_integrate_param")) + method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( + models=self.models, **extend_param) + return method_cls(tasks=tasks) if method_cls else tasks + + def train(self, train_data: BaseDataSource, + valid_data: BaseDataSource = None, + post_process=None, **kwargs): + tasks, task_extractor, train_data = self.task_definition(train_data) + self.extractor = task_extractor + task_groups = self.task_relationship_discovery(tasks) + self.models = [] + callback = None + if post_process: + callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() + self.task_groups = [] + feedback = {} + rare_task = [] + for i, task in enumerate(task_groups): + if not isinstance(task, TaskGroup): + rare_task.append(i) + self.models.append(None) + self.task_groups.append(None) + continue + if not (task.samples and len(task.samples) + > self.min_train_sample): + self.models.append(None) + self.task_groups.append(None) + rare_task.append(i) + n = len(task.samples) + continue + LOGGER.info(f"MTL Train start {i} : {task.entry}") + + model_obj = set_backend(estimator=self.base_model) + res = model_obj.train(train_data=task.samples, **kwargs) + if callback: + res = callback(model_obj, res) + model_path = model_obj.save(model_name=f"{task.entry}.model") + model = Model(index=i, entry=task.entry, + model=model_path, result=res) + model.meta_attr = [t.meta_attr for t in task.tasks] + task.model = model + self.models.append(model) + feedback[task.entry] = res + self.task_groups.append(task) + if len(rare_task): + model_obj = set_backend(estimator=self.base_model) + res = model_obj.train(train_data=train_data, **kwargs) + model_path = model_obj.save(model_name="global.model") + for i in rare_task: + task = task_groups[i] + entry = getattr(task, 'entry', "global") + if not isinstance(task, TaskGroup): + task = TaskGroup( + entry=entry, tasks=[] + ) + model = Model(index=i, entry=entry, + model=model_path, result=res) + model.meta_attr = [t.meta_attr for t in task.tasks] + task.model = model + task.samples = train_data + self.models[i] = model + feedback[entry] = res + self.task_groups[i] = task + extractor_file = FileOps.join_path( + os.path.dirname(self.task_index_url), + "kb_extractor.pkl" + ) + joblib.dump(self.extractor, extractor_file) + task_index = { + "extractor": extractor_file, + "task_groups": self.task_groups + } + joblib.dump(task_index, self.task_index_url) + if valid_data: + feedback = self.evaluate(valid_data, **kwargs) + + return feedback + + def predict(self, data: BaseDataSource, + post_process=None, **kwargs): + if not (self.models and self.extractor): + task_index = joblib.load(self.task_index_url) + extractor_file = FileOps.join_path( + os.path.dirname(self.task_index_url), + "kb_extractor.pkl" + ) + if not callable(task_index['extractor']) and \ + isinstance(task_index['extractor'], str): + FileOps.download(task_index['extractor'], extractor_file) + self.extractor = joblib.load(extractor_file) + else: + self.extractor = task_index['extractor'] + self.task_groups = task_index['task_groups'] + self.models = [task.model for task in self.task_groups] + data, mappings = self.task_mining(samples=data) + samples, models = self.task_remodeling(samples=data, mappings=mappings) + + callback = None + if post_process: + callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() + + tasks = [] + for inx, df in enumerate(samples): + m = models[inx] + if not isinstance(m, Model): + continue + model_obj = set_backend(estimator=self.base_model) + evaluator = model_obj.load(m.model) if isinstance( + m.model, str) else m.model + pred = evaluator.predict(df.x, kwargs=kwargs) + if callable(callback): + pred = callback(pred, df) + task = Task(entry=m.entry, samples=df) + task.result = pred + task.model = m + tasks.append(task) + res = self.inference_integrate(tasks) + return res, tasks + + def evaluate(self, data: BaseDataSource, + metrics=None, + metrics_param=None, + **kwargs): + from sklearn import metrics as sk_metrics + + result, tasks = self.predict(data, kwargs=kwargs) + m_dict = {} + if metrics: + if callable(metrics): # if metrics is a function + m_name = getattr(metrics, '__name__', "mtl_eval") + m_dict = { + m_name: metrics + } + elif isinstance(metrics, (set, list)): # if metrics is multiple + for inx, m in enumerate(metrics): + m_name = getattr(m, '__name__', f"mtl_eval_{inx}") + if isinstance(m, str): + m = getattr(sk_metrics, m) + if not callable(m): + continue + m_dict[m_name] = m + elif isinstance(metrics, str): # if metrics is single + m_dict = { + metrics: getattr(sk_metrics, metrics, sk_metrics.log_loss) + } + elif isinstance(metrics, dict): # if metrics with name + for k, v in metrics.items(): + if isinstance(v, str): + v = getattr(sk_metrics, v) + if not callable(v): + continue + m_dict[k] = v + if not len(m_dict): + m_dict = { + 'precision_score': sk_metrics.precision_score + } + metrics_param = {"average": "micro"} + + data.x['pred_y'] = result + data.x['real_y'] = data.y + if not metrics_param: + metrics_param = {} + elif isinstance(metrics_param, str): + metrics_param = self.parse_param(metrics_param) + tasks_detail = [] + for task in tasks: + sample = task.samples + pred = task.result + scores = { + name: metric(sample.y, pred, **metrics_param) + for name, metric in m_dict.items() + } + task.scores = scores + tasks_detail.append(task) + task_eval_res = { + name: metric(data.y, result, **metrics_param) + for name, metric in m_dict.items() + } + return task_eval_res, tasks_detail diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/__init__.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/__init__.py new file mode 100644 index 00000000..6cf77600 --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .task_definition import * +from .task_relation_discover import * + +from .task_mining import * +from .task_remodeling import * + +from .inference_integrate import * diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py new file mode 100644 index 00000000..711772b2 --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/artifact.py @@ -0,0 +1,44 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +__all__ = ('Task', 'TaskGroup', 'Model') + + +class Task: + def __init__(self, entry, samples, meta_attr=None): + self.entry = entry + self.samples = samples + self.meta_attr = meta_attr + self.model = None # define on running + self.result = None # define on running + + +class TaskGroup: + + def __init__(self, entry, tasks: List[Task]): + self.entry = entry + self.tasks = tasks + self.samples = None # define by task_relation_discover algorithms + self.model = None # define on train + + +class Model: + def __init__(self, index: int, entry, model, result): + self.index = index # integer + self.entry = entry + self.model = model + self.result = result + self.meta_attr = None # define on running diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/inference_integrate.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/inference_integrate.py new file mode 100644 index 00000000..c6f72d6f --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/inference_integrate.py @@ -0,0 +1,33 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import List +from .artifact import Task +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('DefaultInferenceIntegrate', ) + + +@ClassFactory.register(ClassType.MTL) +class DefaultInferenceIntegrate: + def __init__(self, models: list, **kwargs): + self.models = models + + def __call__(self, tasks: List[Task]): + res = {} + for task in tasks: + res.update(dict(zip(task.samples.inx, task.result))) + return np.array([z[1] + for z in sorted(res.items(), key=lambda x: x[0])]) diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/task_definition.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_definition.py new file mode 100644 index 00000000..f853698f --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_definition.py @@ -0,0 +1,110 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +from typing import List, Any, Tuple +from .artifact import Task +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassType, ClassFactory + + +__all__ = ('TaskDefinitionBySVC', 'TaskDefinitionByDataAttr') + + +@ClassFactory.register(ClassType.MTL) +class TaskDefinitionBySVC: + def __init__(self, **kwargs): + n_class = kwargs.get("n_class", "") + self.n_class = max(2, int(n_class)) if str(n_class).isdigit() else 2 + + def __call__(self, + samples: BaseDataSource) -> Tuple[List[Task], + Any, + BaseDataSource]: + from sklearn.svm import SVC + from sklearn.cluster import AgglomerativeClustering + + d_type = samples.data_type + x_data = samples.x + y_data = samples.y + if not isinstance(x_data, pd.DataFrame): + raise TypeError(f"{d_type} data should only be pd.DataFrame") + tasks = [] + legal = list( + filter(lambda col: x_data[col].dtype == 'float64', x_data.columns)) + + df = x_data[legal] + c1 = AgglomerativeClustering(n_clusters=self.n_class).fit_predict(df) + c2 = SVC(gamma=0.01) + c2.fit(df, c1) + + for task in range(self.n_class): + g_attr = f"svc_{task}" + task_df = BaseDataSource(data_type=d_type) + task_df.x = x_data.iloc[np.where(c1 == task)] + task_df.y = y_data.iloc[np.where(c1 == task)] + + task_obj = Task(entry=g_attr, samples=task_df) + tasks.append(task_obj) + samples.x = df + return tasks, c2, samples + + +@ClassFactory.register(ClassType.MTL) +class TaskDefinitionByDataAttr: + def __init__(self, **kwargs): + self.attr_filed = kwargs.get("attribute", []) + + def __call__(self, + samples: BaseDataSource) -> Tuple[List[Task], + Any, + BaseDataSource]: + tasks = [] + d_type = samples.data_type + x_data = samples.x + y_data = samples.y + if not isinstance(x_data, pd.DataFrame): + raise TypeError(f"{d_type} data should only be pd.DataFrame") + + _inx = 0 + task_index = {} + for meta_attr, df in x_data.groupby(self.attr_filed): + if isinstance(meta_attr, (list, tuple, set)): + g_attr = "_".join( + map(lambda x: str(x).replace("_", "-"), meta_attr)) + meta_attr = list(meta_attr) + else: + g_attr = str(meta_attr).replace("_", "-") + meta_attr = [meta_attr] + g_attr = g_attr.replace(" ", "") + if g_attr in task_index: + old_task = tasks[task_index[g_attr]] + old_task.x = pd.concat([old_task.x, df]) + old_task.y = pd.concat([old_task.y, y_data.iloc[df.index]]) + continue + task_index[g_attr] = _inx + + task_df = BaseDataSource(data_type=d_type) + task_df.x = df.drop(self.attr_filed, axis=1) + task_df.y = y_data.iloc[df.index] + + task_obj = Task(entry=g_attr, samples=task_df, meta_attr=meta_attr) + tasks.append(task_obj) + _inx += 1 + x_data.drop(self.attr_filed, axis=1, inplace=True) + samples = BaseDataSource(data_type=d_type) + samples.x = x_data + samples.y = y_data + return tasks, task_index, samples diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/task_mining.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_mining.py new file mode 100644 index 00000000..8a09eabd --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_mining.py @@ -0,0 +1,59 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + + +__all__ = ('TaskMiningBySVC', 'TaskMiningByDataAttr') + + +@ClassFactory.register(ClassType.MTL) +class TaskMiningBySVC: + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + + def __call__(self, samples: BaseDataSource): + df = samples.x + allocations = [0, ] * len(df) + legal = list( + filter(lambda col: df[col].dtype == 'float64', df.columns)) + if not len(legal): + return allocations + + allocations = list(self.task_extractor.predict(df[legal])) + return samples, allocations + + +@ClassFactory.register(ClassType.MTL) +class TaskMiningByDataAttr: + def __init__(self, task_extractor, **kwargs): + self.task_extractor = task_extractor + self.attr_filed = kwargs.get("attribute", []) + + def __call__(self, samples: BaseDataSource): + df = samples.x + meta_attr = df[self.attr_filed] + + allocations = meta_attr.apply( + lambda x: self.task_extractor.get( + "_".join( + map(lambda y: str(x[y]).replace("_", "-").replace(" ", ""), + self.attr_filed) + ), + 0), + axis=1).values.tolist() + samples.x = df.drop(self.attr_filed, axis=1) + samples.meta_attr = meta_attr + return samples, allocations diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/task_relation_discover.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_relation_discover.py new file mode 100644 index 00000000..4c8c22ea --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_relation_discover.py @@ -0,0 +1,34 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +from .artifact import Task, TaskGroup +from sedna.common.class_factory import ClassType, ClassFactory + + +__all__ = ('DefaultTaskRelationDiscover', ) + + +@ClassFactory.register(ClassType.MTL) +class DefaultTaskRelationDiscover: + def __init__(self, **kwargs): + pass + + def __call__(self, tasks: List[Task]) -> List[TaskGroup]: + tg = [] + for task in tasks: + tg_obj = TaskGroup(entry=task.entry, tasks=[task]) + tg_obj.samples = task.samples + tg.append(tg_obj) + return tg diff --git a/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py new file mode 100644 index 00000000..8ee9560a --- /dev/null +++ b/lib/sedna/algorithms/multi_task_learning/task_jobs/task_remodeling.py @@ -0,0 +1,44 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import List, Tuple +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('DefaultTaskRemodeling',) + + +@ClassFactory.register(ClassType.MTL) +class DefaultTaskRemodeling: + def __init__(self, models: list, **kwargs): + self.models = models + + def __call__(self, samples: BaseDataSource, mappings: List) \ + -> Tuple[List[BaseDataSource], List]: + mappings = np.array(mappings) + data, models = [], [] + d_type = samples.data_type + for m in np.unique(mappings): + task_df = BaseDataSource(data_type=d_type) + _inx = np.where(mappings == m) + task_df.x = samples.x.iloc[_inx] + if d_type != "test": + task_df.y = samples.y.iloc[_inx] + task_df.inx = _inx[0].tolist() + task_df.meta_attr = samples.meta_attr.iloc[_inx].values + data.append(task_df) + model = self.models[m] or self.models[0] + models.append(model) + return data, models diff --git a/lib/sedna/algorithms/unseen_task_detect/__init__.py b/lib/sedna/algorithms/unseen_task_detect/__init__.py new file mode 100644 index 00000000..dc477813 --- /dev/null +++ b/lib/sedna/algorithms/unseen_task_detect/__init__.py @@ -0,0 +1,68 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unseen Task detect Algorithms for Lifelong Learning""" +import abc +import numpy as np +from typing import List +from sedna.algorithms.multi_task_learning.task_jobs.artifact import Task +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('ModelProbeFilter', 'TaskAttrFilter') + + +class BaseFilter(metaclass=abc.ABCMeta): + """The base class to define unified interface.""" + + def __call__(self, task: Task = None): + """predict function, and it must be implemented by + different methods class. + + :param task: inference task + :return: `True` means unseen task, `False` means not an unseen task. + """ + raise NotImplementedError + + +@ClassFactory.register(ClassType.UTD, alias="ModelProbe") +class ModelProbeFilter(BaseFilter, abc.ABC): + def __init__(self): + pass + + def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): + all_proba = [] + for task in tasks: + sample = task.samples + model = task.model + if hasattr(model, "predict_proba"): + proba = model.predict_proba(sample) + all_proba.append(np.max(proba)) + return np.mean(all_proba) > threshold if all_proba else True + + +@ClassFactory.register(ClassType.UTD, alias="TaskAttr") +class TaskAttrFilter(BaseFilter, abc.ABC): + def __init__(self): + pass + + def __call__(self, tasks: List[Task] = None, **kwargs): + for task in tasks: + model_attr = list(map(list, task.model.meta_attr)) + sample_attr = list(map(list, task.samples.meta_attr)) + + if not (model_attr and sample_attr): + continue + if list(model_attr) == list(sample_attr): + return False + return True diff --git a/lib/sedna/core/lifelong_learning/__init__.py b/lib/sedna/core/lifelong_learning/__init__.py new file mode 100644 index 00000000..faaadf9e --- /dev/null +++ b/lib/sedna/core/lifelong_learning/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lifelong_learning import * diff --git a/lib/sedna/core/lifelong_learning/lifelong_learning.py b/lib/sedna/core/lifelong_learning/lifelong_learning.py new file mode 100644 index 00000000..9adb6642 --- /dev/null +++ b/lib/sedna/core/lifelong_learning/lifelong_learning.py @@ -0,0 +1,221 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import joblib +import tempfile +from sedna.backend import set_backend +from sedna.core.base import JobBase +from sedna.common.file_ops import FileOps +from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus +from sedna.common.config import Context +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.algorithms.multi_task_learning import MulTaskLearning +from sedna.service.client import KBClient + + +class LifelongLearning(JobBase): + """ + Lifelong learning Experiment + """ + + def __init__(self, + estimator, + task_definition="TaskDefinitionByDataAttr", + task_relationship_discovery=None, + task_mining=None, + task_remodeling=None, + inference_integrate=None, + unseen_task_detect="TaskAttrFilter", + + task_definition_param=None, + relationship_discovery_param=None, + task_mining_param=None, + task_remodeling_param=None, + inference_integrate_param=None, + unseen_task_detect_param=None): + e = MulTaskLearning( + estimator=estimator, + task_definition=task_definition, + task_relationship_discovery=task_relationship_discovery, + task_mining=task_mining, + task_remodeling=task_remodeling, + inference_integrate=inference_integrate, + task_definition_param=task_definition_param, + relationship_discovery_param=relationship_discovery_param, + task_mining_param=task_mining_param, + task_remodeling_param=task_remodeling_param, + inference_integrate_param=inference_integrate_param) + self.unseen_task_detect = unseen_task_detect + self.unseen_task_detect_param = e.parse_param( + unseen_task_detect_param) + config = dict( + ll_kb_server=Context.get_parameters("KB_SERVER"), + output_url=Context.get_parameters("OUTPUT_URL", "/tmp") + ) + task_index = FileOps.join_path(config['output_url'], 'index.pkl') + config['task_index'] = task_index + super(LifelongLearning, self).__init__( + estimator=estimator, config=config) + self.job_kind = K8sResourceKind.LIFELONG_JOB.value + self.kb_server = KBClient(kbserver=self.config.ll_kb_server) + + def train(self, train_data, + valid_data=None, + post_process=None, + action="initial", + **kwargs): + """ + :param train_data: data use to train model + :param valid_data: data use to valid model + :param post_process: callback function + :param action: initial - kb init, update - kb incremental update + """ + + callback_func = None + if post_process is not None: + callback_func = ClassFactory.get_cls( + ClassType.CALLBACK, post_process) + res = self.estimator.train( + train_data=train_data, + valid_data=valid_data, + **kwargs + ) # todo: Distinguishing incremental update and fully overwrite + + task_groups = self.estimator.estimator.task_groups + extractor_file = FileOps.join_path( + os.path.dirname(self.estimator.estimator.task_index_url), + "kb_extractor.pkl" + ) + try: + extractor_file = self.kb_server.upload_file(extractor_file) + except Exception as err: + self.log.error( + f"Upload task extractor_file fail {extractor_file}: {err}") + extractor_file = joblib.load(extractor_file) + for task in task_groups: + try: + model = self.kb_server.upload_file(task.model.model) + except Exception: + model_obj = set_backend( + estimator=self.estimator.estimator.base_model + ) + model = model_obj.load(task.model.model) + task.model.model = model + + task_info = { + "task_groups": task_groups, + "extractor": extractor_file + } + fd, name = tempfile.mkstemp() + joblib.dump(task_info, name) + + index_file = self.kb_server.update_db(name) + if not index_file: + self.log.error(f"KB update Fail !") + index_file = name + + FileOps.download(index_file, self.config.task_index) + if os.path.isfile(name): + os.close(fd) + os.remove(name) + task_info_res = self.estimator.model_info( + self.config.task_index, result=res, + relpath=self.config.data_path_prefix) + self.report_task_info( + None, K8sResourceKindStatus.COMPLETED.value, task_info_res) + self.log.info(f"Lifelong learning Experiment Finished, " + f"KB idnex save in {self.config.task_index}") + return callback_func(self.estimator, res) if callback_func else res + + def update(self, train_data, valid_data=None, post_process=None, **kwargs): + return self.train( + train_data=train_data, + valid_data=valid_data, + post_process=post_process, + action="update", + **kwargs + ) + + def evaluate(self, data, post_process=None, model_threshold=0.1, **kwargs): + callback_func = None + if callable(post_process): + callback_func = post_process + elif post_process is not None: + callback_func = ClassFactory.get_cls( + ClassType.CALLBACK, post_process) + task_index_url = self.get_parameters( + "MODEL_URLS", self.config.task_index) + index_url = self.estimator.estimator.task_index_url + self.log.info( + f"Download kb index from {task_index_url} to {index_url}") + FileOps.download(task_index_url, index_url) + res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) + drop_tasks = [] + for detail in tasks_detail: + scores = detail.scores + entry = detail.entry + self.log.info(f"{entry} socres: {scores}") + if any(map(lambda x: float(x) < model_threshold, scores.values())): + self.log.warn( + f"{entry} will not be deploy " + f"because scores lt {model_threshold}") + drop_tasks.append(entry) + continue + drop_task = ",".join(drop_tasks) + index_file = self.kb_server.update_task_status(drop_task, new_status=0) + if not index_file: + self.log.error(f"KB update Fail !") + index_file = str(index_url) + self.log.info( + f"upload kb index from {index_file} to {self.config.task_index}") + FileOps.download(index_file, self.config.task_index) + task_info_res = self.estimator.model_info( + self.config.task_index, result=res, + relpath=self.config.data_path_prefix) + self.report_task_info( + None, + K8sResourceKindStatus.COMPLETED.value, + task_info_res, + kind="eval") + return callback_func(res) if callback_func else res + + def inference(self, data=None, post_process=None, **kwargs): + task_index_url = self.get_parameters( + "MODEL_URLS", self.config.task_index) + index_url = self.estimator.estimator.task_index_url + FileOps.download(task_index_url, index_url) + res, tasks = self.estimator.predict( + data=data, post_process=post_process, **kwargs + ) + + is_unseen_task = False + if self.unseen_task_detect: + + try: + if callable(self.unseen_task_detect): + unseen_task_detect_algorithm = self.unseen_task_detect() + else: + unseen_task_detect_algorithm = ClassFactory.get_cls( + ClassType.UTD, self.unseen_task_detect + )() + except ValueError as err: + self.log.error( + "Lifelong learning Experiment " + "Inference [UTD] : {}".format(err)) + else: + is_unseen_task = unseen_task_detect_algorithm( + tasks=tasks, result=res, **self.unseen_task_detect_param + ) + return res, is_unseen_task, tasks