Browse Source

lifelong learning s3 support

- fix file_ops method
- fix kb save bug

Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>
tags/v0.3.1
JoeyHwong 4 years ago
parent
commit
cfd99d4c7a
7 changed files with 136 additions and 95 deletions
  1. +29
    -28
      lib/sedna/algorithms/multi_task_learning/multi_task_learning.py
  2. +6
    -0
      lib/sedna/common/constant.py
  3. +72
    -24
      lib/sedna/common/file_ops.py
  4. +10
    -7
      lib/sedna/core/lifelong_learning/lifelong_learning.py
  5. +1
    -2
      lib/sedna/datasources/__init__.py
  6. +2
    -1
      lib/sedna/service/run_kb.py
  7. +16
    -33
      lib/sedna/service/server/knowledgeBase/server.py

+ 29
- 28
lib/sedna/algorithms/multi_task_learning/multi_task_learning.py View File

@@ -14,12 +14,12 @@

import os
import json
import joblib

from sedna.datasources import BaseDataSource
from sedna.backend import set_backend
from sedna.common.log import LOGGER
from sedna.common.config import Context
from sedna.common.constant import KBResourceConstant
from sedna.common.file_ops import FileOps
from sedna.common.class_factory import ClassFactory, ClassType

@@ -68,12 +68,13 @@ class MulTaskLearning:
self.extractor = None
self.base_model = estimator
self.task_groups = None
self.task_index_url = Context.get_parameters(
"MODEL_URLS", '/tmp/index.pkl'
self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value
self.min_train_sample = int(
Context.get_parameters(
"MIN_TRAIN_SAMPLE",
KBResourceConstant.MIN_TRAIN_SAMPLE.value
)
)
self.min_train_sample = int(Context.get_parameters(
"MIN_TRAIN_SAMPLE", '10'
))

@staticmethod
def parse_param(param_str):
@@ -211,37 +212,37 @@ class MulTaskLearning:
self.models[i] = model
feedback[entry] = res
self.task_groups[i] = task
extractor_file = FileOps.join_path(
os.path.dirname(self.task_index_url),
"kb_extractor.pkl"
)
joblib.dump(self.extractor, extractor_file)

task_index = {
"extractor": extractor_file,
"extractor": self.extractor,
"task_groups": self.task_groups
}
joblib.dump(task_index, self.task_index_url)
if valid_data:
feedback = self.evaluate(valid_data, **kwargs)
feedback, _ = self.evaluate(valid_data, **kwargs)
try:
FileOps.dump(task_index, self.task_index_url)
except TypeError:
return feedback, task_index
return feedback, self.task_index_url

return feedback
def load(self, task_index_url=None):
if task_index_url:
self.task_index_url = task_index_url
assert FileOps.exists(self.task_index_url), FileExistsError(
f"Task index miss: {self.task_index_url}"
)
task_index = FileOps.load(self.task_index_url)
self.extractor = task_index['extractor']
if isinstance(self.extractor, str):
self.extractor = FileOps.load(self.extractor)
self.task_groups = task_index['task_groups']
self.models = [task.model for task in self.task_groups]

def predict(self, data: BaseDataSource,
post_process=None, **kwargs):
if not (self.models and self.extractor):
task_index = joblib.load(self.task_index_url)
extractor_file = FileOps.join_path(
os.path.dirname(self.task_index_url),
"kb_extractor.pkl"
)
if (not callable(task_index['extractor']) and
isinstance(task_index['extractor'], str)):
FileOps.download(task_index['extractor'], extractor_file)
self.extractor = joblib.load(extractor_file)
else:
self.extractor = task_index['extractor']
self.task_groups = task_index['task_groups']
self.models = [task.model for task in self.task_groups]
self.load()

data, mappings = self.task_mining(samples=data)
samples, models = self.task_remodeling(samples=data, mappings=mappings)



+ 6
- 0
lib/sedna/common/constant.py View File

@@ -42,3 +42,9 @@ class K8sResourceKindStatus(Enum):
COMPLETED = "completed"
FAILED = "failed"
RUNNING = "running"


class KBResourceConstant(Enum):
MIN_TRAIN_SAMPLE = 10
KB_INDEX_NAME = "index.pkl"
TASK_EXTRACTOR_NAME = "task_attr_extractor.pkl"

+ 72
- 24
lib/sedna/common/file_ops.py View File

@@ -17,11 +17,12 @@
import os
import re

import joblib
import codecs
import pickle
import shutil
import tempfile
import hashlib
import tempfile
from urllib.parse import urlparse

from .utils import singleton
@@ -98,15 +99,23 @@ class FileOps:
if not args[0]:
args[0] = os.path.sep
_path = cls.join_path(*args)
if os.path.isdir(_path) and clean:
shutil.rmtree(_path)
if clean:
cls.delete(_path)
if os.path.isfile(_path):
if clean:
os.remove(_path)
_path = cls.join_path(*args[:len(args) - 1])
os.makedirs(_path, exist_ok=True)
return target

@classmethod
def delete(cls, path):
try:
if os.path.isdir(path):
shutil.rmtree(path)
if os.path.isfile(path):
os.remove(path)
except Exception:
pass

@classmethod
def make_base_dir(cls, *args):
"""Make new a base directory.
@@ -179,6 +188,7 @@ class FileOps:
:rtype: object or None.

"""
filename = cls.download(filename)
if not os.path.isfile(filename):
return None
with open(filename, "rb") as f:
@@ -203,8 +213,7 @@ class FileOps:
name = os.path.join(src, files)
back_name = os.path.join(dst, files)
if os.path.isfile(name):
if os.path.isfile(back_name):
shutil.copy(name, back_name)
shutil.copy(name, back_name)
else:
if not os.path.isdir(back_name):
shutil.copytree(name, back_name)
@@ -219,7 +228,7 @@ class FileOps:
:param str dst: destination path.

"""
if dst is None or dst == "":
if not dst:
return

if os.path.isfile(src):
@@ -237,10 +246,34 @@ class FileOps:
cls.copy_folder(src, dst)

@classmethod
def download(cls, src, dst, unzip=False) -> str:
if dst is None:
dst = tempfile.mkdtemp()
def dump(cls, obj, dst=None) -> str:
fd, name = tempfile.mkstemp()
os.close(fd)
joblib.dump(obj, name)
return cls.upload(name, dst)

@classmethod
def load(cls, src: str):
src = cls.download(src)
obj = joblib.load(src)
return obj

@classmethod
def is_remote(cls, src):
if src.startswith((
cls._GCS_PREFIX,
cls._S3_PREFIX
)):
return True
if re.search(cls._URI_RE, src):
return True
return False

@classmethod
def download(cls, src, dst=None, unzip=False) -> str:
if dst is None:
fd, dst = tempfile.mkstemp()
os.close(fd)
cls.clean_folder([os.path.dirname(dst)], clean=False)
if src.startswith(cls._GCS_PREFIX):
cls.gcs_download(src, dst)
@@ -255,18 +288,29 @@ class FileOps:
return dst

@classmethod
def upload(cls, src, dst, tar=False) -> str:
def upload(cls, src, dst, tar=False, clean=True) -> str:
if dst is None:
dst = tempfile.mkdtemp()
fd, dst = tempfile.mkstemp()
os.close(fd)
if not cls.is_local(src):
fd, name = tempfile.mkstemp()
os.close(fd)
cls.download(src, name)
src = name
if tar:
cls._tar(src, f"{src}.tar.gz")
src = f"{src}.tar.gz"

if dst.startswith(cls._GCS_PREFIX):
cls.gcs_upload(src, dst)
elif dst.startswith(cls._S3_PREFIX):
cls.s3_upload(src, dst)
elif cls.is_local(dst):
else:
cls.copy_file(src, dst)
if cls.is_local(src) and clean:
if cls.is_local(dst) and os.path.samefile(src, dst):
return dst
cls.delete(src)
return dst

@classmethod
@@ -287,21 +331,24 @@ class FileOps:
bucket_name = bucket_args[0]
bucket_path = len(bucket_args) > 1 and bucket_args[1] or ""

objects = client.list_objects(bucket_name,
prefix=bucket_path,
recursive=True,
use_api_v1=True)
objects = list(client.list_objects(bucket_name,
prefix=bucket_path,
recursive=True,
use_api_v1=True))
count = 0
num = len(objects)
for obj in objects:
# Replace any prefix from the object key with out_dir
subdir_object_key = obj.object_name[len(bucket_path):].strip("/")
# fget_object handles directory creation if does not exist
if not obj.is_dir:
local_file = os.path.join(
out_dir,
subdir_object_key or os.path.basename(obj.object_name)
)
if num == 1 and not os.path.isdir(out_dir):
local_file = out_dir
else:
local_file = os.path.join(
out_dir,
subdir_object_key or os.path.basename(obj.object_name)
)
client.fget_object(bucket_name, obj.object_name, local_file)
count += 1

@@ -311,9 +358,10 @@ class FileOps:
def s3_download(cls, src, dst):
s3 = _create_minio_client()
count = cls._download_s3(s3, src, dst)

if count == 0:
raise RuntimeError("Failed to fetch files."
"The path %s does not exist." % (src))
"The path %s does not exist." % src)

@classmethod
def s3_upload(cls, src, dst):


+ 10
- 7
lib/sedna/core/lifelong_learning/lifelong_learning.py View File

@@ -15,12 +15,12 @@
import os
import tempfile

import joblib

from sedna.backend import set_backend
from sedna.core.base import JobBase
from sedna.common.file_ops import FileOps
from sedna.common.constant import K8sResourceKind, K8sResourceKindStatus
from sedna.common.constant import K8sResourceKind
from sedna.common.constant import K8sResourceKindStatus
from sedna.common.constant import KBResourceConstant
from sedna.common.config import Context
from sedna.common.class_factory import ClassType, ClassFactory
from sedna.algorithms.multi_task_learning import MulTaskLearning
@@ -67,7 +67,10 @@ class LifelongLearning(JobBase):
ll_kb_server=Context.get_parameters("KB_SERVER"),
output_url=Context.get_parameters("OUTPUT_URL", "/tmp")
)
task_index = FileOps.join_path(config['output_url'], 'index.pkl')
task_index = FileOps.join_path(
config['output_url'],
KBResourceConstant.KB_INDEX_NAME
)
config['task_index'] = task_index
super(LifelongLearning, self).__init__(
estimator=e, config=config
@@ -91,7 +94,7 @@ class LifelongLearning(JobBase):
if post_process is not None:
callback_func = ClassFactory.get_cls(
ClassType.CALLBACK, post_process)
res = self.estimator.train(
res, task_index_url = self.estimator.train(
train_data=train_data,
valid_data=valid_data,
**kwargs
@@ -107,7 +110,7 @@ class LifelongLearning(JobBase):
except Exception as err:
self.log.error(
f"Upload task extractor_file fail {extractor_file}: {err}")
extractor_file = joblib.load(extractor_file)
extractor_file = FileOps.load(extractor_file)
for task in task_groups:
try:
model = self.kb_server.upload_file(task.model.model)
@@ -123,7 +126,7 @@ class LifelongLearning(JobBase):
"extractor": extractor_file
}
fd, name = tempfile.mkstemp()
joblib.dump(task_info, name)
FileOps.dump(task_info, name)

index_file = self.kb_server.update_db(name)
if not index_file:


+ 1
- 2
lib/sedna/datasources/__init__.py View File

@@ -14,7 +14,6 @@

from abc import ABC

import joblib
import numpy as np
import pandas as pd

@@ -51,7 +50,7 @@ class BaseDataSource:
return self.data_type == "test"

def save(self, output=""):
joblib.dump(self, output)
return FileOps.dump(self, output)


class TxtDataParse(BaseDataSource, ABC):


+ 2
- 1
lib/sedna/service/run_kb.py View File

@@ -22,13 +22,14 @@ from sedna.service.server.knowledgeBase.server import KBServer
def main():
init_db()
server = os.getenv("KnowledgeBaseServer", "")
kb_dir = os.getenv("KnowledgeBasePath", "")
match = re.compile(
"(https?)://([0-9]{1,3}(?:\\.[0-9]{1,3}){3}):([0-9]+)").match(server)
if match:
_, host, port = match.groups()
else:
host, port = '0.0.0.0', 9020
KBServer(host=host, http_port=int(port)).start()
KBServer(host=host, http_port=int(port), save_dir=kb_dir).start()


if __name__ == '__main__':


+ 16
- 33
lib/sedna/service/server/knowledgeBase/server.py View File

@@ -27,6 +27,7 @@ from starlette.responses import JSONResponse

from sedna.service.server.base import BaseServer
from sedna.common.file_ops import FileOps
from sedna.common.constant import KBResourceConstant

from .model import *

@@ -52,7 +53,7 @@ class KBServer(BaseServer):
http_port=http_port, workers=workers)
self.save_dir = FileOps.clean_folder([save_dir], clean=False)[0]
self.url = f"{self.url}/{servername}"
self.latest = 0
self.kb_index = KBResourceConstant.KB_INDEX_NAME.value
self.app = FastAPI(
routes=[
APIRoute(
@@ -94,8 +95,7 @@ class KBServer(BaseServer):
pass

def _get_db_index(self):
_index_path = FileOps.join_path(self.save_dir,
f"kb_index_{self.latest}.pkl")
_index_path = FileOps.join_path(self.save_dir, self.kb_index)
if not FileOps.exists(_index_path): # todo: get from kb
pass
return _index_path
@@ -130,8 +130,7 @@ class KBServer(BaseServer):
}, synchronize_session=False)

# todo: get from kb
_index_path = FileOps.join_path(self.save_dir,
f"kb_index_{self.latest}.pkl")
_index_path = FileOps.join_path(self.save_dir, self.kb_index)
task_info = joblib.load(_index_path)
new_task_group = []

@@ -143,13 +142,9 @@ class KBServer(BaseServer):
continue
new_task_group.append(task_group)
task_info["task_groups"] = new_task_group
self.latest += 1

_index_path = FileOps.join_path(self.save_dir,
f"kb_index_{self.latest}.pkl")
joblib.dump(task_info, _index_path)
res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl"
return res
_index_path = FileOps.join_path(self.save_dir, self.kb_index)
FileOps.dump(task_info, _index_path)
return f"/file/download?files={self.kb_index}&name={self.kb_index}"

def update(self, task: UploadFile = File(...)):
tasks = task.file.read()
@@ -178,21 +173,16 @@ class KBServer(BaseServer):
if t_create:
session.add(t_obj)

sampel_obj = Samples(
sample_obj = Samples(
data_type=task.samples.data_type,
sample_num=len(task.samples)
sample_num=len(task.samples),
data_url=getattr(task, 'data_url', '')
)
session.add(sampel_obj)
session.add(sample_obj)

session.flush()
session.commit()
sample_dir = FileOps.join_path(
self.save_dir,
f"{sampel_obj.data_type}_{sampel_obj.id}.pkl")
task.samples.save(sample_dir)
sampel_obj.data_url = sample_dir

tsample = TaskSample(sample=sampel_obj, task=t_obj)
tsample = TaskSample(sample=sample_obj, task=t_obj)
session.add(tsample)
session.flush()
t_id.append(t_obj.id)
@@ -221,15 +211,8 @@ class KBServer(BaseServer):

session.commit()

self.latest += 1
extractor_file = upload_info["extractor"]
extractor_path = FileOps.join_path(self.save_dir,
f"kb_extractor.pkl")
FileOps.upload(extractor_file, extractor_path)

# todo: get from kb
_index_path = FileOps.join_path(self.save_dir,
f"kb_index_{self.latest}.pkl")
FileOps.upload(name, _index_path)
res = f"/file/download?files=kb_index_{self.latest}.pkl&name=index.pkl"
return res
_index_path = FileOps.join_path(self.save_dir, self.kb_index)
_index_path = FileOps.dump(upload_info, _index_path)

return f"/file/download?files={self.kb_index}&name={self.kb_index}"

Loading…
Cancel
Save