1.fix backend env bug; 2.fix s3 upload bug; 3.fix joint_inference bug; Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>tags/v0.3.1
| @@ -17,7 +17,7 @@ | |||||
| cd "$(dirname "${BASH_SOURCE[0]}")" | cd "$(dirname "${BASH_SOURCE[0]}")" | ||||
| IMAGE_REPO=${IMAGE_REPO:-kubeedge} | IMAGE_REPO=${IMAGE_REPO:-kubeedge} | ||||
| IMAGE_TAG=${IMAGE_TAG:-v0.1.0} | |||||
| IMAGE_TAG=${IMAGE_TAG:-v0.3.0} | |||||
| EXAMPLE_REPO_PREFIX=${IMAGE_REPO}/sedna-example- | EXAMPLE_REPO_PREFIX=${IMAGE_REPO}/sedna-example- | ||||
| @@ -6,6 +6,8 @@ RUN apt update \ | |||||
| COPY ./lib/requirements.txt /home | COPY ./lib/requirements.txt /home | ||||
| RUN pip install -r /home/requirements.txt | RUN pip install -r /home/requirements.txt | ||||
| RUN pip install keras~=2.4.3 | RUN pip install keras~=2.4.3 | ||||
| RUN pip install opencv-python==4.4.0.44 | |||||
| RUN pip install Pillow==8.0.1 | |||||
| ENV PYTHONPATH "/home/lib" | ENV PYTHONPATH "/home/lib" | ||||
| @@ -96,7 +96,7 @@ metadata: | |||||
| name: "surface-defect-detection-model" | name: "surface-defect-detection-model" | ||||
| spec: | spec: | ||||
| url: "/model" | url: "/model" | ||||
| format: "ckpt" | |||||
| format: "pb" | |||||
| EOF | EOF | ||||
| ``` | ``` | ||||
| @@ -82,6 +82,8 @@ class Estimator: | |||||
| self.model.set_weights(weights) | self.model.set_weights(weights) | ||||
| def load_weights(self, model): | def load_weights(self, model): | ||||
| if not os.path.isfile(model): | |||||
| return | |||||
| return self.model.load_weights(model) | return self.model.load_weights(model) | ||||
| def predict(self, datas): | def predict(self, datas): | ||||
| @@ -10,7 +10,8 @@ RUN pip install -r /home/requirements.txt | |||||
| # extra requirements for example | # extra requirements for example | ||||
| RUN pip install tqdm==4.56.0 | RUN pip install tqdm==4.56.0 | ||||
| RUN pip install matplotlib==3.3.3 | RUN pip install matplotlib==3.3.3 | ||||
| RUN pip install opencv-python==4.4.0.44 | |||||
| RUN pip install Pillow==8.0.1 | |||||
| ENV PYTHONPATH "/home/lib" | ENV PYTHONPATH "/home/lib" | ||||
| @@ -20,6 +20,7 @@ import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| from sedna.common.config import Context | from sedna.common.config import Context | ||||
| from sedna.common.file_ops import FileOps | |||||
| from sedna.core.incremental_learning import IncrementalLearning | from sedna.core.incremental_learning import IncrementalLearning | ||||
| from interface import Estimator | from interface import Estimator | ||||
| @@ -27,6 +28,8 @@ from interface import Estimator | |||||
| he_saved_url = Context.get_parameters("HE_SAVED_URL") | he_saved_url = Context.get_parameters("HE_SAVED_URL") | ||||
| class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] | class_names = ['person', 'helmet', 'helmet_on', 'helmet_off'] | ||||
| FileOps.clean_folder([he_saved_url], clean=False) | |||||
| def draw_boxes(img, labels, scores, bboxes, class_names, colors): | def draw_boxes(img, labels, scores, bboxes, class_names, colors): | ||||
| line_type = 2 | line_type = 2 | ||||
| @@ -69,12 +72,12 @@ def output_deal(is_hard_example, infer_result, nframe, img_rgb): | |||||
| img_rgb = np.array(img_rgb) | img_rgb = np.array(img_rgb) | ||||
| img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) | ||||
| colors = 'yellow,blue,green,red' | colors = 'yellow,blue,green,red' | ||||
| # if is_hard_example: | |||||
| label = 1 if is_hard_example else 0 | |||||
| if not is_hard_example: | |||||
| return | |||||
| lables, scores, bbox_list_pred = infer_result | lables, scores, bbox_list_pred = infer_result | ||||
| img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, | img = draw_boxes(img_rgb, lables, scores, bbox_list_pred, class_names, | ||||
| colors) | colors) | ||||
| cv2.imwrite(f"{he_saved_url}/{nframe}-{label}.jpeg", img) | |||||
| cv2.imwrite(f"{he_saved_url}/{nframe}.jpeg", img) | |||||
| def mkdir(path): | def mkdir(path): | ||||
| @@ -5,6 +5,8 @@ RUN apt update \ | |||||
| COPY ./lib/requirements.txt /home | COPY ./lib/requirements.txt /home | ||||
| RUN pip install -r /home/requirements.txt | RUN pip install -r /home/requirements.txt | ||||
| RUN pip install opencv-python==4.4.0.44 | |||||
| RUN pip install Pillow==8.0.1 | |||||
| ENV PYTHONPATH "/home/lib" | ENV PYTHONPATH "/home/lib" | ||||
| @@ -5,6 +5,8 @@ RUN apt update \ | |||||
| COPY ./lib/requirements.txt /home | COPY ./lib/requirements.txt /home | ||||
| RUN pip install -r /home/requirements.txt | RUN pip install -r /home/requirements.txt | ||||
| RUN pip install opencv-python==4.4.0.44 | |||||
| RUN pip install Pillow==8.0.1 | |||||
| ENV PYTHONPATH "/home/lib" | ENV PYTHONPATH "/home/lib" | ||||
| @@ -21,6 +21,7 @@ import tensorflow as tf | |||||
| import numpy as np | import numpy as np | ||||
| from sedna.common.config import Context | from sedna.common.config import Context | ||||
| from sedna.common.file_ops import FileOps | |||||
| from sedna.core.joint_inference import JointInference | from sedna.core.joint_inference import JointInference | ||||
| from interface import Estimator | from interface import Estimator | ||||
| @@ -40,6 +41,12 @@ hard_example_cloud_output_path = Context.get_parameters( | |||||
| 'hard_example_cloud_inference_output' | 'hard_example_cloud_inference_output' | ||||
| ) | ) | ||||
| FileOps.clean_folder([ | |||||
| all_output_path, | |||||
| hard_example_cloud_output_path, | |||||
| hard_example_edge_output_path | |||||
| ], clean=False) | |||||
| class InferenceResult: | class InferenceResult: | ||||
| """The Result class for joint inference | """The Result class for joint inference | ||||
| @@ -128,7 +128,7 @@ spec: | |||||
| type: DirectoryOrCreate | type: DirectoryOrCreate | ||||
| - name: inferdata | - name: inferdata | ||||
| hostPath: | hostPath: | ||||
| path: /lifelong/data/ | |||||
| path: /data/ | |||||
| type: DirectoryOrCreate | type: DirectoryOrCreate | ||||
| outputDir: "/output" | outputDir: "/output" | ||||
| EOF | EOF | ||||
| @@ -17,7 +17,7 @@ | |||||
| cd "$(dirname "${BASH_SOURCE[0]}")" | cd "$(dirname "${BASH_SOURCE[0]}")" | ||||
| export IMAGE_REPO=${IMAGE_REPO:-kubeedge} | export IMAGE_REPO=${IMAGE_REPO:-kubeedge} | ||||
| export IMAGE_TAG=${IMAGE_TAG:-v0.1.0} | |||||
| export IMAGE_TAG=${IMAGE_TAG:-v0.3.0} | |||||
| bash build_image.sh | bash build_image.sh | ||||
| @@ -1,16 +1,15 @@ | |||||
| numpy>=1.13.3 | |||||
| colorlog~=4.7.2 | |||||
| websockets~=9.1 | |||||
| requests==2.24.0 | |||||
| PyYAML~=5.4.1 | |||||
| numpy>=1.13.3 # BSD | |||||
| colorlog~=4.7.2 # MIT | |||||
| websockets~=9.1 # BSD | |||||
| requests==2.24.0 # Apache-2.0 | |||||
| PyYAML~=5.4.1 # MIT | |||||
| setuptools~=54.2.0 | setuptools~=54.2.0 | ||||
| fastapi~=0.63.0 | |||||
| starlette~=0.13.6 | |||||
| pydantic~=1.8.1 | |||||
| retrying~=1.3.3 | |||||
| joblib~=1.0.1 | |||||
| pandas~=1.1.5 | |||||
| six~=1.15.0 | |||||
| opencv-python==4.4.0.44 | |||||
| Pillow==8.0.1 | |||||
| uvicorn~=0.14.0 | |||||
| fastapi~=0.63.0 # MIT | |||||
| starlette~=0.13.6 # BSD | |||||
| pydantic~=1.8.1 # MIT | |||||
| retrying~=1.3.3 # Apache-2.0 | |||||
| joblib~=1.0.1 # BSD | |||||
| pandas~=1.1.5 # BSD | |||||
| six~=1.15.0 # MIT | |||||
| minio~=7.0.3 # Apache-2.0 | |||||
| uvicorn~=0.14.0 # BSD | |||||
| @@ -13,9 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| """Hard Example Mining Algorithms""" | """Hard Example Mining Algorithms""" | ||||
| import abc | import abc | ||||
| import math | import math | ||||
| from sedna.common.class_factory import ClassFactory, ClassType | from sedna.common.class_factory import ClassFactory, ClassType | ||||
| @@ -51,10 +49,13 @@ class ThresholdFilter(BaseFilter, abc.ABC): | |||||
| :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | :param infer_result: [N, 6], (x0, y0, x1, y1, score, class) | ||||
| :return: `True` means hard sample, `False` means not a hard sample. | :return: `True` means hard sample, `False` means not a hard sample. | ||||
| """ | """ | ||||
| if not infer_result: | |||||
| return True | |||||
| # if invalid input, return False | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| return False | |||||
| image_score = 0 | image_score = 0 | ||||
| for bbox in infer_result: | for bbox in infer_result: | ||||
| image_score += bbox[4] | image_score += bbox[4] | ||||
| @@ -87,23 +88,24 @@ class CrossEntropyFilter(BaseFilter, abc.ABC): | |||||
| not in [0,1]. | not in [0,1]. | ||||
| :return: `True` means a hard sample, `False` means not a hard sample. | :return: `True` means a hard sample, `False` means not a hard sample. | ||||
| """ | """ | ||||
| if infer_result is None: | |||||
| if not infer_result: | |||||
| # if invalid input, return False | |||||
| return False | return False | ||||
| elif len(infer_result) == 0: | |||||
| log_sum = 0.0 | |||||
| data_check_list = [class_probability for class_probability | |||||
| in infer_result | |||||
| if self.data_check(class_probability)] | |||||
| if len(data_check_list) != len(infer_result): | |||||
| return False | return False | ||||
| else: | |||||
| log_sum = 0.0 | |||||
| data_check_list = [class_probability for class_probability | |||||
| in infer_result | |||||
| if self.data_check(class_probability)] | |||||
| if len(data_check_list) == len(infer_result): | |||||
| for class_data in data_check_list: | |||||
| log_sum += class_data * math.log(class_data) | |||||
| confidence_score = 1 + 1.0 * log_sum / math.log( | |||||
| len(infer_result)) | |||||
| return confidence_score < self.threshold_cross_entropy | |||||
| else: | |||||
| return False | |||||
| for class_data in data_check_list: | |||||
| log_sum += class_data * math.log(class_data) | |||||
| confidence_score = 1 + 1.0 * log_sum / math.log( | |||||
| len(infer_result)) | |||||
| return confidence_score < self.threshold_cross_entropy | |||||
| @ClassFactory.register(ClassType.HEM, alias="IBT") | @ClassFactory.register(ClassType.HEM, alias="IBT") | ||||
| @@ -133,21 +135,19 @@ class IBTFilter(BaseFilter, abc.ABC): | |||||
| in [0,1]. | in [0,1]. | ||||
| :return: `True` means a hard sample, `False` means not a hard sample. | :return: `True` means a hard sample, `False` means not a hard sample. | ||||
| """ | """ | ||||
| if infer_result is None: | |||||
| if not (infer_result | |||||
| and all(map(lambda x: len(x) > 4, infer_result))): | |||||
| # if invalid input, return False | |||||
| return False | return False | ||||
| elif len(infer_result) == 0: | |||||
| data_check_list = [bbox[4] for bbox in infer_result | |||||
| if self.data_check(bbox[4])] | |||||
| if len(data_check_list) != len(infer_result): | |||||
| return False | return False | ||||
| else: | |||||
| data_check_list = [bbox[4] for bbox in infer_result | |||||
| if self.data_check(bbox[4])] | |||||
| if len(data_check_list) == len(infer_result): | |||||
| confidence_score_list = [ | |||||
| float(box_score) for box_score in data_check_list | |||||
| if float(box_score) <= self.threshold_box] | |||||
| if (len(confidence_score_list) / len(infer_result) | |||||
| >= (1 - self.threshold_img)): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| else: | |||||
| return False | |||||
| confidence_score_list = [ | |||||
| float(box_score) for box_score in data_check_list | |||||
| if float(box_score) <= self.threshold_box] | |||||
| return (len(confidence_score_list) / len(infer_result) | |||||
| >= (1 - self.threshold_img)) | |||||
| @@ -27,24 +27,32 @@ def set_backend(estimator=None, config=None): | |||||
| if config is None: | if config is None: | ||||
| config = BaseConfig() | config = BaseConfig() | ||||
| use_cuda = False | use_cuda = False | ||||
| backend_type = os.getenv( | |||||
| 'BACKEND_TYPE', config.get("backend_type", "UNKNOWN") | |||||
| ) | |||||
| backend_type = str(backend_type).upper() | |||||
| device_category = os.getenv( | |||||
| 'DEVICE_CATEGORY', config.get("device_category", "CPU") | |||||
| ) | |||||
| if 'CUDA_VISIBLE_DEVICES' in os.environ: | if 'CUDA_VISIBLE_DEVICES' in os.environ: | ||||
| os.environ['DEVICE_CATEGORY'] = 'GPU' | os.environ['DEVICE_CATEGORY'] = 'GPU' | ||||
| use_cuda = True | use_cuda = True | ||||
| if config.get("device_category"): | |||||
| os.environ['DEVICE_CATEGORY'] = config.get("device_category") | |||||
| if config.is_tf_backend(): | |||||
| else: | |||||
| os.environ['DEVICE_CATEGORY'] = device_category | |||||
| if backend_type == "TENSORFLOW": | |||||
| from sedna.backend.tensorflow import TFBackend as REGISTER | from sedna.backend.tensorflow import TFBackend as REGISTER | ||||
| elif config.is_kr_backend(): | |||||
| elif backend_type == "KERAS": | |||||
| from sedna.backend.tensorflow import KerasBackend as REGISTER | from sedna.backend.tensorflow import KerasBackend as REGISTER | ||||
| else: | else: | ||||
| backend_type = config.get("backend_type") or "UNKNOWN" | |||||
| warnings.warn(f"{backend_type} Not Support yet, use itself") | warnings.warn(f"{backend_type} Not Support yet, use itself") | ||||
| from sedna.backend.base import BackendBase as REGISTER | from sedna.backend.base import BackendBase as REGISTER | ||||
| model_save_url = config.model_url | |||||
| model_save_url = config.get("model_url") | |||||
| base_model_save = config.get("base_model_save") or model_save_url | base_model_save = config.get("base_model_save") or model_save_url | ||||
| model_save_name = config.model_name | |||||
| return REGISTER(estimator=estimator, use_cuda=use_cuda, | |||||
| model_save_path=base_model_save, | |||||
| model_name=model_save_name, | |||||
| model_save_url=model_save_url | |||||
| ) | |||||
| model_save_name = config.get("model_name") | |||||
| return REGISTER( | |||||
| estimator=estimator, use_cuda=use_cuda, | |||||
| model_save_path=base_model_save, | |||||
| model_name=model_save_name, | |||||
| model_save_url=model_save_url | |||||
| ) | |||||
| @@ -20,6 +20,16 @@ from sedna.backend.base import BackendBase | |||||
| from sedna.common.file_ops import FileOps | from sedna.common.file_ops import FileOps | ||||
| if hasattr(tf, "compat"): | |||||
| # version 2.0 tf | |||||
| ConfigProto = tf.compat.v1.ConfigProto | |||||
| Session = tf.compat.v1.Session | |||||
| else: | |||||
| # version 1 | |||||
| ConfigProto = tf.ConfigProto | |||||
| Session = tf.Session | |||||
| class TFBackend(BackendBase): | class TFBackend(BackendBase): | ||||
| def __init__(self, estimator, fine_tune=True, **kwargs): | def __init__(self, estimator, fine_tune=True, **kwargs): | ||||
| @@ -31,25 +41,24 @@ class TFBackend(BackendBase): | |||||
| self.graph = tf.Graph() | self.graph = tf.Graph() | ||||
| with self.graph.as_default(): | with self.graph.as_default(): | ||||
| self.sess = tf.compat.v1.Session(config=sess_config) | |||||
| self.sess = Session(config=sess_config) | |||||
| if callable(self.estimator): | if callable(self.estimator): | ||||
| self.estimator = self.estimator() | self.estimator = self.estimator() | ||||
| @staticmethod | @staticmethod | ||||
| def _init_cpu_session_config(): | def _init_cpu_session_config(): | ||||
| sess_config = tf.ConfigProto(allow_soft_placement=True) | |||||
| sess_config = ConfigProto(allow_soft_placement=True) | |||||
| return sess_config | return sess_config | ||||
| @staticmethod | @staticmethod | ||||
| def _init_gpu_session_config(): | def _init_gpu_session_config(): | ||||
| sess_config = tf.ConfigProto( | |||||
| sess_config = ConfigProto( | |||||
| log_device_placement=True, allow_soft_placement=True) | log_device_placement=True, allow_soft_placement=True) | ||||
| sess_config.gpu_options.per_process_gpu_memory_fraction = 0.7 | sess_config.gpu_options.per_process_gpu_memory_fraction = 0.7 | ||||
| sess_config.gpu_options.allow_growth = True | sess_config.gpu_options.allow_growth = True | ||||
| return sess_config | return sess_config | ||||
| def train(self, train_data, valid_data=None, **kwargs): | def train(self, train_data, valid_data=None, **kwargs): | ||||
| # self.sess.run(tf.global_variables_initializer()) | |||||
| if callable(self.estimator): | if callable(self.estimator): | ||||
| self.estimator = self.estimator() | self.estimator = self.estimator() | ||||
| if self.fine_tune and FileOps.exists(self.model_save_path): | if self.fine_tune and FileOps.exists(self.model_save_path): | ||||
| @@ -279,36 +279,6 @@ class BaseConfig(ConfigSerializable): | |||||
| if self.parameters: | if self.parameters: | ||||
| self.parameter = _url2dict(self.parameters) | self.parameter = _url2dict(self.parameters) | ||||
| @classmethod | |||||
| def is_gpu_device(cls): | |||||
| """Return whether is gpu device or not.""" | |||||
| return getattr(cls, 'device_category', None) == 'GPU' | |||||
| @classmethod | |||||
| def is_npu_device(cls): | |||||
| """Return whether is npu device or not.""" | |||||
| return getattr(cls, 'device_category', None) == 'NPU' | |||||
| @classmethod | |||||
| def is_torch_backend(cls): | |||||
| """Return whether is pytorch backend or not.""" | |||||
| return getattr(cls, 'backend_type', None) == 'PYTORCH' | |||||
| @classmethod | |||||
| def is_tf_backend(cls): | |||||
| """Return whether is tensorflow backend or not.""" | |||||
| return getattr(cls, 'backend_type', None) == 'TENSORFLOW' | |||||
| @classmethod | |||||
| def is_kr_backend(cls): | |||||
| """Return whether is keras backend or not.""" | |||||
| return getattr(cls, 'backend_type', None) == 'KERAS' | |||||
| @classmethod | |||||
| def is_ms_backend(cls): | |||||
| """Return whether is mindspore backend or not.""" | |||||
| return getattr(cls, 'backend_type', None) == 'MINDSPORE' | |||||
| class Context: | class Context: | ||||
| """The Context provides the capability of obtaining the context""" | """The Context provides the capability of obtaining the context""" | ||||
| @@ -22,6 +22,28 @@ import pickle | |||||
| import shutil | import shutil | ||||
| import tempfile | import tempfile | ||||
| import hashlib | import hashlib | ||||
| from urllib.parse import urlparse | |||||
| from .utils import singleton | |||||
| @singleton | |||||
| def _create_minio_client(): | |||||
| import minio | |||||
| _url = os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com") | |||||
| if not (_url.startswith("http://") or _url.startswith("https://")): | |||||
| _url = f"https://{_url}" | |||||
| url = urlparse(_url) | |||||
| use_ssl = url.scheme == 'https' if url.scheme else True | |||||
| s3 = minio.Minio( | |||||
| url.netloc, | |||||
| access_key=os.getenv("ACCESS_KEY_ID", ""), | |||||
| secret_key=os.getenv("SECRET_ACCESS_KEY", ""), | |||||
| secure=use_ssl | |||||
| ) | |||||
| return s3 | |||||
| class FileOps: | class FileOps: | ||||
| @@ -239,11 +261,11 @@ class FileOps: | |||||
| if tar: | if tar: | ||||
| cls._tar(src, f"{src}.tar.gz") | cls._tar(src, f"{src}.tar.gz") | ||||
| src = f"{src}.tar.gz" | src = f"{src}.tar.gz" | ||||
| if src.startswith(cls._GCS_PREFIX): | |||||
| if dst.startswith(cls._GCS_PREFIX): | |||||
| cls.gcs_upload(src, dst) | cls.gcs_upload(src, dst) | ||||
| elif src.startswith(cls._S3_PREFIX): | |||||
| elif dst.startswith(cls._S3_PREFIX): | |||||
| cls.s3_upload(src, dst) | cls.s3_upload(src, dst) | ||||
| elif cls.is_local(src): | |||||
| elif cls.is_local(dst): | |||||
| cls.copy_file(src, dst) | cls.copy_file(src, dst) | ||||
| return dst | return dst | ||||
| @@ -287,18 +309,7 @@ class FileOps: | |||||
| @classmethod | @classmethod | ||||
| def s3_download(cls, src, dst): | def s3_download(cls, src, dst): | ||||
| import minio | |||||
| from urllib.parse import urlparse | |||||
| url = urlparse(os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com")) | |||||
| use_ssl = url.scheme == 'https' if url.scheme else True | |||||
| s3 = minio.Minio( | |||||
| url.netloc, | |||||
| access_key=os.getenv("ACCESS_KEY_ID", ""), | |||||
| secret_key=os.getenv("SECRET_ACCESS_KEY", ""), | |||||
| secure=use_ssl | |||||
| ) | |||||
| s3 = _create_minio_client() | |||||
| count = cls._download_s3(s3, src, dst) | count = cls._download_s3(s3, src, dst) | ||||
| if count == 0: | if count == 0: | ||||
| raise RuntimeError("Failed to fetch files." | raise RuntimeError("Failed to fetch files." | ||||
| @@ -306,29 +317,30 @@ class FileOps: | |||||
| @classmethod | @classmethod | ||||
| def s3_upload(cls, src, dst): | def s3_upload(cls, src, dst): | ||||
| import minio | |||||
| from urllib.parse import urlparse | |||||
| url = urlparse(os.getenv("S3_ENDPOINT_URL", "http://s3.amazonaws.com")) | |||||
| use_ssl = url.scheme == 'https' if url.scheme else True | |||||
| s3 = minio.Minio( | |||||
| url.netloc, | |||||
| access_key=os.getenv("ACCESS_KEY_ID", ""), | |||||
| secret_key=os.getenv("SECRET_ACCESS_KEY", ""), | |||||
| secure=use_ssl | |||||
| ) | |||||
| s3 = _create_minio_client() | |||||
| parsed = urlparse(dst, scheme='s3') | parsed = urlparse(dst, scheme='s3') | ||||
| bucket_name = parsed.netloc | bucket_name = parsed.netloc | ||||
| def _s3_upload(_file, fname=""): | |||||
| _file_handle = open(_file, 'rb') | |||||
| _file_handle.seek(0, os.SEEK_END) | |||||
| size = _file_handle.tell() | |||||
| _file_handle.seek(0) | |||||
| if not fname: | |||||
| fname = os.path.basename(fname) | |||||
| s3.put_object(bucket_name, fname, _file_handle, size) | |||||
| _file_handle.close() | |||||
| return size | |||||
| if os.path.isdir(src): | if os.path.isdir(src): | ||||
| for root, _, files in os.walk(src): | for root, _, files in os.walk(src): | ||||
| for file in files: | for file in files: | ||||
| filepath = os.path.join(root, file) | filepath = os.path.join(root, file) | ||||
| with open(filepath, 'rb') as data: | |||||
| s3.put_object(bucket_name, file, data, -1) | |||||
| name = os.path.relpath(filepath, src) | |||||
| _s3_upload(filepath, name) | |||||
| elif os.path.isfile(src): | elif os.path.isfile(src): | ||||
| with open(src, 'rb') as data: | |||||
| s3.put_object(bucket_name, os.path.basename(src), data, -1) | |||||
| _s3_upload(src, parsed.path.lstrip("/")) | |||||
| @classmethod | @classmethod | ||||
| def http_download(cls, src, dst): | def http_download(cls, src, dst): | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| import os.path | import os.path | ||||
| import json | |||||
| from sedna.common.log import LOGGER | from sedna.common.log import LOGGER | ||||
| from sedna.common.file_ops import FileOps | from sedna.common.file_ops import FileOps | ||||
| @@ -70,6 +71,27 @@ class JobBase(DistributedWorker): | |||||
| work_name = f"{self.job_name}-{self.worker_id}" | work_name = f"{self.job_name}-{self.worker_id}" | ||||
| self.worker_name = self.config.worker_name or work_name | self.worker_name = self.config.worker_name or work_name | ||||
| @property | |||||
| def initial_hem(self): | |||||
| hem = self.get_parameters("HEM_NAME") | |||||
| hem_parameters = self.get_parameters("HEM_PARAMETERS") | |||||
| try: | |||||
| hem_parameters = json.loads(hem_parameters) | |||||
| hem_parameters = { | |||||
| p["key"]: p.get("value", "") | |||||
| for p in hem_parameters if "key" in p | |||||
| } | |||||
| except Exception as err: | |||||
| self.log.warn(f"Parse HEM_PARAMETERS failure, " | |||||
| f"fallback to empty: {err}") | |||||
| hem_parameters = {} | |||||
| if hem is None: | |||||
| hem = self.config.get("hem_name") or "IBT" | |||||
| return ClassFactory.get_cls(ClassType.HEM, hem)(**hem_parameters) | |||||
| @property | @property | ||||
| def model_path(self): | def model_path(self): | ||||
| if os.path.isfile(self.config.model_url): | if os.path.isfile(self.config.model_url): | ||||
| @@ -36,29 +36,7 @@ class IncrementalLearning(JobBase): | |||||
| "MODEL_URLS") # use in evaluation | "MODEL_URLS") # use in evaluation | ||||
| self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value | self.job_kind = K8sResourceKind.INCREMENTAL_JOB.value | ||||
| FileOps.clean_folder([self.config.model_url], clean=False) | FileOps.clean_folder([self.config.model_url], clean=False) | ||||
| hem = self.get_parameters("HEM_NAME") | |||||
| hem_parameters = self.get_parameters("HEM_PARAMETERS") | |||||
| try: | |||||
| hem_parameters = json.loads(hem_parameters) | |||||
| if isinstance(hem_parameters, (list, tuple)): | |||||
| if isinstance(hem_parameters[0], dict): | |||||
| hem_parameters = { | |||||
| p["key"]: p.get("value", "") | |||||
| for p in hem_parameters if "key" in p | |||||
| } | |||||
| else: | |||||
| hem_parameters = dict(hem_parameters) | |||||
| except Exception: | |||||
| hem_parameters = None | |||||
| if hem is None: | |||||
| hem = self.config.get("hem_name") or "IBT" | |||||
| if hem_parameters is None: | |||||
| hem_parameters = {} | |||||
| self.hard_example_mining_algorithm = ClassFactory.get_cls( | |||||
| ClassType.HEM, hem)(**hem_parameters) | |||||
| self.hard_example_mining_algorithm = self.initial_hem | |||||
| def train(self, train_data, | def train(self, train_data, | ||||
| valid_data=None, | valid_data=None, | ||||
| @@ -99,7 +77,7 @@ class IncrementalLearning(JobBase): | |||||
| is_hard_example = False | is_hard_example = False | ||||
| if self.hard_example_mining_algorithm: | if self.hard_example_mining_algorithm: | ||||
| is_hard_example = self.hard_example_mining_algorithm(infer_res) | |||||
| is_hard_example = self.hard_example_mining_algorithm(res) | |||||
| return infer_res, res, is_hard_example | return infer_res, res, is_hard_example | ||||
| def evaluate(self, data, post_process=None, **kwargs): | def evaluate(self, data, post_process=None, **kwargs): | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| import os | import os | ||||
| import json | |||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from sedna.common.utils import get_host_ip | from sedna.common.utils import get_host_ip | ||||
| @@ -81,8 +81,8 @@ class JointInference(JobBase): | |||||
| self.job_kind = K8sResourceKind.JOINT_INFERENCE_SERVICE.value | self.job_kind = K8sResourceKind.JOINT_INFERENCE_SERVICE.value | ||||
| self.local_ip = get_host_ip() | self.local_ip = get_host_ip() | ||||
| self.remote_ip = self.get_parameters( | self.remote_ip = self.get_parameters( | ||||
| "BIG_MODEL_BIND_IP", self.local_ip) | |||||
| self.port = int(self.get_parameters("BIG_MODEL_BIND_PORT", "5000")) | |||||
| "BIG_MODEL_IP", self.local_ip) | |||||
| self.port = int(self.get_parameters("BIG_MODEL_PORT", "5000")) | |||||
| report_msg = { | report_msg = { | ||||
| "name": self.worker_name, | "name": self.worker_name, | ||||
| @@ -93,7 +93,8 @@ class JointInference(JobBase): | |||||
| "results": [] | "results": [] | ||||
| } | } | ||||
| period_interval = int(self.get_parameters("LC_PERIOD", "30")) | period_interval = int(self.get_parameters("LC_PERIOD", "30")) | ||||
| self.lc_reporter = LCReporter(message=report_msg, | |||||
| self.lc_reporter = LCReporter(lc_server=self.config.lc_server, | |||||
| message=report_msg, | |||||
| period_interval=period_interval) | period_interval=period_interval) | ||||
| self.lc_reporter.setDaemon(True) | self.lc_reporter.setDaemon(True) | ||||
| self.lc_reporter.start() | self.lc_reporter.start() | ||||
| @@ -106,6 +107,7 @@ class JointInference(JobBase): | |||||
| self.estimator.load(self.model_path) | self.estimator.load(self.model_path) | ||||
| self.cloud = ModelClient(service_name=self.job_name, | self.cloud = ModelClient(service_name=self.job_name, | ||||
| host=self.remote_ip, port=self.port) | host=self.remote_ip, port=self.port) | ||||
| self.hard_example_mining_algorithm = self.initial_hem | |||||
| def train(self, train_data, | def train(self, train_data, | ||||
| valid_data=None, | valid_data=None, | ||||
| @@ -128,24 +130,12 @@ class JointInference(JobBase): | |||||
| res = callback_func(res) | res = callback_func(res) | ||||
| self.lc_reporter.update_for_edge_inference() | self.lc_reporter.update_for_edge_inference() | ||||
| hem = self.get_parameters("HEM_NAME") | |||||
| hem_parameters = self.get_parameters("HEM_PARAMETERS") | |||||
| if hem is None: | |||||
| hem = self.config.get("hem_name") or "IBT" | |||||
| if hem_parameters is None: | |||||
| hem_parameters = {} | |||||
| is_hard_example = False | is_hard_example = False | ||||
| cloud_result = None | cloud_result = None | ||||
| try: | |||||
| hard_example_mining_algorithm = ClassFactory.get_cls( | |||||
| ClassType.HEM, hem)() | |||||
| except ValueError as err: | |||||
| self.log.error("Joint Inference [HEM] : {}".format(err)) | |||||
| else: | |||||
| is_hard_example = hard_example_mining_algorithm( | |||||
| res, **hem_parameters) | |||||
| if self.hard_example_mining_algorithm: | |||||
| is_hard_example = self.hard_example_mining_algorithm(res) | |||||
| if is_hard_example: | if is_hard_example: | ||||
| cloud_result = self.cloud.inference( | cloud_result = self.cloud.inference( | ||||
| data.tolist(), post_process=post_process, **kwargs) | data.tolist(), post_process=post_process, **kwargs) | ||||
| @@ -130,7 +130,7 @@ class LifelongLearning(JobBase): | |||||
| self.log.error(f"KB update Fail !") | self.log.error(f"KB update Fail !") | ||||
| index_file = name | index_file = name | ||||
| FileOps.download(index_file, self.config.task_index) | |||||
| FileOps.upload(index_file, self.config.task_index) | |||||
| if os.path.isfile(name): | if os.path.isfile(name): | ||||
| os.close(fd) | os.close(fd) | ||||
| os.remove(name) | os.remove(name) | ||||
| @@ -184,7 +184,7 @@ class LifelongLearning(JobBase): | |||||
| index_file = str(index_url) | index_file = str(index_url) | ||||
| self.log.info( | self.log.info( | ||||
| f"upload kb index from {index_file} to {self.config.task_index}") | f"upload kb index from {index_file} to {self.config.task_index}") | ||||
| FileOps.download(index_file, self.config.task_index) | |||||
| FileOps.upload(index_file, self.config.task_index) | |||||
| task_info_res = self.estimator.model_info( | task_info_res = self.estimator.model_info( | ||||
| self.config.task_index, result=res, | self.config.task_index, result=res, | ||||
| relpath=self.config.data_path_prefix) | relpath=self.config.data_path_prefix) | ||||
| @@ -54,13 +54,14 @@ class LCReporter(threading.Thread): | |||||
| the lc. | the lc. | ||||
| """ | """ | ||||
| def __init__(self, message, period_interval=30): | |||||
| def __init__(self, lc_server, message, period_interval=30): | |||||
| threading.Thread.__init__(self) | threading.Thread.__init__(self) | ||||
| # the value of statistics | # the value of statistics | ||||
| self.inference_number = 0 | self.inference_number = 0 | ||||
| self.hard_example_number = 0 | self.hard_example_number = 0 | ||||
| self.period_interval = period_interval | self.period_interval = period_interval | ||||
| self.lc_server = lc_server | |||||
| # The system resets the period_increment after sending the messages to | # The system resets the period_increment after sending the messages to | ||||
| # the LC. If the period_increment is 0 in the current period, | # the LC. If the period_increment is 0 in the current period, | ||||
| # the system does not send the messages to the LC. | # the system does not send the messages to the LC. | ||||
| @@ -99,9 +100,10 @@ class LCReporter(threading.Thread): | |||||
| "hardExampleNumber": self.hard_example_number, | "hardExampleNumber": self.hard_example_number, | ||||
| "uploadCloudRatio": self.hard_example_number / | "uploadCloudRatio": self.hard_example_number / | ||||
| self.inference_number} | self.inference_number} | ||||
| message = deepcopy(self.message) | |||||
| message["ownerInfo"] = info | |||||
| LCClient.send(message["ownerName"], message["name"], message) | |||||
| self.message["ownerInfo"] = info | |||||
| LCClient.send(self.lc_server, | |||||
| self.message["name"], | |||||
| self.message) | |||||
| self.period_increment = 0 | self.period_increment = 0 | ||||
| @@ -178,18 +180,13 @@ class AggregationClient: | |||||
| raise | raise | ||||
| async def _send(self, data): | async def _send(self, data): | ||||
| error = "" | |||||
| for _ in range(self._retry): | for _ in range(self._retry): | ||||
| try: | try: | ||||
| await self.ws.send(data) | await self.ws.send(data) | ||||
| result = await self.ws.recv() | result = await self.ws.recv() | ||||
| return result | return result | ||||
| except Exception as e: | |||||
| error = e | |||||
| LOGGER.warning(f"send data error: {error}") | |||||
| except Exception: | |||||
| time.sleep(self._retry_interval_seconds) | time.sleep(self._retry_interval_seconds) | ||||
| LOGGER.error( | |||||
| f"websocket error: {error}, retry times: {self._retry}") | |||||
| return None | return None | ||||
| def send(self, data, msg_type="message", job_name=""): | def send(self, data, msg_type="message", job_name=""): | ||||
| @@ -200,7 +197,6 @@ class AggregationClient: | |||||
| }) | }) | ||||
| data_json = loop.run_until_complete(self._send(j)) | data_json = loop.run_until_complete(self._send(j)) | ||||
| if data_json is None: | if data_json is None: | ||||
| LOGGER.error(f"send {msg_type} to agg worker failed") | |||||
| return | return | ||||
| res = json.loads(data_json) | res = json.loads(data_json) | ||||
| return res | return res | ||||
| @@ -134,8 +134,12 @@ class KBServer(BaseServer): | |||||
| f"kb_index_{self.latest}.pkl") | f"kb_index_{self.latest}.pkl") | ||||
| task_info = joblib.load(_index_path) | task_info = joblib.load(_index_path) | ||||
| new_task_group = [] | new_task_group = [] | ||||
| default_task = task_info["task_groups"][0] | |||||
| # todo: get from transfer learning | |||||
| for task_group in task_info["task_groups"]: | for task_group in task_info["task_groups"]: | ||||
| if not ((task_group.entry in tasks) == deploy): | if not ((task_group.entry in tasks) == deploy): | ||||
| new_task_group.append(default_task) | |||||
| continue | continue | ||||
| new_task_group.append(task_group) | new_task_group.append(task_group) | ||||
| task_info["task_groups"] = new_task_group | task_info["task_groups"] = new_task_group | ||||