|
|
|
@@ -1,17 +1,32 @@ |
|
|
|
import os |
|
|
|
import tempfile |
|
|
|
from dataclasses import dataclass, field |
|
|
|
from shutil import copyfile |
|
|
|
from typing import List, Tuple, Union, Optional |
|
|
|
|
|
|
|
from ...utils import save_dict_to_yaml |
|
|
|
from ...utils import save_dict_to_yaml, convert_folder_to_zipfile |
|
|
|
from ...config import C |
|
|
|
|
|
|
|
class LearnwareTemplate: |
|
|
|
|
|
|
|
@dataclass |
|
|
|
class ModelTemplate: |
|
|
|
class_name: str = field(repr=False) |
|
|
|
template_path: str = field(repr=False) |
|
|
|
model_kwargs: dict |
|
|
|
@dataclass |
|
|
|
class PickleModelTemplate(ModelTemplate): |
|
|
|
pickle_filepath: str |
|
|
|
def __post_init__(self): |
|
|
|
self.class_name = "PickleLoadedModel" |
|
|
|
self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") |
|
|
|
|
|
|
|
class TestTemplates: |
|
|
|
def __init__(self): |
|
|
|
self.model_templates = { |
|
|
|
"pickle": { |
|
|
|
"class_name": 'PickleLoadedModel', |
|
|
|
"template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py") |
|
|
|
"template_path": os.path.join(C.package_path, "tests", "templates", "pickle_model.py"), |
|
|
|
"parameters": {"pickle_filepath"}, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -47,10 +62,10 @@ class LearnwareTemplate: |
|
|
|
def generate_learnware_zipfile( |
|
|
|
self, |
|
|
|
learnware_zippath: str, |
|
|
|
model_template: str = "pickle", |
|
|
|
model_kwargs: Optional[dict] = None, |
|
|
|
model_template: Union[ModelTemplate, PickleModelTemplate], |
|
|
|
stat_spec_config: Optional[List[dict]] = None, |
|
|
|
requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None, |
|
|
|
pickle_filepath: Optional[str] = None, |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: |
|
|
|
@@ -58,20 +73,20 @@ class LearnwareTemplate: |
|
|
|
self.generate_requirements(requirement_filepath, requirements) |
|
|
|
|
|
|
|
model_filepath = os.path.join(tempdir, "__init__.py") |
|
|
|
copyfile(self.model_templates[model_template]["template_path"], model_filepath) |
|
|
|
copyfile(model_template.template_path, model_filepath) |
|
|
|
|
|
|
|
learnware_yaml_filepath = os.path.join(tempdir, "requirements.txt") |
|
|
|
model_config = { |
|
|
|
"class_name": self.model_templates[model_template]["class_name"], |
|
|
|
"kwargs": {} if model_kwargs is None else model_kwargs |
|
|
|
"class_name": model_template.class_name, |
|
|
|
"kwargs": {} if model_template.model_kwargs is None else model_kwargs |
|
|
|
} |
|
|
|
self.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config) |
|
|
|
|
|
|
|
if model_template == "pickle": |
|
|
|
pickle_filepath = os.path.join(tempdir, model_config["kwargs"]["pickle_filepath"]) |
|
|
|
copyfile(kwargs["pickle_filepath"], pickle_filepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_template_semantic_spec(self): |
|
|
|
pass |