Browse Source

[MNT] for temp save

tags/v0.3.2
bxdd 1 year ago
parent
commit
76d24fc680
3 changed files with 38 additions and 12 deletions
  1. +25
    -10
      learnware/tests/templates/__init__.py
  2. +1
    -1
      learnware/utils/__init__.py
  3. +12
    -1
      learnware/utils/file.py

+ 25
- 10
learnware/tests/templates/__init__.py View File

@@ -1,17 +1,32 @@
import os import os
import tempfile import tempfile
from dataclasses import dataclass, field
from shutil import copyfile from shutil import copyfile
from typing import List, Tuple, Union, Optional from typing import List, Tuple, Union, Optional


from ...utils import save_dict_to_yaml
from ...utils import save_dict_to_yaml, convert_folder_to_zipfile
from ...config import C from ...config import C


class LearnwareTemplate:

@dataclass
class ModelTemplate:
class_name: str = field(repr=False)
template_path: str = field(repr=False)
model_kwargs: dict
@dataclass
class PickleModelTemplate(ModelTemplate):
pickle_filepath: str
def __post_init__(self):
self.class_name = "PickleLoadedModel"
self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py")
class TestTemplates:
def __init__(self): def __init__(self):
self.model_templates = { self.model_templates = {
"pickle": { "pickle": {
"class_name": 'PickleLoadedModel', "class_name": 'PickleLoadedModel',
"template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py")
"template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py"),
"parameters": {"pickle_filepath"},
} }
} }
@@ -47,10 +62,10 @@ class LearnwareTemplate:
def generate_learnware_zipfile( def generate_learnware_zipfile(
self, self,
learnware_zippath: str, learnware_zippath: str,
model_template: str = "pickle",
model_kwargs: Optional[dict] = None,
model_template: Union[ModelTemplate, PickleModelTemplate],
stat_spec_config: Optional[List[dict]] = None, stat_spec_config: Optional[List[dict]] = None,
requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None,
pickle_filepath: Optional[str] = None,
**kwargs, **kwargs,
): ):
with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir:
@@ -58,20 +73,20 @@ class LearnwareTemplate:
self.generate_requirements(requirement_filepath, requirements) self.generate_requirements(requirement_filepath, requirements)
model_filepath = os.path.join(tempdir, "__init__.py") model_filepath = os.path.join(tempdir, "__init__.py")
copyfile(self.model_templates[model_template]["template_path"], model_filepath)
copyfile(model_template.template_path, model_filepath)
learnware_yaml_filepath = os.path.join(tempdir, "requirements.txt") learnware_yaml_filepath = os.path.join(tempdir, "requirements.txt")
model_config = { model_config = {
"class_name": self.model_templates[model_template]["class_name"],
"kwargs": {} if model_kwargs is None else model_kwargs
"class_name": model_template.class_name,
"kwargs": {} if model_template.model_kwargs is None else model_kwargs
} }
self.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config) self.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config)


if model_template == "pickle": if model_template == "pickle":
pickle_filepath = os.path.join(tempdir, model_config["kwargs"]["pickle_filepath"]) pickle_filepath = os.path.join(tempdir, model_config["kwargs"]["pickle_filepath"])
copyfile(kwargs["pickle_filepath"], pickle_filepath) copyfile(kwargs["pickle_filepath"], pickle_filepath)


def generate_template_semantic_spec(self): def generate_template_semantic_spec(self):
pass pass

+ 1
- 1
learnware/utils/__init__.py View File

@@ -3,7 +3,7 @@ import zipfile


from .import_utils import is_torch_available from .import_utils import is_torch_available
from .module import get_module_by_module_path from .module import get_module_by_module_path
from .file import read_yaml_to_dict, save_dict_to_yaml
from .file import read_yaml_to_dict, save_dict_to_yaml, convert_folder_to_zipfile
from .gpu import setup_seed, choose_device, allocate_cuda_idx from .gpu import setup_seed, choose_device, allocate_cuda_idx
from ..config import get_platform, SystemType from ..config import get_platform, SystemType




+ 12
- 1
learnware/utils/file.py View File

@@ -1,5 +1,6 @@
import os
import yaml import yaml
import zipfile


def save_dict_to_yaml(dict_value: dict, save_path: str): def save_dict_to_yaml(dict_value: dict, save_path: str):
"""save dict object into yaml file""" """save dict object into yaml file"""
@@ -12,3 +13,13 @@ def read_yaml_to_dict(yaml_path: str):
with open(yaml_path, "r") as file: with open(yaml_path, "r") as file:
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value return dict_value

def convert_folder_to_zipfile(folder_path, zip_path):
with zipfile.ZipFile(zip_path, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(folder_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
zip_info = zipfile.ZipInfo(filename)
zip_info.compress_type = zipfile.ZIP_STORED
with open(file_path, "rb") as file:
zip_obj.writestr(zip_info, file.read())

Loading…
Cancel
Save