diff --git a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py index 1af4edc5..fab3e1c0 100644 --- a/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py +++ b/lib/sedna/algorithms/multi_task_learning/multi_task_learning.py @@ -14,12 +14,12 @@ import os import json -import joblib from sedna.datasources import BaseDataSource from sedna.backend import set_backend from sedna.common.log import LOGGER from sedna.common.config import Context +from sedna.common.constant import KBResourceConstant from sedna.common.file_ops import FileOps from sedna.common.class_factory import ClassFactory, ClassType @@ -68,12 +68,13 @@ class MulTaskLearning: self.extractor = None self.base_model = estimator self.task_groups = None - self.task_index_url = Context.get_parameters( - "MODEL_URLS", '/tmp/index.pkl' + self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value + self.min_train_sample = int( + Context.get_parameters( + "MIN_TRAIN_SAMPLE", + KBResourceConstant.MIN_TRAIN_SAMPLE.value + ) ) - self.min_train_sample = int(Context.get_parameters( - "MIN_TRAIN_SAMPLE", '10' - )) @staticmethod def parse_param(param_str): @@ -211,37 +212,37 @@ class MulTaskLearning: self.models[i] = model feedback[entry] = res self.task_groups[i] = task - extractor_file = FileOps.join_path( - os.path.dirname(self.task_index_url), - "kb_extractor.pkl" - ) - joblib.dump(self.extractor, extractor_file) + task_index = { - "extractor": extractor_file, + "extractor": self.extractor, "task_groups": self.task_groups } - joblib.dump(task_index, self.task_index_url) if valid_data: - feedback = self.evaluate(valid_data, **kwargs) + feedback, _ = self.evaluate(valid_data, **kwargs) + try: + FileOps.dump(task_index, self.task_index_url) + except TypeError: + return feedback, task_index + return feedback, self.task_index_url - return feedback + def load(self, task_index_url=None): + if task_index_url: + self.task_index_url = task_index_url + assert FileOps.exists(self.task_index_url), FileExistsError( + f"Task index miss: {self.task_index_url}" + ) + task_index = FileOps.load(self.task_index_url) + self.extractor = task_index['extractor'] + if isinstance(self.extractor, str): + self.extractor = FileOps.load(self.extractor) + self.task_groups = task_index['task_groups'] + self.models = [task.model for task in self.task_groups] def predict(self, data: BaseDataSource, post_process=None, **kwargs): if not (self.models and self.extractor): - task_index = joblib.load(self.task_index_url) - extractor_file = FileOps.join_path( - os.path.dirname(self.task_index_url), - "kb_extractor.pkl" - ) - if (not callable(task_index['extractor']) and - isinstance(task_index['extractor'], str)): - FileOps.download(task_index['extractor'], extractor_file) - self.extractor = joblib.load(extractor_file) - else: - self.extractor = task_index['extractor'] - self.task_groups = task_index['task_groups'] - self.models = [task.model for task in self.task_groups] + self.load() + data, mappings = self.task_mining(samples=data) samples, models = self.task_remodeling(samples=data, mappings=mappings) diff --git a/lib/sedna/common/constant.py b/lib/sedna/common/constant.py index 79ca191b..029bb5c9 100644 --- a/lib/sedna/common/constant.py +++ b/lib/sedna/common/constant.py @@ -42,3 +42,9 @@ class K8sResourceKindStatus(Enum): COMPLETED = "completed" FAILED = "failed" RUNNING = "running" + + +class KBResourceConstant(Enum): + MIN_TRAIN_SAMPLE = 10 + KB_INDEX_NAME = "index.pkl" + TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl" diff --git a/lib/sedna/common/file_ops.py b/lib/sedna/common/file_ops.py index 7694a1a7..668fbbc5 100644 --- a/lib/sedna/common/file_ops.py +++ b/lib/sedna/common/file_ops.py @@ -17,11 +17,12 @@ import os import re +import joblib import codecs import pickle import shutil -import tempfile import hashlib +import tempfile from urllib.parse import urlparse from .utils import singleton @@ -98,15 +99,23 @@ class FileOps: if not args[0]: args[0] = os.path.sep _path = cls.join_path(*args) - if os.path.isdir(_path) and clean: - shutil.rmtree(_path) + if clean: + cls.delete(_path) if os.path.isfile(_path): - if clean: - os.remove(_path) _path = cls.join_path(*args[:len(args) - 1]) os.makedirs(_path, exist_ok=True) return target + @classmethod + def delete(cls, path): + try: + if os.path.isdir(path): + shutil.rmtree(path) + if os.path.isfile(path): + os.remove(path) + except Exception: + pass + @classmethod def make_base_dir(cls, *args): """Make new a base directory. @@ -179,6 +188,7 @@ class FileOps: :rtype: object or None. """ + filename = cls.download(filename) if not os.path.isfile(filename): return None with open(filename, "rb") as f: @@ -203,8 +213,7 @@ class FileOps: name = os.path.join(src, files) back_name = os.path.join(dst, files) if os.path.isfile(name): - if os.path.isfile(back_name): - shutil.copy(name, back_name) + shutil.copy(name, back_name) else: if not os.path.isdir(back_name): shutil.copytree(name, back_name) @@ -219,7 +228,7 @@ class FileOps: :param str dst: destination path. """ - if dst is None or dst == "": + if not dst: return if os.path.isfile(src): @@ -237,10 +246,34 @@ class FileOps: cls.copy_folder(src, dst) @classmethod - def download(cls, src, dst, unzip=False) -> str: - if dst is None: - dst = tempfile.mkdtemp() + def dump(cls, obj, dst=None) -> str: + fd, name = tempfile.mkstemp() + os.close(fd) + joblib.dump(obj, name) + return cls.upload(name, dst) + + @classmethod + def load(cls, src: str): + src = cls.download(src) + obj = joblib.load(src) + return obj + @classmethod + def is_remote(cls, src): + if src.startswith(( + cls._GCS_PREFIX, + cls._S3_PREFIX + )): + return True + if re.search(cls._URI_RE, src): + return True + return False + + @classmethod + def download(cls, src, dst=None, unzip=False) -> str: + if dst is None: + fd, dst = tempfile.mkstemp() + os.close(fd) cls.clean_folder([os.path.dirname(dst)], clean=False) if src.startswith(cls._GCS_PREFIX): cls.gcs_download(src, dst) @@ -255,18 +288,29 @@ class FileOps: return dst @classmethod - def upload(cls, src, dst, tar=False) -> str: + def upload(cls, src, dst, tar=False, clean=True) -> str: if dst is None: - dst = tempfile.mkdtemp() + fd, dst = tempfile.mkstemp() + os.close(fd) + if not cls.is_local(src): + fd, name = tempfile.mkstemp() + os.close(fd) + cls.download(src, name) + src = name if tar: cls._tar(src, f"{src}.tar.gz") src = f"{src}.tar.gz" + if dst.startswith(cls._GCS_PREFIX): cls.gcs_upload(src, dst) elif dst.startswith(cls._S3_PREFIX): cls.s3_upload(src, dst) - elif cls.is_local(dst): + else: cls.copy_file(src, dst) + if cls.is_local(src) and clean: + if cls.is_local(dst) and os.path.samefile(src, dst): + return dst + cls.delete(src) return dst @classmethod @@ -287,21 +331,24 @@ class FileOps: bucket_name = bucket_args[0] bucket_path = len(bucket_args) > 1 and bucket_args[1] or "" - objects = client.list_objects(bucket_name, - prefix=bucket_path, - recursive=True, - use_api_v1=True) + objects = list(client.list_objects(bucket_name, + prefix=bucket_path, + recursive=True, + use_api_v1=True)) count = 0 - + num = len(objects) for obj in objects: # Replace any prefix from the object key with out_dir subdir_object_key = obj.object_name[len(bucket_path):].strip("/") # fget_object handles directory creation if does not exist if not obj.is_dir: - local_file = os.path.join( - out_dir, - subdir_object_key or os.path.basename(obj.object_name) - ) + if num == 1 and not os.path.isdir(out_dir): + local_file = out_dir + else: + local_file = os.path.join( + out_dir, + subdir_object_key or os.path.basename(obj.object_name) + ) client.fget_object(bucket_name, obj.object_name, local_file) count += 1 @@ -311,9 +358,10 @@ class FileOps: def s3_download(cls, src, dst): s3 = _create_minio_client() count = cls._download_s3(s3, src, dst) + if count == 0: raise RuntimeError("Failed to fetch files." - "The path %s does not exist." % (src)) + "The path %s does not exist." % src) @classmethod def s3_upload(cls, src, dst): diff --git a/lib/sedna/core/lifelong_learning/lifelong_learning.py b/lib/sedna/core/lifelong_learning/lifelong_learning.py index 64658a82..119f4fe7 100644 --- a/lib/sedna/core/lifelong_learning/lifelong_learning.py +++ b/lib/sedna/core/lifelong_learning/lifelong_learning.py @@ -15,12 +15,12 @@ import os import tempfile -import joblib - from sedna.backend import set_backend from sedna.core.base import JobBase from sedna.common.file_ops import FileOps -from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus +from sedna.common.constant import K8sResourceKind +from sedna.common.constant import K8sResourceKindStatus +from sedna.common.constant import KBResourceConstant from sedna.common.config import Context from sedna.common.class_factory import ClassType, ClassFactory from sedna.algorithms.multi_task_learning import MulTaskLearning @@ -67,7 +67,10 @@ class LifelongLearning(JobBase): ll_kb_server=Context.get_parameters("KB_SERVER"), output_url=Context.get_parameters("OUTPUT_URL", "/tmp") ) - task_index = FileOps.join_path(config['output_url'], 'index.pkl') + task_index = FileOps.join_path( + config['output_url'], + KBResourceConstant.KB_INDEX_NAME + ) config['task_index'] = task_index super(LifelongLearning, self).__init__( estimator=e, config=config @@ -91,7 +94,7 @@ class LifelongLearning(JobBase): if post_process is not None: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) - res = self.estimator.train( + res, task_index_url = self.estimator.train( train_data=train_data, valid_data=valid_data, **kwargs @@ -107,7 +110,7 @@ class LifelongLearning(JobBase): except Exception as err: self.log.error( f"Upload task extractor_file fail {extractor_file}: {err}") - extractor_file = joblib.load(extractor_file) + extractor_file = FileOps.load(extractor_file) for task in task_groups: try: model = self.kb_server.upload_file(task.model.model) @@ -123,7 +126,7 @@ class LifelongLearning(JobBase): "extractor": extractor_file } fd, name = tempfile.mkstemp() - joblib.dump(task_info, name) + FileOps.dump(task_info, name) index_file = self.kb_server.update_db(name) if not index_file: diff --git a/lib/sedna/datasources/__init__.py b/lib/sedna/datasources/__init__.py index 25d48dc8..b793e163 100644 --- a/lib/sedna/datasources/__init__.py +++ b/lib/sedna/datasources/__init__.py @@ -14,7 +14,6 @@ from abc import ABC -import joblib import numpy as np import pandas as pd @@ -51,7 +50,7 @@ class BaseDataSource: return self.data_type == "test" def save(self, output=""): - joblib.dump(self, output) + return FileOps.dump(self, output) class TxtDataParse(BaseDataSource, ABC): diff --git a/lib/sedna/service/run_kb.py b/lib/sedna/service/run_kb.py index 73d682f6..4cca7de4 100644 --- a/lib/sedna/service/run_kb.py +++ b/lib/sedna/service/run_kb.py @@ -22,13 +22,14 @@ from sedna.service.server.knowledgeBase.server import KBServer def main(): init_db() server = os.getenv("KnowledgeBaseServer", "") + kb_dir = os.getenv("KnowledgeBasePath", "") match = re.compile( "(https?)://([0-9]{1,3}(?:\\.[0-9]{1,3}){3}):([0-9]+)").match(server) if match: _, host, port = match.groups() else: host, port = '0.0.0.0', 9020 - KBServer(host=host, http_port=int(port)).start() + KBServer(host=host, http_port=int(port), save_dir=kb_dir).start() if __name__ == '__main__': diff --git a/lib/sedna/service/server/knowledgeBase/server.py b/lib/sedna/service/server/knowledgeBase/server.py index f570de05..a489e058 100644 --- a/lib/sedna/service/server/knowledgeBase/server.py +++ b/lib/sedna/service/server/knowledgeBase/server.py @@ -27,6 +27,7 @@ from starlette.responses import JSONResponse from sedna.service.server.base import BaseServer from sedna.common.file_ops import FileOps +from sedna.common.constant import KBResourceConstant from .model import * @@ -52,7 +53,7 @@ class KBServer(BaseServer): http_port=http_port, workers=workers) self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0] self.url = f"{self.url}/{servername}" - self.latest = 0 + self.kb_index = KBResourceConstant.KB_INDEX_NAME.value self.app = FastAPI( routes=[ APIRoute( @@ -94,8 +95,7 @@ class KBServer(BaseServer): pass def _get_db_index(self): - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") + _index_path = FileOps.join_path(self.save_dir, self.kb_index) if not FileOps.exists(_index_path): # todo: get from kb pass return _index_path @@ -130,8 +130,7 @@ class KBServer(BaseServer): }, synchronize_session=False) # todo: get from kb - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") + _index_path = FileOps.join_path(self.save_dir, self.kb_index) task_info = joblib.load(_index_path) new_task_group = [] @@ -143,13 +142,9 @@ class KBServer(BaseServer): continue new_task_group.append(task_group) task_info["task_groups"] = new_task_group - self.latest += 1 - - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") - joblib.dump(task_info, _index_path) - res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl" - return res + _index_path = FileOps.join_path(self.save_dir, self.kb_index) + FileOps.dump(task_info, _index_path) + return f"/file/download?files={self.kb_index}&name={self.kb_index}" def update(self, task: UploadFile = File(...)): tasks = task.file.read() @@ -178,21 +173,16 @@ class KBServer(BaseServer): if t_create: session.add(t_obj) - sampel_obj = Samples( + sample_obj = Samples( data_type=task.samples.data_type, - sample_num=len(task.samples) + sample_num=len(task.samples), + data_url=getattr(task, 'data_url', '') ) - session.add(sampel_obj) + session.add(sample_obj) session.flush() session.commit() - sample_dir = FileOps.join_path( - self.save_dir, - f"{sampel_obj.data_type}_{sampel_obj.id}.pkl") - task.samples.save(sample_dir) - sampel_obj.data_url = sample_dir - - tsample = TaskSample(sample=sampel_obj, task=t_obj) + tsample = TaskSample(sample=sample_obj, task=t_obj) session.add(tsample) session.flush() t_id.append(t_obj.id) @@ -221,15 +211,8 @@ class KBServer(BaseServer): session.commit() - self.latest += 1 - extractor_file = upload_info["extractor"] - extractor_path = FileOps.join_path(self.save_dir, - f"kb_extractor.pkl") - FileOps.upload(extractor_file, extractor_path) - # todo: get from kb - _index_path = FileOps.join_path(self.save_dir, - f"kb_index_{self.latest}.pkl") - FileOps.upload(name, _index_path) - res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl" - return res + _index_path = FileOps.join_path(self.save_dir, self.kb_index) + _index_path = FileOps.dump(upload_info, _index_path) + + return f"/file/download?files={self.kb_index}&name={self.kb_index}"