- Reduce parameters for initial - show all interfaces of lifelong learning in example - fix bugs from deep learning framework Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>tags/v0.3.1
| @@ -15,7 +15,7 @@ | |||||
| import json | import json | ||||
| from sedna.datasources import CSVDataParse | from sedna.datasources import CSVDataParse | ||||
| from sedna.common.config import Context, BaseConfig | |||||
| from sedna.common.config import BaseConfig | |||||
| from sedna.core.lifelong_learning import LifelongLearning | from sedna.core.lifelong_learning import LifelongLearning | ||||
| from interface import DATACONF, Estimator, feature_process | from interface import DATACONF, Estimator, feature_process | ||||
| @@ -26,17 +26,24 @@ def main(): | |||||
| valid_data = CSVDataParse(data_type="valid", func=feature_process) | valid_data = CSVDataParse(data_type="valid", func=feature_process) | ||||
| valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) | valid_data.parse(test_dataset_url, label=DATACONF["LABEL"]) | ||||
| attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | ||||
| model_threshold = float(Context.get_parameters('model_threshold', 0)) | |||||
| task_definition = { | |||||
| "method": "TaskDefinitionByDataAttr", | |||||
| "param": attribute | |||||
| } | |||||
| ll_job = LifelongLearning( | ll_job = LifelongLearning( | ||||
| estimator=Estimator, | estimator=Estimator, | ||||
| task_definition="TaskDefinitionByDataAttr", | |||||
| task_definition_param=attribute | |||||
| task_definition=task_definition, | |||||
| task_relationship_discovery=None, | |||||
| task_mining=None, | |||||
| task_remodeling=None, | |||||
| inference_integrate=None, | |||||
| unseen_task_detect=None | |||||
| ) | ) | ||||
| eval_experiment = ll_job.evaluate( | eval_experiment = ll_job.evaluate( | ||||
| data=valid_data, metrics="precision_score", | data=valid_data, metrics="precision_score", | ||||
| metrics_param={"average": "micro"}, | |||||
| model_threshold=model_threshold | |||||
| metrics_param={"average": "micro"} | |||||
| ) | ) | ||||
| return eval_experiment | return eval_experiment | ||||
| @@ -26,17 +26,29 @@ from interface import DATACONF, Estimator, feature_process | |||||
| def main(): | def main(): | ||||
| utd = Context.get_parameters("UTD_NAME", "TaskAttr") | |||||
| utd = Context.get_parameters("UTD_NAME", "TaskAttrFilter") | |||||
| attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | ||||
| utd_parameters = Context.get_parameters("UTD_PARAMETERS", {}) | utd_parameters = Context.get_parameters("UTD_PARAMETERS", {}) | ||||
| ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp") | ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp") | ||||
| ll_job = LifelongLearning( | |||||
| task_mining = { | |||||
| "method": "TaskMiningByDataAttr", | |||||
| "param": attribute | |||||
| } | |||||
| unseen_task_detect = { | |||||
| "method": utd, | |||||
| "param": utd_parameters | |||||
| } | |||||
| ll_service = LifelongLearning( | |||||
| estimator=Estimator, | estimator=Estimator, | ||||
| task_mining="TaskMiningByDataAttr", | |||||
| task_mining_param=attribute, | |||||
| unseen_task_detect=utd, | |||||
| unseen_task_detect_param=utd_parameters) | |||||
| task_mining=task_mining, | |||||
| task_definition=None, | |||||
| task_relationship_discovery=None, | |||||
| task_remodeling=None, | |||||
| inference_integrate=None, | |||||
| unseen_task_detect=unseen_task_detect) | |||||
| infer_dataset_url = Context.get_parameters('infer_dataset_url') | infer_dataset_url = Context.get_parameters('infer_dataset_url') | ||||
| file_handle = open(infer_dataset_url, "r", encoding="utf-8") | file_handle = open(infer_dataset_url, "r", encoding="utf-8") | ||||
| @@ -60,12 +72,14 @@ def main(): | |||||
| rows = reader[0] | rows = reader[0] | ||||
| data = dict(zip(header, rows)) | data = dict(zip(header, rows)) | ||||
| infer_data.parse(data, label=DATACONF["LABEL"]) | infer_data.parse(data, label=DATACONF["LABEL"]) | ||||
| rsl, is_unseen, target_task = ll_job.inference(infer_data) | |||||
| rsl, is_unseen, target_task = ll_service.inference(infer_data) | |||||
| rows.append(list(rsl)[0]) | rows.append(list(rsl)[0]) | ||||
| output = "\t".join(map(str, rows)) + "\n" | |||||
| if is_unseen: | if is_unseen: | ||||
| unseen_sample.write("\t".join(map(str, rows)) + "\n") | |||||
| output_sample.write("\t".join(map(str, rows)) + "\n") | |||||
| unseen_sample.write(output) | |||||
| output_sample.write(output) | |||||
| unseen_sample.close() | unseen_sample.close() | ||||
| output_sample.close() | output_sample.close() | ||||
| @@ -28,13 +28,23 @@ def main(): | |||||
| train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) | train_data.parse(train_dataset_url, label=DATACONF["LABEL"]) | ||||
| attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) | ||||
| early_stopping_rounds = int( | early_stopping_rounds = int( | ||||
| Context.get_parameters( | |||||
| "early_stopping_rounds", 100)) | |||||
| Context.get_parameters("early_stopping_rounds", 100) | |||||
| ) | |||||
| metric_name = Context.get_parameters("metric_name", "mlogloss") | metric_name = Context.get_parameters("metric_name", "mlogloss") | ||||
| task_definition = { | |||||
| "method": "TaskDefinitionByDataAttr", | |||||
| "param": attribute | |||||
| } | |||||
| ll_job = LifelongLearning( | ll_job = LifelongLearning( | ||||
| estimator=Estimator, | estimator=Estimator, | ||||
| task_definition="TaskDefinitionByDataAttr", | |||||
| task_definition_param=attribute | |||||
| task_definition=task_definition, | |||||
| task_relationship_discovery=None, | |||||
| task_mining=None, | |||||
| task_remodeling=None, | |||||
| inference_integrate=None, | |||||
| unseen_task_detect=None | |||||
| ) | ) | ||||
| train_experiment = ll_job.train( | train_experiment = ll_job.train( | ||||
| train_data=train_data, | train_data=train_data, | ||||
| @@ -12,15 +12,15 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| import os | |||||
| import json | import json | ||||
| import pandas as pd | |||||
| from sedna.datasources import BaseDataSource | from sedna.datasources import BaseDataSource | ||||
| from sedna.backend import set_backend | from sedna.backend import set_backend | ||||
| from sedna.common.log import LOGGER | from sedna.common.log import LOGGER | ||||
| from sedna.common.file_ops import FileOps | |||||
| from sedna.common.config import Context | from sedna.common.config import Context | ||||
| from sedna.common.constant import KBResourceConstant | from sedna.common.constant import KBResourceConstant | ||||
| from sedna.common.file_ops import FileOps | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | from sedna.common.class_factory import ClassFactory, ClassType | ||||
| from .task_jobs.artifact import Model, Task, TaskGroup | from .task_jobs.artifact import Model, Task, TaskGroup | ||||
| @@ -36,118 +36,113 @@ class MulTaskLearning: | |||||
| def __init__(self, | def __init__(self, | ||||
| estimator=None, | estimator=None, | ||||
| task_definition="TaskDefinitionByDataAttr", | |||||
| task_definition=None, | |||||
| task_relationship_discovery=None, | task_relationship_discovery=None, | ||||
| task_mining=None, | task_mining=None, | ||||
| task_remodeling=None, | task_remodeling=None, | ||||
| inference_integrate=None, | |||||
| task_definition_param=None, | |||||
| relationship_discovery_param=None, | |||||
| task_mining_param=None, | |||||
| task_remodeling_param=None, | |||||
| inference_integrate_param=None | |||||
| inference_integrate=None | |||||
| ): | ): | ||||
| if not task_relationship_discovery: | |||||
| task_relationship_discovery = "DefaultTaskRelationDiscover" | |||||
| if not task_remodeling: | |||||
| task_remodeling = "DefaultTaskRemodeling" | |||||
| if not inference_integrate: | |||||
| inference_integrate = "DefaultInferenceIntegrate" | |||||
| self.method_selection = dict( | |||||
| task_definition=task_definition, | |||||
| task_relationship_discovery=task_relationship_discovery, | |||||
| task_mining=task_mining, | |||||
| task_remodeling=task_remodeling, | |||||
| inference_integrate=inference_integrate, | |||||
| task_definition_param=task_definition_param, | |||||
| task_relationship_discovery_param=relationship_discovery_param, | |||||
| task_mining_param=task_mining_param, | |||||
| task_remodeling_param=task_remodeling_param, | |||||
| inference_integrate_param=inference_integrate_param) | |||||
| self.task_definition = task_definition or { | |||||
| "method": "TaskDefinitionByDataAttr" | |||||
| } | |||||
| self.task_relationship_discovery = task_relationship_discovery or { | |||||
| "method": "DefaultTaskRelationDiscover" | |||||
| } | |||||
| self.task_mining = task_mining or {} | |||||
| self.task_remodeling = task_remodeling or { | |||||
| "method": "DefaultTaskRemodeling" | |||||
| } | |||||
| self.inference_integrate = inference_integrate or { | |||||
| "method": "DefaultInferenceIntegrate" | |||||
| } | |||||
| self.models = None | self.models = None | ||||
| self.extractor = None | self.extractor = None | ||||
| self.base_model = estimator | self.base_model = estimator | ||||
| self.task_groups = None | self.task_groups = None | ||||
| self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value | self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value | ||||
| self.min_train_sample = int( | |||||
| Context.get_parameters( | |||||
| "MIN_TRAIN_SAMPLE", | |||||
| KBResourceConstant.MIN_TRAIN_SAMPLE.value | |||||
| ) | |||||
| ) | |||||
| self.min_train_sample = int(Context.get_parameters( | |||||
| "MIN_TRAIN_SAMPLE", KBResourceConstant.MIN_TRAIN_SAMPLE.value | |||||
| )) | |||||
| @staticmethod | @staticmethod | ||||
| def parse_param(param_str): | def parse_param(param_str): | ||||
| if not param_str: | if not param_str: | ||||
| return {} | return {} | ||||
| if isinstance(param_str, dict): | |||||
| return param_str | |||||
| try: | try: | ||||
| raw_dict = json.loads(param_str, encoding="utf-8") | raw_dict = json.loads(param_str, encoding="utf-8") | ||||
| except json.JSONDecodeError: | except json.JSONDecodeError: | ||||
| raw_dict = {} | raw_dict = {} | ||||
| return raw_dict | return raw_dict | ||||
| def task_definition(self, samples): | |||||
| def _task_definition(self, samples): | |||||
| """ | """ | ||||
| Task attribute extractor and multi-task definition | Task attribute extractor and multi-task definition | ||||
| """ | """ | ||||
| method_name = self.method_selection.get( | |||||
| "task_definition", "TaskDefinitionByDataAttr") | |||||
| method_name = self.task_definition.get( | |||||
| "method", "TaskDefinitionByDataAttr" | |||||
| ) | |||||
| extend_param = self.parse_param( | extend_param = self.parse_param( | ||||
| self.method_selection.get("task_definition_param")) | |||||
| self.task_definition.get("param") | |||||
| ) | |||||
| method_cls = ClassFactory.get_cls( | method_cls = ClassFactory.get_cls( | ||||
| ClassType.MTL, method_name)(**extend_param) | ClassType.MTL, method_name)(**extend_param) | ||||
| return method_cls(samples) | return method_cls(samples) | ||||
| def task_relationship_discovery(self, tasks): | |||||
| def _task_relationship_discovery(self, tasks): | |||||
| """ | """ | ||||
| Merge tasks from task_definition | Merge tasks from task_definition | ||||
| """ | """ | ||||
| method_name = self.method_selection.get("task_relationship_discovery") | |||||
| method_name = self.task_relationship_discovery.get("method") | |||||
| extend_param = self.parse_param( | extend_param = self.parse_param( | ||||
| self.method_selection.get("task_relationship_discovery_param") | |||||
| self.task_relationship_discovery.get("param") | |||||
| ) | ) | ||||
| method_cls = ClassFactory.get_cls( | method_cls = ClassFactory.get_cls( | ||||
| ClassType.MTL, method_name)(**extend_param) | ClassType.MTL, method_name)(**extend_param) | ||||
| return method_cls(tasks) | return method_cls(tasks) | ||||
| def task_mining(self, samples): | |||||
| def _task_mining(self, samples): | |||||
| """ | """ | ||||
| Mining tasks of inference sample base on task attribute extractor | Mining tasks of inference sample base on task attribute extractor | ||||
| """ | """ | ||||
| method_name = self.method_selection.get("task_mining") | |||||
| method_name = self.task_mining.get("method") | |||||
| extend_param = self.parse_param( | extend_param = self.parse_param( | ||||
| self.method_selection.get("task_mining_param")) | |||||
| self.task_mining.get("param") | |||||
| ) | |||||
| if not method_name: | if not method_name: | ||||
| task_definition = self.method_selection.get( | |||||
| "task_definition", "TaskDefinitionByDataAttr") | |||||
| task_definition = self.task_definition.get( | |||||
| "method", "TaskDefinitionByDataAttr" | |||||
| ) | |||||
| method_name = self._method_pair.get(task_definition, | method_name = self._method_pair.get(task_definition, | ||||
| 'TaskMiningByDataAttr') | 'TaskMiningByDataAttr') | ||||
| extend_param = self.parse_param( | extend_param = self.parse_param( | ||||
| self.method_selection.get("task_definition_param")) | |||||
| self.task_definition.get("param")) | |||||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | ||||
| task_extractor=self.extractor, **extend_param | task_extractor=self.extractor, **extend_param | ||||
| ) | ) | ||||
| return method_cls(samples=samples) | return method_cls(samples=samples) | ||||
| def task_remodeling(self, samples, mappings): | |||||
| def _task_remodeling(self, samples, mappings): | |||||
| """ | """ | ||||
| Remodeling tasks from task mining | Remodeling tasks from task mining | ||||
| """ | """ | ||||
| method_name = self.method_selection.get("task_remodeling") | |||||
| method_name = self.task_remodeling.get("method") | |||||
| extend_param = self.parse_param( | extend_param = self.parse_param( | ||||
| self.method_selection.get("task_remodeling_param")) | |||||
| self.task_remodeling.get("param")) | |||||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | ||||
| models=self.models, **extend_param) | models=self.models, **extend_param) | ||||
| return method_cls(samples=samples, mappings=mappings) | return method_cls(samples=samples, mappings=mappings) | ||||
| def inference_integrate(self, tasks): | |||||
| def _inference_integrate(self, tasks): | |||||
| """ | """ | ||||
| Aggregate inference results from target models | Aggregate inference results from target models | ||||
| """ | """ | ||||
| method_name = self.method_selection.get("inference_integrate") | |||||
| method_name = self.inference_integrate.get("method") | |||||
| extend_param = self.parse_param( | extend_param = self.parse_param( | ||||
| self.method_selection.get("inference_integrate_param")) | |||||
| self.inference_integrate.get("param")) | |||||
| method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | method_cls = ClassFactory.get_cls(ClassType.MTL, method_name)( | ||||
| models=self.models, **extend_param) | models=self.models, **extend_param) | ||||
| return method_cls(tasks=tasks) if method_cls else tasks | return method_cls(tasks=tasks) if method_cls else tasks | ||||
| @@ -155,12 +150,12 @@ class MulTaskLearning: | |||||
| def train(self, train_data: BaseDataSource, | def train(self, train_data: BaseDataSource, | ||||
| valid_data: BaseDataSource = None, | valid_data: BaseDataSource = None, | ||||
| post_process=None, **kwargs): | post_process=None, **kwargs): | ||||
| tasks, task_extractor, train_data = self.task_definition(train_data) | |||||
| tasks, task_extractor, train_data = self._task_definition(train_data) | |||||
| self.extractor = task_extractor | self.extractor = task_extractor | ||||
| task_groups = self.task_relationship_discovery(tasks) | |||||
| task_groups = self._task_relationship_discovery(tasks) | |||||
| self.models = [] | self.models = [] | ||||
| callback = None | callback = None | ||||
| if post_process: | |||||
| if isinstance(post_process, str): | |||||
| callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() | callback = ClassFactory.get_cls(ClassType.CALLBACK, post_process)() | ||||
| self.task_groups = [] | self.task_groups = [] | ||||
| feedback = {} | feedback = {} | ||||
| @@ -181,18 +176,31 @@ class MulTaskLearning: | |||||
| continue | continue | ||||
| LOGGER.info(f"MTL Train start {i} : {task.entry}") | LOGGER.info(f"MTL Train start {i} : {task.entry}") | ||||
| model_obj = set_backend(estimator=self.base_model) | |||||
| res = model_obj.train(train_data=task.samples, **kwargs) | |||||
| if callback: | |||||
| res = callback(model_obj, res) | |||||
| model_path = model_obj.save(model_name=f"{task.entry}.model") | |||||
| model = Model(index=i, entry=task.entry, | |||||
| model=model_path, result=res) | |||||
| model.meta_attr = [t.meta_attr for t in task.tasks] | |||||
| model = None | |||||
| for t in task.tasks: # if model has train in tasks | |||||
| if not (t.model and t.result): | |||||
| continue | |||||
| model_path = t.model.save(model_name=f"{task.entry}.model") | |||||
| t.model = model_path | |||||
| model = Model(index=i, entry=t.entry, | |||||
| model=model_path, result=t.result) | |||||
| model.meta_attr = t.meta_attr | |||||
| break | |||||
| if not model: | |||||
| model_obj = set_backend(estimator=self.base_model) | |||||
| res = model_obj.train(train_data=task.samples, **kwargs) | |||||
| if callback: | |||||
| res = callback(model_obj, res) | |||||
| model_path = model_obj.save(model_name=f"{task.entry}.model") | |||||
| model = Model(index=i, entry=task.entry, | |||||
| model=model_path, result=res) | |||||
| model.meta_attr = [t.meta_attr for t in task.tasks] | |||||
| task.model = model | task.model = model | ||||
| self.models.append(model) | self.models.append(model) | ||||
| feedback[task.entry] = res | |||||
| feedback[task.entry] = model.result | |||||
| self.task_groups.append(task) | self.task_groups.append(task) | ||||
| if len(rare_task): | if len(rare_task): | ||||
| model_obj = set_backend(estimator=self.base_model) | model_obj = set_backend(estimator=self.base_model) | ||||
| res = model_obj.train(train_data=train_data, **kwargs) | res = model_obj.train(train_data=train_data, **kwargs) | ||||
| @@ -240,11 +248,13 @@ class MulTaskLearning: | |||||
| def predict(self, data: BaseDataSource, | def predict(self, data: BaseDataSource, | ||||
| post_process=None, **kwargs): | post_process=None, **kwargs): | ||||
| if not (self.models and self.extractor): | if not (self.models and self.extractor): | ||||
| self.load() | self.load() | ||||
| data, mappings = self.task_mining(samples=data) | |||||
| samples, models = self.task_remodeling(samples=data, mappings=mappings) | |||||
| data, mappings = self._task_mining(samples=data) | |||||
| samples, models = self._task_remodeling(samples=data, | |||||
| mappings=mappings) | |||||
| callback = None | callback = None | ||||
| if post_process: | if post_process: | ||||
| @@ -255,17 +265,19 @@ class MulTaskLearning: | |||||
| m = models[inx] | m = models[inx] | ||||
| if not isinstance(m, Model): | if not isinstance(m, Model): | ||||
| continue | continue | ||||
| model_obj = set_backend(estimator=self.base_model) | |||||
| evaluator = model_obj.load(m.model) if isinstance( | |||||
| m.model, str) else m.model | |||||
| pred = evaluator.predict(df.x, kwargs=kwargs) | |||||
| if isinstance(m.model, str): | |||||
| evaluator = set_backend(estimator=self.base_model) | |||||
| evaluator.load(m.model) | |||||
| else: | |||||
| evaluator = m.model | |||||
| pred = evaluator.predict(df.x, **kwargs) | |||||
| if callable(callback): | if callable(callback): | ||||
| pred = callback(pred, df) | pred = callback(pred, df) | ||||
| task = Task(entry=m.entry, samples=df) | task = Task(entry=m.entry, samples=df) | ||||
| task.result = pred | task.result = pred | ||||
| task.model = m | task.model = m | ||||
| tasks.append(task) | tasks.append(task) | ||||
| res = self.inference_integrate(tasks) | |||||
| res = self._inference_integrate(tasks) | |||||
| return res, tasks | return res, tasks | ||||
| def evaluate(self, data: BaseDataSource, | def evaluate(self, data: BaseDataSource, | ||||
| @@ -294,7 +306,7 @@ class MulTaskLearning: | |||||
| m_dict = { | m_dict = { | ||||
| metrics: getattr(sk_metrics, metrics, sk_metrics.log_loss) | metrics: getattr(sk_metrics, metrics, sk_metrics.log_loss) | ||||
| } | } | ||||
| elif isinstance(metrics, dict): # if metrics with name | |||||
| elif isinstance(metrics, dict): # if metrics with name | |||||
| for k, v in metrics.items(): | for k, v in metrics.items(): | ||||
| if isinstance(v, str): | if isinstance(v, str): | ||||
| v = getattr(sk_metrics, v) | v = getattr(sk_metrics, v) | ||||
| @@ -307,8 +319,9 @@ class MulTaskLearning: | |||||
| } | } | ||||
| metrics_param = {"average": "micro"} | metrics_param = {"average": "micro"} | ||||
| data.x['pred_y'] = result | |||||
| data.x['real_y'] = data.y | |||||
| if isinstance(data.x, pd.DataFrame): | |||||
| data.x['pred_y'] = result | |||||
| data.x['real_y'] = data.y | |||||
| if not metrics_param: | if not metrics_param: | ||||
| metrics_param = {} | metrics_param = {} | ||||
| elif isinstance(metrics_param, str): | elif isinstance(metrics_param, str): | ||||
| @@ -22,6 +22,7 @@ class Task: | |||||
| self.entry = entry | self.entry = entry | ||||
| self.samples = samples | self.samples = samples | ||||
| self.meta_attr = meta_attr | self.meta_attr = meta_attr | ||||
| self.test_samples = None # assign on task definition and use in TRD | |||||
| self.model = None # assign on running | self.model = None # assign on running | ||||
| self.result = None # assign on running | self.result = None # assign on running | ||||
| @@ -15,6 +15,7 @@ | |||||
| from typing import List | from typing import List | ||||
| import numpy as np | import numpy as np | ||||
| import pandas as pd | |||||
| from sedna.datasources import BaseDataSource | from sedna.datasources import BaseDataSource | ||||
| from sedna.common.class_factory import ClassFactory, ClassType | from sedna.common.class_factory import ClassFactory, ClassType | ||||
| @@ -34,11 +35,18 @@ class DefaultTaskRemodeling: | |||||
| for m in np.unique(mappings): | for m in np.unique(mappings): | ||||
| task_df = BaseDataSource(data_type=d_type) | task_df = BaseDataSource(data_type=d_type) | ||||
| _inx = np.where(mappings == m) | _inx = np.where(mappings == m) | ||||
| task_df.x = samples.x.iloc[_inx] | |||||
| if isinstance(samples.x, pd.DataFrame): | |||||
| task_df.x = samples.x.iloc[_inx] | |||||
| else: | |||||
| task_df.x = np.array(samples.x)[_inx] | |||||
| if d_type != "test": | if d_type != "test": | ||||
| task_df.y = samples.y.iloc[_inx] | |||||
| if isinstance(samples.x, pd.DataFrame): | |||||
| task_df.y = samples.y.iloc[_inx] | |||||
| else: | |||||
| task_df.y = np.array(samples.y)[_inx] | |||||
| task_df.inx = _inx[0].tolist() | task_df.inx = _inx[0].tolist() | ||||
| task_df.meta_attr = samples.meta_attr.iloc[_inx].values | |||||
| if samples.meta_attr is not None: | |||||
| task_df.meta_attr = np.array(samples.meta_attr)[_inx] | |||||
| data.append(task_df) | data.append(task_df) | ||||
| model = self.models[m] or self.models[0] | model = self.models[m] or self.models[0] | ||||
| models.append(model) | models.append(model) | ||||
| @@ -12,60 +12,4 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """Unseen Task detect Algorithms for Lifelong Learning""" | |||||
| import abc | |||||
| from typing import List | |||||
| import numpy as np | |||||
| from sedna.algorithms.multi_task_learning.task_jobs.artifact import Task | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | |||||
| __all__ = ('ModelProbeFilter', 'TaskAttrFilter') | |||||
| class BaseFilter(metaclass=abc.ABCMeta): | |||||
| """The base class to define unified interface.""" | |||||
| def __call__(self, task: Task = None): | |||||
| """predict function, and it must be implemented by | |||||
| different methods class. | |||||
| :param task: inference task | |||||
| :return: `True` means unseen task, `False` means not an unseen task. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @ClassFactory.register(ClassType.UTD, alias="ModelProbe") | |||||
| class ModelProbeFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): | |||||
| all_proba = [] | |||||
| for task in tasks: | |||||
| sample = task.samples | |||||
| model = task.model | |||||
| if hasattr(model, "predict_proba"): | |||||
| proba = model.predict_proba(sample) | |||||
| all_proba.append(np.max(proba)) | |||||
| return np.mean(all_proba) > threshold if all_proba else True | |||||
| @ClassFactory.register(ClassType.UTD, alias="TaskAttr") | |||||
| class TaskAttrFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __call__(self, tasks: List[Task] = None, **kwargs): | |||||
| for task in tasks: | |||||
| model_attr = list(map(list, task.model.meta_attr)) | |||||
| sample_attr = list(map(list, task.samples.meta_attr)) | |||||
| if not (model_attr and sample_attr): | |||||
| continue | |||||
| if list(model_attr) == list(sample_attr): | |||||
| return False | |||||
| return True | |||||
| from .unseen_task_detect import ModelProbeFilter, TaskAttrFilter | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Unseen task detection algorithms for Lifelong Learning""" | |||||
| import abc | |||||
| from typing import List | |||||
| import numpy as np | |||||
| from sedna.algorithms.multi_task_learning.task_jobs.artifact import Task | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | |||||
| __all__ = ('ModelProbeFilter', 'TaskAttrFilter') | |||||
| class BaseFilter(metaclass=abc.ABCMeta): | |||||
| """The base class to define unified interface.""" | |||||
| def __call__(self, task: Task = None): | |||||
| """predict function, and it must be implemented by | |||||
| different methods class. | |||||
| :param task: inference task | |||||
| :return: `True` means unseen task, `False` means not an unseen task. | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @ClassFactory.register(ClassType.UTD) | |||||
| class ModelProbeFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __call__(self, tasks: List[Task] = None, threshold=0.5, **kwargs): | |||||
| all_proba = [] | |||||
| for task in tasks: | |||||
| sample = task.samples | |||||
| model = task.model | |||||
| if hasattr(model, "predict_proba"): | |||||
| proba = model.predict_proba(sample) | |||||
| all_proba.append(np.max(proba)) | |||||
| return np.mean(all_proba) > threshold if all_proba else True | |||||
| @ClassFactory.register(ClassType.UTD) | |||||
| class TaskAttrFilter(BaseFilter, abc.ABC): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __call__(self, tasks: List[Task] = None, **kwargs): | |||||
| for task in tasks: | |||||
| model_attr = list(map(list, task.model.meta_attr)) | |||||
| sample_attr = list(map(list, task.samples.meta_attr)) | |||||
| if not (model_attr and sample_attr): | |||||
| continue | |||||
| if list(model_attr) == list(sample_attr): | |||||
| return False | |||||
| return True | |||||
| @@ -35,7 +35,7 @@ class BackendBase: | |||||
| if self.default_name: | if self.default_name: | ||||
| return self.default_name | return self.default_name | ||||
| model_postfix = {"pytorch": ".pth", | model_postfix = {"pytorch": ".pth", | ||||
| "keras": ".h5", "tensorflow": ".pb"} | |||||
| "keras": ".pb", "tensorflow": ".pb"} | |||||
| continue_flag = "_finetune_" if self.fine_tune else "" | continue_flag = "_finetune_" if self.fine_tune else "" | ||||
| post_fix = model_postfix.get(self.framework, ".pkl") | post_fix = model_postfix.get(self.framework, ".pkl") | ||||
| return f"model{continue_flag}{self.framework}{post_fix}" | return f"model{continue_flag}{self.framework}{post_fix}" | ||||
| @@ -107,9 +107,11 @@ class BackendBase: | |||||
| self.model_save_path, mname = os.path.split(self.model_save_path) | self.model_save_path, mname = os.path.split(self.model_save_path) | ||||
| model_path = FileOps.join_path(self.model_save_path, mname) | model_path = FileOps.join_path(self.model_save_path, mname) | ||||
| if model_url: | if model_url: | ||||
| FileOps.download(model_url, model_path) | |||||
| model_path = FileOps.download(model_url, model_path) | |||||
| self.has_load = True | self.has_load = True | ||||
| if not (hasattr(self.estimator, "load") | |||||
| and os.path.exists(model_path)): | |||||
| return | |||||
| return self.estimator.load(model_url=model_path) | return self.estimator.load(model_url=model_path) | ||||
| def set_weights(self, weights): | def set_weights(self, weights): | ||||
| @@ -25,10 +25,12 @@ if hasattr(tf, "compat"): | |||||
| # version 2.0 tf | # version 2.0 tf | ||||
| ConfigProto = tf.compat.v1.ConfigProto | ConfigProto = tf.compat.v1.ConfigProto | ||||
| Session = tf.compat.v1.Session | Session = tf.compat.v1.Session | ||||
| reset_default_graph = tf.compat.v1.reset_default_graph | |||||
| else: | else: | ||||
| # version 1 | # version 1 | ||||
| ConfigProto = tf.ConfigProto | ConfigProto = tf.ConfigProto | ||||
| Session = tf.Session | Session = tf.Session | ||||
| reset_default_graph = tf.reset_default_graph | |||||
| class TFBackend(BackendBase): | class TFBackend(BackendBase): | ||||
| @@ -64,24 +66,27 @@ class TFBackend(BackendBase): | |||||
| self.estimator = self.estimator() | self.estimator = self.estimator() | ||||
| if self.fine_tune and FileOps.exists(self.model_save_path): | if self.fine_tune and FileOps.exists(self.model_save_path): | ||||
| self.finetune() | self.finetune() | ||||
| self.has_load = True | |||||
| varkw = self.parse_kwargs(self.estimator.train, **kwargs) | |||||
| return self.estimator.train( | return self.estimator.train( | ||||
| train_data=train_data, | train_data=train_data, | ||||
| valid_data=valid_data, | valid_data=valid_data, | ||||
| **kwargs | |||||
| **varkw | |||||
| ) | ) | ||||
| def predict(self, data, **kwargs): | def predict(self, data, **kwargs): | ||||
| if not self.has_load: | if not self.has_load: | ||||
| tf.reset_default_graph() | |||||
| self.sess = self.load() | |||||
| return self.estimator.predict(data, **kwargs) | |||||
| reset_default_graph() | |||||
| self.load() | |||||
| varkw = self.parse_kwargs(self.estimator.predict, **kwargs) | |||||
| return self.estimator.predict(data=data, **varkw) | |||||
| def evaluate(self, data, **kwargs): | def evaluate(self, data, **kwargs): | ||||
| if not self.has_load: | if not self.has_load: | ||||
| tf.reset_default_graph() | |||||
| self.sess = self.load() | |||||
| return self.estimator.evaluate(data, **kwargs) | |||||
| reset_default_graph() | |||||
| self.load() | |||||
| varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs) | |||||
| return self.estimator.evaluate(data, **varkw) | |||||
| def finetune(self): | def finetune(self): | ||||
| """todo: no support yet""" | """todo: no support yet""" | ||||
| @@ -99,23 +104,25 @@ class TFBackend(BackendBase): | |||||
| def model_info(self, model, relpath=None, result=None): | def model_info(self, model, relpath=None, result=None): | ||||
| ckpt = os.path.dirname(model) | ckpt = os.path.dirname(model) | ||||
| _, _type = os.path.splitext(model) | |||||
| if relpath: | if relpath: | ||||
| _url = FileOps.remove_path_prefix(model, relpath) | _url = FileOps.remove_path_prefix(model, relpath) | ||||
| ckpt_url = FileOps.remove_path_prefix(ckpt, relpath) | ckpt_url = FileOps.remove_path_prefix(ckpt, relpath) | ||||
| else: | else: | ||||
| _url = model | _url = model | ||||
| ckpt_url = ckpt | ckpt_url = ckpt | ||||
| results = [ | |||||
| { | |||||
| "format": "pb", | |||||
| _type = _type.lstrip(".").lower() | |||||
| results = [{ | |||||
| "format": _type, | |||||
| "url": _url, | "url": _url, | ||||
| "metrics": result | "metrics": result | ||||
| }, { | |||||
| }] | |||||
| if _type == "pb": # report ckpt path when model save as pb file | |||||
| results.append({ | |||||
| "format": "ckpt", | "format": "ckpt", | ||||
| "url": ckpt_url, | "url": ckpt_url, | ||||
| "metrics": result | "metrics": result | ||||
| } | |||||
| ] | |||||
| }) | |||||
| return results | return results | ||||
| @@ -12,23 +12,8 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| import logging | |||||
| from enum import Enum | from enum import Enum | ||||
| LOG = logging.getLogger(__name__) | |||||
| class ModelType(Enum): | |||||
| GlobalModel = 1 | |||||
| PersonalizedModel = 2 | |||||
| class Framework(Enum): | |||||
| Tensorflow = "tensorflow" | |||||
| Keras = "keras" | |||||
| Pytorch = "pytorch" | |||||
| Mindspore = "mindspore" | |||||
| class K8sResourceKind(Enum): | class K8sResourceKind(Enum): | ||||
| DEFAULT = "default" | DEFAULT = "default" | ||||
| @@ -27,35 +27,7 @@ from sedna.common.class_factory import ClassFactory, ClassType | |||||
| __all__ = ('JobBase',) | __all__ = ('JobBase',) | ||||
| class DistributedWorker: | |||||
| """"Class of Distributed Worker use to manage all jobs""" | |||||
| # original params | |||||
| __worker_path__ = None | |||||
| __worker_module__ = None | |||||
| # id params | |||||
| __worker_id__ = 0 | |||||
| def __init__(self): | |||||
| DistributedWorker.__worker_id__ += 1 | |||||
| self._worker_id = DistributedWorker.__worker_id__ | |||||
| self.timeout = 0 | |||||
| @property | |||||
| def worker_id(self): | |||||
| """Property: worker_id.""" | |||||
| return self._worker_id | |||||
| @worker_id.setter | |||||
| def worker_id(self, value): | |||||
| """Setter: set worker_id with value. | |||||
| :param value: worker id | |||||
| :type value: int | |||||
| """ | |||||
| self._worker_id = value | |||||
| class JobBase(DistributedWorker): | |||||
| class JobBase: | |||||
| """ sedna feature base class """ | """ sedna feature base class """ | ||||
| parameters = Context | parameters = Context | ||||
| @@ -68,8 +40,7 @@ class JobBase(DistributedWorker): | |||||
| self.estimator = set_backend(estimator=estimator, config=self.config) | self.estimator = set_backend(estimator=estimator, config=self.config) | ||||
| self.job_kind = K8sResourceKind.DEFAULT.value | self.job_kind = K8sResourceKind.DEFAULT.value | ||||
| self.job_name = self.config.job_name or self.config.service_name | self.job_name = self.config.job_name or self.config.service_name | ||||
| work_name = f"{self.job_name}-{self.worker_id}" | |||||
| self.worker_name = self.config.worker_name or work_name | |||||
| self.worker_name = self.config.worker_name or self.job_name | |||||
| @property | @property | ||||
| def initial_hem(self): | def initial_hem(self): | ||||
| @@ -18,8 +18,7 @@ import tempfile | |||||
| from sedna.backend import set_backend | from sedna.backend import set_backend | ||||
| from sedna.core.base import JobBase | from sedna.core.base import JobBase | ||||
| from sedna.common.file_ops import FileOps | from sedna.common.file_ops import FileOps | ||||
| from sedna.common.constant import K8sResourceKind | |||||
| from sedna.common.constant import K8sResourceKindStatus | |||||
| from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus | |||||
| from sedna.common.constant import KBResourceConstant | from sedna.common.constant import KBResourceConstant | ||||
| from sedna.common.config import Context | from sedna.common.config import Context | ||||
| from sedna.common.class_factory import ClassType, ClassFactory | from sedna.common.class_factory import ClassType, ClassFactory | ||||
| @@ -34,43 +33,39 @@ class LifelongLearning(JobBase): | |||||
| def __init__(self, | def __init__(self, | ||||
| estimator, | estimator, | ||||
| task_definition="TaskDefinitionByDataAttr", | |||||
| task_definition=None, | |||||
| task_relationship_discovery=None, | task_relationship_discovery=None, | ||||
| task_mining=None, | task_mining=None, | ||||
| task_remodeling=None, | task_remodeling=None, | ||||
| inference_integrate=None, | inference_integrate=None, | ||||
| unseen_task_detect="TaskAttrFilter", | |||||
| task_definition_param=None, | |||||
| relationship_discovery_param=None, | |||||
| task_mining_param=None, | |||||
| task_remodeling_param=None, | |||||
| inference_integrate_param=None, | |||||
| unseen_task_detect_param=None): | |||||
| unseen_task_detect=None): | |||||
| if not task_definition: | |||||
| task_definition = { | |||||
| "method": "TaskDefinitionByDataAttr" | |||||
| } | |||||
| if not unseen_task_detect: | |||||
| unseen_task_detect = { | |||||
| "method": "TaskAttrFilter" | |||||
| } | |||||
| e = MulTaskLearning( | e = MulTaskLearning( | ||||
| estimator=estimator, | estimator=estimator, | ||||
| task_definition=task_definition, | task_definition=task_definition, | ||||
| task_relationship_discovery=task_relationship_discovery, | task_relationship_discovery=task_relationship_discovery, | ||||
| task_mining=task_mining, | task_mining=task_mining, | ||||
| task_remodeling=task_remodeling, | task_remodeling=task_remodeling, | ||||
| inference_integrate=inference_integrate, | |||||
| task_definition_param=task_definition_param, | |||||
| relationship_discovery_param=relationship_discovery_param, | |||||
| task_mining_param=task_mining_param, | |||||
| task_remodeling_param=task_remodeling_param, | |||||
| inference_integrate_param=inference_integrate_param) | |||||
| self.unseen_task_detect = unseen_task_detect | |||||
| inference_integrate=inference_integrate) | |||||
| self.unseen_task_detect = unseen_task_detect.get("method", | |||||
| "TaskAttrFilter") | |||||
| self.unseen_task_detect_param = e.parse_param( | self.unseen_task_detect_param = e.parse_param( | ||||
| unseen_task_detect_param | |||||
| unseen_task_detect.get("param", {}) | |||||
| ) | ) | ||||
| config = dict( | config = dict( | ||||
| ll_kb_server=Context.get_parameters("KB_SERVER"), | ll_kb_server=Context.get_parameters("KB_SERVER"), | ||||
| output_url=Context.get_parameters("OUTPUT_URL", "/tmp") | output_url=Context.get_parameters("OUTPUT_URL", "/tmp") | ||||
| ) | ) | ||||
| task_index = FileOps.join_path( | |||||
| config['output_url'], | |||||
| KBResourceConstant.KB_INDEX_NAME | |||||
| ) | |||||
| task_index = FileOps.join_path(config['output_url'], | |||||
| KBResourceConstant.KB_INDEX_NAME) | |||||
| config['task_index'] = task_index | config['task_index'] = task_index | ||||
| super(LifelongLearning, self).__init__( | super(LifelongLearning, self).__init__( | ||||
| estimator=e, config=config | estimator=e, config=config | ||||
| @@ -100,30 +95,62 @@ class LifelongLearning(JobBase): | |||||
| **kwargs | **kwargs | ||||
| ) # todo: Distinguishing incremental update and fully overwrite | ) # todo: Distinguishing incremental update and fully overwrite | ||||
| task_groups = self.estimator.estimator.task_groups | |||||
| extractor_file = FileOps.join_path( | |||||
| os.path.dirname(self.estimator.estimator.task_index_url), | |||||
| "kb_extractor.pkl" | |||||
| ) | |||||
| try: | |||||
| extractor_file = self.kb_server.upload_file(extractor_file) | |||||
| except Exception as err: | |||||
| self.log.error( | |||||
| f"Upload task extractor_file fail {extractor_file}: {err}") | |||||
| extractor_file = FileOps.load(extractor_file) | |||||
| if isinstance(task_index_url, str) and FileOps.exists(task_index_url): | |||||
| task_index = FileOps.load(task_index_url) | |||||
| else: | |||||
| task_index = task_index_url | |||||
| extractor = task_index['extractor'] | |||||
| task_groups = task_index['task_groups'] | |||||
| model_upload_key = {} | |||||
| for task in task_groups: | for task in task_groups: | ||||
| model_file = task.model.model | |||||
| save_model = FileOps.join_path( | |||||
| self.config.output_url, | |||||
| os.path.basename(model_file) | |||||
| ) | |||||
| if model_file not in model_upload_key: | |||||
| model_upload_key[model_file] = FileOps.upload(model_file, | |||||
| save_model) | |||||
| model_file = model_upload_key[model_file] | |||||
| try: | try: | ||||
| model = self.kb_server.upload_file(task.model.model) | |||||
| except Exception: | |||||
| model_obj = set_backend( | |||||
| model = self.kb_server.upload_file(save_model) | |||||
| except Exception as err: | |||||
| self.log.error( | |||||
| f"Upload task model of {model_file} fail: {err}" | |||||
| ) | |||||
| model = set_backend( | |||||
| estimator=self.estimator.estimator.base_model | estimator=self.estimator.estimator.base_model | ||||
| ) | ) | ||||
| model = model_obj.load(task.model.model) | |||||
| model.load(model_file) | |||||
| task.model.model = model | task.model.model = model | ||||
| for _task in task.tasks: | |||||
| sample_dir = FileOps.join_path( | |||||
| self.config.output_url, | |||||
| f"{_task.samples.data_type}_{_task.entry}.sample") | |||||
| task.samples.save(sample_dir) | |||||
| try: | |||||
| sample_dir = self.kb_server.upload_file(sample_dir) | |||||
| except Exception as err: | |||||
| self.log.error( | |||||
| f"Upload task samples of {_task.entry} fail: {err}") | |||||
| _task.samples.data_url = sample_dir | |||||
| save_extractor = FileOps.join_path( | |||||
| self.config.output_url, | |||||
| KBResourceConstant.TASK_EXTRACTOR_NAME | |||||
| ) | |||||
| extractor = FileOps.dump(extractor, save_extractor) | |||||
| try: | |||||
| extractor = self.kb_server.upload_file(extractor) | |||||
| except Exception as err: | |||||
| self.log.error(f"Upload task extractor fail: {err}") | |||||
| task_info = { | task_info = { | ||||
| "task_groups": task_groups, | "task_groups": task_groups, | ||||
| "extractor": extractor_file | |||||
| "extractor": extractor | |||||
| } | } | ||||
| fd, name = tempfile.mkstemp() | fd, name = tempfile.mkstemp() | ||||
| FileOps.dump(task_info, name) | FileOps.dump(task_info, name) | ||||
| @@ -132,13 +159,10 @@ class LifelongLearning(JobBase): | |||||
| if not index_file: | if not index_file: | ||||
| self.log.error(f"KB update Fail !") | self.log.error(f"KB update Fail !") | ||||
| index_file = name | index_file = name | ||||
| FileOps.upload(index_file, self.config.task_index) | FileOps.upload(index_file, self.config.task_index) | ||||
| if os.path.isfile(name): | |||||
| os.close(fd) | |||||
| os.remove(name) | |||||
| task_info_res = self.estimator.model_info( | task_info_res = self.estimator.model_info( | ||||
| self.config.task_index, result=res, | |||||
| self.config.task_index, | |||||
| relpath=self.config.data_path_prefix) | relpath=self.config.data_path_prefix) | ||||
| self.report_task_info( | self.report_task_info( | ||||
| None, K8sResourceKindStatus.COMPLETED.value, task_info_res) | None, K8sResourceKindStatus.COMPLETED.value, task_info_res) | ||||
| @@ -155,7 +179,7 @@ class LifelongLearning(JobBase): | |||||
| **kwargs | **kwargs | ||||
| ) | ) | ||||
| def evaluate(self, data, post_process=None, model_threshold=0.1, **kwargs): | |||||
| def evaluate(self, data, post_process=None, **kwargs): | |||||
| callback_func = None | callback_func = None | ||||
| if callable(post_process): | if callable(post_process): | ||||
| callback_func = post_process | callback_func = post_process | ||||
| @@ -170,14 +194,35 @@ class LifelongLearning(JobBase): | |||||
| FileOps.download(task_index_url, index_url) | FileOps.download(task_index_url, index_url) | ||||
| res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) | res, tasks_detail = self.estimator.evaluate(data=data, **kwargs) | ||||
| drop_tasks = [] | drop_tasks = [] | ||||
| model_filter_operator = self.get_parameters("operator", ">") | |||||
| model_threshold = float(self.get_parameters('model_threshold', 0.1)) | |||||
| operator_map = { | |||||
| ">": lambda x, y: x > y, | |||||
| "<": lambda x, y: x < y, | |||||
| "=": lambda x, y: x == y, | |||||
| ">=": lambda x, y: x >= y, | |||||
| "<=": lambda x, y: x <= y, | |||||
| } | |||||
| if model_filter_operator not in operator_map: | |||||
| self.log.warn( | |||||
| f"operator {model_filter_operator} use to " | |||||
| f"compare is not allow, set to <" | |||||
| ) | |||||
| model_filter_operator = "<" | |||||
| operator_func = operator_map[model_filter_operator] | |||||
| for detail in tasks_detail: | for detail in tasks_detail: | ||||
| scores = detail.scores | scores = detail.scores | ||||
| entry = detail.entry | entry = detail.entry | ||||
| self.log.info(f"{entry} socres: {scores}") | |||||
| if any(map(lambda x: float(x) < model_threshold, scores.values())): | |||||
| self.log.info(f"{entry} scores: {scores}") | |||||
| if any(map(lambda x: operator_func(float(x), | |||||
| model_threshold), | |||||
| scores.values())): | |||||
| self.log.warn( | self.log.warn( | ||||
| f"{entry} will not be deploy " | |||||
| f"because scores lt {model_threshold}") | |||||
| f"{entry} will not be deploy because all " | |||||
| f"scores {model_filter_operator} {model_threshold}") | |||||
| drop_tasks.append(entry) | drop_tasks.append(entry) | ||||
| continue | continue | ||||
| drop_task = ",".join(drop_tasks) | drop_task = ",".join(drop_tasks) | ||||
| @@ -199,6 +244,7 @@ class LifelongLearning(JobBase): | |||||
| return callback_func(res) if callback_func else res | return callback_func(res) if callback_func else res | ||||
| def inference(self, data=None, post_process=None, **kwargs): | def inference(self, data=None, post_process=None, **kwargs): | ||||
| task_index_url = self.get_parameters( | task_index_url = self.get_parameters( | ||||
| "MODEL_URLS", self.config.task_index) | "MODEL_URLS", self.config.task_index) | ||||
| index_url = self.estimator.estimator.task_index_url | index_url = self.estimator.estimator.task_index_url | ||||