Browse Source

[MNT] for temp save

tags/v0.3.2
bxdd 1 year ago
parent
commit
76d24fc680
3 changed files with 38 additions and 12 deletions
  1. +25
    -10
      learnware/tests/templates/__init__.py
  2. +1
    -1
      learnware/utils/__init__.py
  3. +12
    -1
      learnware/utils/file.py

+ 25
- 10
learnware/tests/templates/__init__.py View File

@@ -1,17 +1,32 @@
import os
import tempfile
from dataclasses import dataclass, field
from shutil import copyfile
from typing import List, Tuple, Union, Optional

from ...utils import save_dict_to_yaml
from ...utils import save_dict_to_yaml, convert_folder_to_zipfile
from ...config import C

class LearnwareTemplate:

@dataclass
class ModelTemplate:
class_name: str = field(repr=False)
template_path: str = field(repr=False)
model_kwargs: dict
@dataclass
class PickleModelTemplate(ModelTemplate):
pickle_filepath: str
def __post_init__(self):
self.class_name = "PickleLoadedModel"
self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py")
class TestTemplates:
def __init__(self):
self.model_templates = {
"pickle": {
"class_name": 'PickleLoadedModel',
"template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py")
"template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py"),
"parameters": {"pickle_filepath"},
}
}
@@ -47,10 +62,10 @@ class LearnwareTemplate:
def generate_learnware_zipfile(
self,
learnware_zippath: str,
model_template: str = "pickle",
model_kwargs: Optional[dict] = None,
model_template: Union[ModelTemplate, PickleModelTemplate],
stat_spec_config: Optional[List[dict]] = None,
requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None,
pickle_filepath: Optional[str] = None,
**kwargs,
):
with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir:
@@ -58,20 +73,20 @@ class LearnwareTemplate:
self.generate_requirements(requirement_filepath, requirements)
model_filepath = os.path.join(tempdir, "__init__.py")
copyfile(self.model_templates[model_template]["template_path"], model_filepath)
copyfile(model_template.template_path, model_filepath)
learnware_yaml_filepath = os.path.join(tempdir, "requirements.txt")
model_config = {
"class_name": self.model_templates[model_template]["class_name"],
"kwargs": {} if model_kwargs is None else model_kwargs
"class_name": model_template.class_name,
"kwargs": {} if model_template.model_kwargs is None else model_kwargs
}
self.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config)

if model_template == "pickle":
pickle_filepath = os.path.join(tempdir, model_config["kwargs"]["pickle_filepath"])
copyfile(kwargs["pickle_filepath"], pickle_filepath)

def generate_template_semantic_spec(self):
pass

+ 1
- 1
learnware/utils/__init__.py View File

@@ -3,7 +3,7 @@ import zipfile

from .import_utils import is_torch_available
from .module import get_module_by_module_path
from .file import read_yaml_to_dict, save_dict_to_yaml
from .file import read_yaml_to_dict, save_dict_to_yaml, convert_folder_to_zipfile
from .gpu import setup_seed, choose_device, allocate_cuda_idx
from ..config import get_platform, SystemType



+ 12
- 1
learnware/utils/file.py View File

@@ -1,5 +1,6 @@
import os
import yaml
import zipfile

def save_dict_to_yaml(dict_value: dict, save_path: str):
"""save dict object into yaml file"""
@@ -12,3 +13,13 @@ def read_yaml_to_dict(yaml_path: str):
with open(yaml_path, "r") as file:
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value

def convert_folder_to_zipfile(folder_path, zip_path):
with zipfile.ZipFile(zip_path, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(folder_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
zip_info = zipfile.ZipInfo(filename)
zip_info.compress_type = zipfile.ZIP_STORED
with open(file_path, "rb") as file:
zip_obj.writestr(zip_info, file.read())

Loading…
Cancel
Save