- fix file_ops method - fix kb save bug Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>tags/v0.3.1
| @@ -14,12 +14,12 @@ | |||
| import os | |||
| import json | |||
| import joblib | |||
| from sedna.datasources import BaseDataSource | |||
| from sedna.backend import set_backend | |||
| from sedna.common.log import LOGGER | |||
| from sedna.common.config import Context | |||
| from sedna.common.constant import KBResourceConstant | |||
| from sedna.common.file_ops import FileOps | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| @@ -68,12 +68,13 @@ class MulTaskLearning: | |||
| self.extractor = None | |||
| self.base_model = estimator | |||
| self.task_groups = None | |||
| self.task_index_url = Context.get_parameters( | |||
| "MODEL_URLS", '/tmp/index.pkl' | |||
| self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value | |||
| self.min_train_sample = int( | |||
| Context.get_parameters( | |||
| "MIN_TRAIN_SAMPLE", | |||
| KBResourceConstant.MIN_TRAIN_SAMPLE.value | |||
| ) | |||
| ) | |||
| self.min_train_sample = int(Context.get_parameters( | |||
| "MIN_TRAIN_SAMPLE", '10' | |||
| )) | |||
| @staticmethod | |||
| def parse_param(param_str): | |||
| @@ -211,37 +212,37 @@ class MulTaskLearning: | |||
| self.models[i] = model | |||
| feedback[entry] = res | |||
| self.task_groups[i] = task | |||
| extractor_file = FileOps.join_path( | |||
| os.path.dirname(self.task_index_url), | |||
| "kb_extractor.pkl" | |||
| ) | |||
| joblib.dump(self.extractor, extractor_file) | |||
| task_index = { | |||
| "extractor": extractor_file, | |||
| "extractor": self.extractor, | |||
| "task_groups": self.task_groups | |||
| } | |||
| joblib.dump(task_index, self.task_index_url) | |||
| if valid_data: | |||
| feedback = self.evaluate(valid_data, **kwargs) | |||
| feedback, _ = self.evaluate(valid_data, **kwargs) | |||
| try: | |||
| FileOps.dump(task_index, self.task_index_url) | |||
| except TypeError: | |||
| return feedback, task_index | |||
| return feedback, self.task_index_url | |||
| return feedback | |||
| def load(self, task_index_url=None): | |||
| if task_index_url: | |||
| self.task_index_url = task_index_url | |||
| assert FileOps.exists(self.task_index_url), FileExistsError( | |||
| f"Task index miss: {self.task_index_url}" | |||
| ) | |||
| task_index = FileOps.load(self.task_index_url) | |||
| self.extractor = task_index['extractor'] | |||
| if isinstance(self.extractor, str): | |||
| self.extractor = FileOps.load(self.extractor) | |||
| self.task_groups = task_index['task_groups'] | |||
| self.models = [task.model for task in self.task_groups] | |||
| def predict(self, data: BaseDataSource, | |||
| post_process=None, **kwargs): | |||
| if not (self.models and self.extractor): | |||
| task_index = joblib.load(self.task_index_url) | |||
| extractor_file = FileOps.join_path( | |||
| os.path.dirname(self.task_index_url), | |||
| "kb_extractor.pkl" | |||
| ) | |||
| if (not callable(task_index['extractor']) and | |||
| isinstance(task_index['extractor'], str)): | |||
| FileOps.download(task_index['extractor'], extractor_file) | |||
| self.extractor = joblib.load(extractor_file) | |||
| else: | |||
| self.extractor = task_index['extractor'] | |||
| self.task_groups = task_index['task_groups'] | |||
| self.models = [task.model for task in self.task_groups] | |||
| self.load() | |||
| data, mappings = self.task_mining(samples=data) | |||
| samples, models = self.task_remodeling(samples=data, mappings=mappings) | |||
| @@ -42,3 +42,9 @@ class K8sResourceKindStatus(Enum): | |||
| COMPLETED = "completed" | |||
| FAILED = "failed" | |||
| RUNNING = "running" | |||
| class KBResourceConstant(Enum): | |||
| MIN_TRAIN_SAMPLE = 10 | |||
| KB_INDEX_NAME = "index.pkl" | |||
| TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl" | |||
| @@ -17,11 +17,12 @@ | |||
| import os | |||
| import re | |||
| import joblib | |||
| import codecs | |||
| import pickle | |||
| import shutil | |||
| import tempfile | |||
| import hashlib | |||
| import tempfile | |||
| from urllib.parse import urlparse | |||
| from .utils import singleton | |||
| @@ -98,15 +99,23 @@ class FileOps: | |||
| if not args[0]: | |||
| args[0] = os.path.sep | |||
| _path = cls.join_path(*args) | |||
| if os.path.isdir(_path) and clean: | |||
| shutil.rmtree(_path) | |||
| if clean: | |||
| cls.delete(_path) | |||
| if os.path.isfile(_path): | |||
| if clean: | |||
| os.remove(_path) | |||
| _path = cls.join_path(*args[:len(args) - 1]) | |||
| os.makedirs(_path, exist_ok=True) | |||
| return target | |||
| @classmethod | |||
| def delete(cls, path): | |||
| try: | |||
| if os.path.isdir(path): | |||
| shutil.rmtree(path) | |||
| if os.path.isfile(path): | |||
| os.remove(path) | |||
| except Exception: | |||
| pass | |||
| @classmethod | |||
| def make_base_dir(cls, *args): | |||
| """Make new a base directory. | |||
| @@ -179,6 +188,7 @@ class FileOps: | |||
| :rtype: object or None. | |||
| """ | |||
| filename = cls.download(filename) | |||
| if not os.path.isfile(filename): | |||
| return None | |||
| with open(filename, "rb") as f: | |||
| @@ -203,8 +213,7 @@ class FileOps: | |||
| name = os.path.join(src, files) | |||
| back_name = os.path.join(dst, files) | |||
| if os.path.isfile(name): | |||
| if os.path.isfile(back_name): | |||
| shutil.copy(name, back_name) | |||
| shutil.copy(name, back_name) | |||
| else: | |||
| if not os.path.isdir(back_name): | |||
| shutil.copytree(name, back_name) | |||
| @@ -219,7 +228,7 @@ class FileOps: | |||
| :param str dst: destination path. | |||
| """ | |||
| if dst is None or dst == "": | |||
| if not dst: | |||
| return | |||
| if os.path.isfile(src): | |||
| @@ -237,10 +246,34 @@ class FileOps: | |||
| cls.copy_folder(src, dst) | |||
| @classmethod | |||
| def download(cls, src, dst, unzip=False) -> str: | |||
| if dst is None: | |||
| dst = tempfile.mkdtemp() | |||
| def dump(cls, obj, dst=None) -> str: | |||
| fd, name = tempfile.mkstemp() | |||
| os.close(fd) | |||
| joblib.dump(obj, name) | |||
| return cls.upload(name, dst) | |||
| @classmethod | |||
| def load(cls, src: str): | |||
| src = cls.download(src) | |||
| obj = joblib.load(src) | |||
| return obj | |||
| @classmethod | |||
| def is_remote(cls, src): | |||
| if src.startswith(( | |||
| cls._GCS_PREFIX, | |||
| cls._S3_PREFIX | |||
| )): | |||
| return True | |||
| if re.search(cls._URI_RE, src): | |||
| return True | |||
| return False | |||
| @classmethod | |||
| def download(cls, src, dst=None, unzip=False) -> str: | |||
| if dst is None: | |||
| fd, dst = tempfile.mkstemp() | |||
| os.close(fd) | |||
| cls.clean_folder([os.path.dirname(dst)], clean=False) | |||
| if src.startswith(cls._GCS_PREFIX): | |||
| cls.gcs_download(src, dst) | |||
| @@ -255,18 +288,29 @@ class FileOps: | |||
| return dst | |||
| @classmethod | |||
| def upload(cls, src, dst, tar=False) -> str: | |||
| def upload(cls, src, dst, tar=False, clean=True) -> str: | |||
| if dst is None: | |||
| dst = tempfile.mkdtemp() | |||
| fd, dst = tempfile.mkstemp() | |||
| os.close(fd) | |||
| if not cls.is_local(src): | |||
| fd, name = tempfile.mkstemp() | |||
| os.close(fd) | |||
| cls.download(src, name) | |||
| src = name | |||
| if tar: | |||
| cls._tar(src, f"{src}.tar.gz") | |||
| src = f"{src}.tar.gz" | |||
| if dst.startswith(cls._GCS_PREFIX): | |||
| cls.gcs_upload(src, dst) | |||
| elif dst.startswith(cls._S3_PREFIX): | |||
| cls.s3_upload(src, dst) | |||
| elif cls.is_local(dst): | |||
| else: | |||
| cls.copy_file(src, dst) | |||
| if cls.is_local(src) and clean: | |||
| if cls.is_local(dst) and os.path.samefile(src, dst): | |||
| return dst | |||
| cls.delete(src) | |||
| return dst | |||
| @classmethod | |||
| @@ -287,21 +331,24 @@ class FileOps: | |||
| bucket_name = bucket_args[0] | |||
| bucket_path = len(bucket_args) > 1 and bucket_args[1] or "" | |||
| objects = client.list_objects(bucket_name, | |||
| prefix=bucket_path, | |||
| recursive=True, | |||
| use_api_v1=True) | |||
| objects = list(client.list_objects(bucket_name, | |||
| prefix=bucket_path, | |||
| recursive=True, | |||
| use_api_v1=True)) | |||
| count = 0 | |||
| num = len(objects) | |||
| for obj in objects: | |||
| # Replace any prefix from the object key with out_dir | |||
| subdir_object_key = obj.object_name[len(bucket_path):].strip("/") | |||
| # fget_object handles directory creation if does not exist | |||
| if not obj.is_dir: | |||
| local_file = os.path.join( | |||
| out_dir, | |||
| subdir_object_key or os.path.basename(obj.object_name) | |||
| ) | |||
| if num == 1 and not os.path.isdir(out_dir): | |||
| local_file = out_dir | |||
| else: | |||
| local_file = os.path.join( | |||
| out_dir, | |||
| subdir_object_key or os.path.basename(obj.object_name) | |||
| ) | |||
| client.fget_object(bucket_name, obj.object_name, local_file) | |||
| count += 1 | |||
| @@ -311,9 +358,10 @@ class FileOps: | |||
| def s3_download(cls, src, dst): | |||
| s3 = _create_minio_client() | |||
| count = cls._download_s3(s3, src, dst) | |||
| if count == 0: | |||
| raise RuntimeError("Failed to fetch files." | |||
| "The path %s does not exist." % (src)) | |||
| "The path %s does not exist." % src) | |||
| @classmethod | |||
| def s3_upload(cls, src, dst): | |||
| @@ -15,12 +15,12 @@ | |||
| import os | |||
| import tempfile | |||
| import joblib | |||
| from sedna.backend import set_backend | |||
| from sedna.core.base import JobBase | |||
| from sedna.common.file_ops import FileOps | |||
| from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus | |||
| from sedna.common.constant import K8sResourceKind | |||
| from sedna.common.constant import K8sResourceKindStatus | |||
| from sedna.common.constant import KBResourceConstant | |||
| from sedna.common.config import Context | |||
| from sedna.common.class_factory import ClassType, ClassFactory | |||
| from sedna.algorithms.multi_task_learning import MulTaskLearning | |||
| @@ -67,7 +67,10 @@ class LifelongLearning(JobBase): | |||
| ll_kb_server=Context.get_parameters("KB_SERVER"), | |||
| output_url=Context.get_parameters("OUTPUT_URL", "/tmp") | |||
| ) | |||
| task_index = FileOps.join_path(config['output_url'], 'index.pkl') | |||
| task_index = FileOps.join_path( | |||
| config['output_url'], | |||
| KBResourceConstant.KB_INDEX_NAME | |||
| ) | |||
| config['task_index'] = task_index | |||
| super(LifelongLearning, self).__init__( | |||
| estimator=e, config=config | |||
| @@ -91,7 +94,7 @@ class LifelongLearning(JobBase): | |||
| if post_process is not None: | |||
| callback_func = ClassFactory.get_cls( | |||
| ClassType.CALLBACK, post_process) | |||
| res = self.estimator.train( | |||
| res, task_index_url = self.estimator.train( | |||
| train_data=train_data, | |||
| valid_data=valid_data, | |||
| **kwargs | |||
| @@ -107,7 +110,7 @@ class LifelongLearning(JobBase): | |||
| except Exception as err: | |||
| self.log.error( | |||
| f"Upload task extractor_file fail {extractor_file}: {err}") | |||
| extractor_file = joblib.load(extractor_file) | |||
| extractor_file = FileOps.load(extractor_file) | |||
| for task in task_groups: | |||
| try: | |||
| model = self.kb_server.upload_file(task.model.model) | |||
| @@ -123,7 +126,7 @@ class LifelongLearning(JobBase): | |||
| "extractor": extractor_file | |||
| } | |||
| fd, name = tempfile.mkstemp() | |||
| joblib.dump(task_info, name) | |||
| FileOps.dump(task_info, name) | |||
| index_file = self.kb_server.update_db(name) | |||
| if not index_file: | |||
| @@ -14,7 +14,6 @@ | |||
| from abc import ABC | |||
| import joblib | |||
| import numpy as np | |||
| import pandas as pd | |||
| @@ -51,7 +50,7 @@ class BaseDataSource: | |||
| return self.data_type == "test" | |||
| def save(self, output=""): | |||
| joblib.dump(self, output) | |||
| return FileOps.dump(self, output) | |||
| class TxtDataParse(BaseDataSource, ABC): | |||
| @@ -22,13 +22,14 @@ from sedna.service.server.knowledgeBase.server import KBServer | |||
| def main(): | |||
| init_db() | |||
| server = os.getenv("KnowledgeBaseServer", "") | |||
| kb_dir = os.getenv("KnowledgeBasePath", "") | |||
| match = re.compile( | |||
| "(https?)://([0-9]{1,3}(?:\\.[0-9]{1,3}){3}):([0-9]+)").match(server) | |||
| if match: | |||
| _, host, port = match.groups() | |||
| else: | |||
| host, port = '0.0.0.0', 9020 | |||
| KBServer(host=host, http_port=int(port)).start() | |||
| KBServer(host=host, http_port=int(port), save_dir=kb_dir).start() | |||
| if __name__ == '__main__': | |||
| @@ -27,6 +27,7 @@ from starlette.responses import JSONResponse | |||
| from sedna.service.server.base import BaseServer | |||
| from sedna.common.file_ops import FileOps | |||
| from sedna.common.constant import KBResourceConstant | |||
| from .model import * | |||
| @@ -52,7 +53,7 @@ class KBServer(BaseServer): | |||
| http_port=http_port, workers=workers) | |||
| self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0] | |||
| self.url = f"{self.url}/{servername}" | |||
| self.latest = 0 | |||
| self.kb_index = KBResourceConstant.KB_INDEX_NAME.value | |||
| self.app = FastAPI( | |||
| routes=[ | |||
| APIRoute( | |||
| @@ -94,8 +95,7 @@ class KBServer(BaseServer): | |||
| pass | |||
| def _get_db_index(self): | |||
| _index_path = FileOps.join_path(self.save_dir, | |||
| f"kb_index_{self.latest}.pkl") | |||
| _index_path = FileOps.join_path(self.save_dir, self.kb_index) | |||
| if not FileOps.exists(_index_path): # todo: get from kb | |||
| pass | |||
| return _index_path | |||
| @@ -130,8 +130,7 @@ class KBServer(BaseServer): | |||
| }, synchronize_session=False) | |||
| # todo: get from kb | |||
| _index_path = FileOps.join_path(self.save_dir, | |||
| f"kb_index_{self.latest}.pkl") | |||
| _index_path = FileOps.join_path(self.save_dir, self.kb_index) | |||
| task_info = joblib.load(_index_path) | |||
| new_task_group = [] | |||
| @@ -143,13 +142,9 @@ class KBServer(BaseServer): | |||
| continue | |||
| new_task_group.append(task_group) | |||
| task_info["task_groups"] = new_task_group | |||
| self.latest += 1 | |||
| _index_path = FileOps.join_path(self.save_dir, | |||
| f"kb_index_{self.latest}.pkl") | |||
| joblib.dump(task_info, _index_path) | |||
| res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl" | |||
| return res | |||
| _index_path = FileOps.join_path(self.save_dir, self.kb_index) | |||
| FileOps.dump(task_info, _index_path) | |||
| return f"/file/download?files={self.kb_index}&name={self.kb_index}" | |||
| def update(self, task: UploadFile = File(...)): | |||
| tasks = task.file.read() | |||
| @@ -178,21 +173,16 @@ class KBServer(BaseServer): | |||
| if t_create: | |||
| session.add(t_obj) | |||
| sampel_obj = Samples( | |||
| sample_obj = Samples( | |||
| data_type=task.samples.data_type, | |||
| sample_num=len(task.samples) | |||
| sample_num=len(task.samples), | |||
| data_url=getattr(task, 'data_url', '') | |||
| ) | |||
| session.add(sampel_obj) | |||
| session.add(sample_obj) | |||
| session.flush() | |||
| session.commit() | |||
| sample_dir = FileOps.join_path( | |||
| self.save_dir, | |||
| f"{sampel_obj.data_type}_{sampel_obj.id}.pkl") | |||
| task.samples.save(sample_dir) | |||
| sampel_obj.data_url = sample_dir | |||
| tsample = TaskSample(sample=sampel_obj, task=t_obj) | |||
| tsample = TaskSample(sample=sample_obj, task=t_obj) | |||
| session.add(tsample) | |||
| session.flush() | |||
| t_id.append(t_obj.id) | |||
| @@ -221,15 +211,8 @@ class KBServer(BaseServer): | |||
| session.commit() | |||
| self.latest += 1 | |||
| extractor_file = upload_info["extractor"] | |||
| extractor_path = FileOps.join_path(self.save_dir, | |||
| f"kb_extractor.pkl") | |||
| FileOps.upload(extractor_file, extractor_path) | |||
| # todo: get from kb | |||
| _index_path = FileOps.join_path(self.save_dir, | |||
| f"kb_index_{self.latest}.pkl") | |||
| FileOps.upload(name, _index_path) | |||
| res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl" | |||
| return res | |||
| _index_path = FileOps.join_path(self.save_dir, self.kb_index) | |||
| _index_path = FileOps.dump(upload_info, _index_path) | |||
| return f"/file/download?files={self.kb_index}&name={self.kb_index}" | |||