From 76d24fc680de7cf3890e4842e86d37dcab6cdbcd Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 4 Dec 2023 14:45:22 +0800 Subject: [PATCH] [MNT] for temp save --- learnware/tests/templates/__init__.py | 35 +++++++++++++++++++-------- learnware/utils/__init__.py | 2 +- learnware/utils/file.py | 13 +++++++++- 3 files changed, 38 insertions(+), 12 deletions(-) diff --git a/learnware/tests/templates/__init__.py b/learnware/tests/templates/__init__.py index 874cb1b..2d81821 100644 --- a/learnware/tests/templates/__init__.py +++ b/learnware/tests/templates/__init__.py @@ -1,17 +1,32 @@ import os import tempfile +from dataclasses import dataclass, field from shutil import copyfile from typing import List, Tuple, Union, Optional -from ...utils import save_dict_to_yaml +from ...utils import save_dict_to_yaml, convert_folder_to_zipfile from ...config import C -class LearnwareTemplate: + +@dataclass +class ModelTemplate: + class_name: str = field(repr=False) + template_path: str = field(repr=False) + model_kwargs: dict +@dataclass +class PickleModelTemplate(ModelTemplate): + pickle_filepath: str + def __post_init__(self): + self.class_name = "PickleLoadedModel" + self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") + +class TestTemplates: def __init__(self): self.model_templates = { "pickle": { "class_name": 'PickleLoadedModel', - "template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py") + "template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py"), + "parameters": {"pickle_filepath"}, } } @@ -47,10 +62,10 @@ class LearnwareTemplate: def generate_learnware_zipfile( self, learnware_zippath: str, - model_template: str = "pickle", - model_kwargs: Optional[dict] = None, + model_template: Union[ModelTemplate, PickleModelTemplate], stat_spec_config: Optional[List[dict]] = None, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None, + pickle_filepath: Optional[str] = None, **kwargs, ): with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: @@ -58,20 +73,20 @@ class LearnwareTemplate: self.generate_requirements(requirement_filepath, requirements) model_filepath = os.path.join(tempdir, "__init__.py") - copyfile(self.model_templates[model_template]["template_path"], model_filepath) + copyfile(model_template.template_path, model_filepath) learnware_yaml_filepath = os.path.join(tempdir, "requirements.txt") model_config = { - "class_name": self.model_templates[model_template]["class_name"], - "kwargs": {} if model_kwargs is None else model_kwargs + "class_name": model_template.class_name, + "kwargs": {} if model_template.model_kwargs is None else model_kwargs } self.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config) if model_template == "pickle": pickle_filepath = os.path.join(tempdir, model_config["kwargs"]["pickle_filepath"]) copyfile(kwargs["pickle_filepath"], pickle_filepath) - - + + def generate_template_semantic_spec(self): pass \ No newline at end of file diff --git a/learnware/utils/__init__.py b/learnware/utils/__init__.py index 5357aaf..b43d763 100644 --- a/learnware/utils/__init__.py +++ b/learnware/utils/__init__.py @@ -3,7 +3,7 @@ import zipfile from .import_utils import is_torch_available from .module import get_module_by_module_path -from .file import read_yaml_to_dict, save_dict_to_yaml +from .file import read_yaml_to_dict, save_dict_to_yaml, convert_folder_to_zipfile from .gpu import setup_seed, choose_device, allocate_cuda_idx from ..config import get_platform, SystemType diff --git a/learnware/utils/file.py b/learnware/utils/file.py index 27ba5f5..c0b4f77 100644 --- a/learnware/utils/file.py +++ b/learnware/utils/file.py @@ -1,5 +1,6 @@ +import os import yaml - +import zipfile def save_dict_to_yaml(dict_value: dict, save_path: str): """save dict object into yaml file""" @@ -12,3 +13,13 @@ def read_yaml_to_dict(yaml_path: str): with open(yaml_path, "r") as file: dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) return dict_value + +def convert_folder_to_zipfile(folder_path, zip_path): + with zipfile.ZipFile(zip_path, "w") as zip_obj: + for foldername, subfolders, filenames in os.walk(folder_path): + for filename in filenames: + file_path = os.path.join(foldername, filename) + zip_info = zipfile.ZipInfo(filename) + zip_info.compress_type = zipfile.ZIP_STORED + with open(file_path, "rb") as file: + zip_obj.writestr(zip_info, file.read())