| @@ -100,12 +100,12 @@ html_logo = "_static/img/logo/logo1.png" | |||
| # These folders are copied to the documentation's HTML output | |||
| html_static_path = ['_static'] | |||
| html_static_path = ["_static"] | |||
| # These paths are either relative to html_static_path | |||
| # or fully qualified paths (eg. https://...) | |||
| html_css_files = [ | |||
| 'css/custom_style.css', | |||
| "css/custom_style.css", | |||
| ] | |||
| # -- Options for HTMLHelp output ------------------------------------------ | |||
| @@ -85,9 +85,7 @@ def get_split_errs(algo): | |||
| split = train_xs.shape[0] - proportion_list[tmp] | |||
| model.fit( | |||
| train_xs[ | |||
| split:, | |||
| ], | |||
| train_xs[split:,], | |||
| train_ys[split:], | |||
| eval_set=[(val_xs, val_ys)], | |||
| early_stopping_rounds=50, | |||
| @@ -86,7 +86,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||
| pass | |||
| except Exception as err: | |||
| logger.error(err) | |||
| return None | |||
| exist_packages = [] | |||
| @@ -101,7 +101,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||
| exist_packages.append(result) | |||
| else: | |||
| nonexist_packages.append(package) | |||
| if len(nonexist_packages) > 0: | |||
| logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.") | |||
| return exist_packages, nonexist_packages | |||
| @@ -13,7 +13,9 @@ from ..utils import read_yaml_to_dict | |||
| logger = get_module_logger("learnware.learnware") | |||
| def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Optional[Learnware]: | |||
| def get_learnware_from_dirpath( | |||
| id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True | |||
| ) -> Optional[Learnware]: | |||
| """Get the learnware object from dirpath, and provide the manage interface tor Learnware class | |||
| Parameters | |||
| @@ -46,11 +48,11 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||
| } | |||
| try: | |||
| learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"]) | |||
| assert os.path.exists(learnware_yaml_path), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." | |||
| assert os.path.exists( | |||
| learnware_yaml_path | |||
| ), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." | |||
| yaml_config = read_yaml_to_dict(learnware_yaml_path) | |||
| if "name" in yaml_config: | |||
| @@ -67,8 +69,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||
| for _stat_spec in learnware_config["stat_specifications"]: | |||
| stat_spec = _stat_spec.copy() | |||
| stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"]) | |||
| assert os.path.exists(stat_spec_path), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." | |||
| assert os.path.exists( | |||
| stat_spec_path | |||
| ), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." | |||
| stat_spec["file_name"] = stat_spec_path | |||
| stat_spec_inst = get_stat_spec_from_config(stat_spec) | |||
| learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst}) | |||
| @@ -1,9 +1,7 @@ | |||
| from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo | |||
| from .base import (BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, | |||
| LearnwareMarket) | |||
| from .base import BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, LearnwareMarket | |||
| from .classes import CondaChecker | |||
| from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker, | |||
| EasyStatChecker) | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .evolve import EvolvedOrganizer | |||
| from .evolve_anchor import EvolvedAnchoredOrganizer | |||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| @@ -11,5 +11,4 @@ if not is_torch_available(verbose=False): | |||
| logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") | |||
| else: | |||
| from .checker import EasySemanticChecker, EasyStatChecker | |||
| from .searcher import (EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, | |||
| EasySearcher, EasyStatSearcher) | |||
| from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasySearcher, EasyStatSearcher | |||
| @@ -6,13 +6,11 @@ import torch | |||
| from rapidfuzz import fuzz | |||
| from .organizer import EasyOrganizer | |||
| from ..base import (BaseSearcher, BaseUserInfo, MultipleSearchItem, | |||
| SearchResults, SingleSearchItem) | |||
| from ..base import BaseSearcher, BaseUserInfo, MultipleSearchItem, SearchResults, SingleSearchItem | |||
| from ..utils import parse_specification_type | |||
| from ...learnware import Learnware | |||
| from ...logger import get_module_logger | |||
| from ...specification import (RKMEImageSpecification, RKMETableSpecification, | |||
| RKMETextSpecification, rkme_solve_qp) | |||
| from ...specification import RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification, rkme_solve_qp | |||
| logger = get_module_logger("easy_seacher") | |||
| @@ -8,8 +8,7 @@ from torch import nn | |||
| from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer | |||
| from .trainer import Trainer, TransTabCollatorForCL | |||
| from .....specification import (HeteroMapTableSpecification, | |||
| RKMETableSpecification) | |||
| from .....specification import HeteroMapTableSpecification, RKMETableSpecification | |||
| from .....utils import allocate_cuda_idx, choose_device | |||
| @@ -288,7 +287,7 @@ class HeteroMap(nn.Module): | |||
| # go through transformers, get the first cls embedding | |||
| encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim | |||
| output_features = encoder_output[:, 0, :] | |||
| del inputs, outputs, encoder_output | |||
| torch.cuda.empty_cache() | |||
| @@ -11,7 +11,11 @@ logger = get_module_logger("hetero_searcher") | |||
| class HeteroSearcher(EasySearcher): | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy" | |||
| self, | |||
| user_info: BaseUserInfo, | |||
| check_status: Optional[int] = None, | |||
| max_search_num: int = 5, | |||
| search_method: str = "greedy", | |||
| ) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status. | |||
| Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | |||
| @@ -1,11 +1,12 @@ | |||
| from .base import LearnwareMarket | |||
| from .classes import CondaChecker | |||
| from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker, | |||
| EasyStatChecker) | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False): | |||
| def get_market_component( | |||
| name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False | |||
| ): | |||
| organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs | |||
| searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs | |||
| checker_kwargs = {} if checker_kwargs is None else checker_kwargs | |||
| @@ -13,7 +14,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||
| if name == "easy": | |||
| easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) | |||
| easy_searcher = EasySearcher(organizer=easy_organizer) | |||
| easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||
| easy_checker_list = [ | |||
| EasySemanticChecker(), | |||
| EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), | |||
| ] | |||
| market_component = { | |||
| "organizer": easy_organizer, | |||
| "searcher": easy_searcher, | |||
| @@ -22,7 +26,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||
| elif name == "hetero": | |||
| hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) | |||
| hetero_searcher = HeteroSearcher(organizer=hetero_organizer) | |||
| hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||
| hetero_checker_list = [ | |||
| EasySemanticChecker(), | |||
| EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), | |||
| ] | |||
| market_component = { | |||
| "organizer": hetero_organizer, | |||
| @@ -45,7 +52,9 @@ def instantiate_learnware_market( | |||
| conda_checker: bool = False, | |||
| **kwargs, | |||
| ): | |||
| market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker) | |||
| market_componets = get_market_component( | |||
| name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker | |||
| ) | |||
| return LearnwareMarket( | |||
| organizer=market_componets["organizer"], | |||
| searcher=market_componets["searcher"], | |||
| @@ -20,4 +20,4 @@ else: | |||
| from .ensemble_pruning import EnsemblePruningReuser | |||
| from .feature_augment import FeatureAugmentReuser | |||
| from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware | |||
| from .job_selector import JobSelectorReuser | |||
| from .job_selector import JobSelectorReuser | |||
| @@ -54,13 +54,14 @@ class EnsemblePruningReuser(BaseReuser): | |||
| np.ndarray | |||
| Binary one-dimensional vector, 1 indicates that the corresponding model is selected. | |||
| """ | |||
| try: | |||
| import geatpy as ea | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") | |||
| raise ModuleNotFoundError( | |||
| f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." | |||
| ) | |||
| model_num = v_predict.shape[1] | |||
| @ea.Problem.single | |||
| @@ -148,7 +149,9 @@ class EnsemblePruningReuser(BaseReuser): | |||
| try: | |||
| import geatpy as ea | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") | |||
| raise ModuleNotFoundError( | |||
| f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." | |||
| ) | |||
| if torch.is_tensor(v_true): | |||
| v_true = v_true.detach().cpu().numpy() | |||
| @@ -270,8 +273,10 @@ class EnsemblePruningReuser(BaseReuser): | |||
| try: | |||
| import geatpy as ea | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") | |||
| raise ModuleNotFoundError( | |||
| f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." | |||
| ) | |||
| model_num = v_predict.shape[1] | |||
| v_predict[v_predict == 0.0] = -1 | |||
| v_true[v_true == 0.0] = -1 | |||
| @@ -8,8 +8,7 @@ from .base import BaseReuser | |||
| from ..learnware import Learnware | |||
| from ..logger import get_module_logger | |||
| from ..market.utils import parse_specification_type | |||
| from ..specification import (RKMETableSpecification, RKMETextSpecification, | |||
| generate_rkme_table_spec, rkme_solve_qp) | |||
| from ..specification import RKMETableSpecification, RKMETextSpecification, generate_rkme_table_spec, rkme_solve_qp | |||
| logger = get_module_logger("job_selector_reuse") | |||
| @@ -4,6 +4,7 @@ from ..logger import get_module_logger | |||
| logger = get_module_logger("reuse_utils") | |||
| def fill_data_with_mean(X: np.ndarray) -> np.ndarray: | |||
| """ | |||
| Fill missing data (NaN, Inf) in the input array with the mean of the column. | |||
| @@ -1,7 +1,12 @@ | |||
| from .base import BaseStatSpecification, Specification | |||
| from .regular import (RegularStatSpecification, RKMEImageSpecification, | |||
| RKMEStatSpecification, RKMETableSpecification, | |||
| RKMETextSpecification, rkme_solve_qp) | |||
| from .regular import ( | |||
| RegularStatSpecification, | |||
| RKMEImageSpecification, | |||
| RKMEStatSpecification, | |||
| RKMETableSpecification, | |||
| RKMETextSpecification, | |||
| rkme_solve_qp, | |||
| ) | |||
| from .system import HeteroMapTableSpecification | |||
| from ..utils import is_torch_available | |||
| @@ -12,6 +17,10 @@ if not is_torch_available(verbose=False): | |||
| generate_rkme_text_spec = None | |||
| generate_semantic_spec = None | |||
| else: | |||
| from .module import (generate_rkme_image_spec, generate_rkme_table_spec, | |||
| generate_rkme_text_spec, generate_semantic_spec, | |||
| generate_stat_spec) | |||
| from .module import ( | |||
| generate_rkme_image_spec, | |||
| generate_rkme_table_spec, | |||
| generate_rkme_text_spec, | |||
| generate_semantic_spec, | |||
| generate_stat_spec, | |||
| ) | |||
| @@ -11,5 +11,4 @@ if not is_torch_available(verbose=False): | |||
| f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!" | |||
| ) | |||
| else: | |||
| from .rkme import (RKMEStatSpecification, RKMETableSpecification, | |||
| rkme_solve_qp) | |||
| from .rkme import RKMEStatSpecification, RKMETableSpecification, rkme_solve_qp | |||
| @@ -87,12 +87,14 @@ class RKMETextSpecification(RKMETableSpecification): | |||
| return np.array(miniLM_learnware.predict(X)) | |||
| logger.info("Load the necessary feature extractor for RKMETextSpecification.") | |||
| try: | |||
| from sentence_transformers import SentenceTransformer | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.") | |||
| raise ModuleNotFoundError( | |||
| f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually." | |||
| ) | |||
| if os.path.exists(zip_path): | |||
| X = _get_from_client(zip_path, X) | |||
| else: | |||
| @@ -137,7 +137,9 @@ class HeteroMapTableSpecification(SystemStatSpecification): | |||
| for d in self.get_states(): | |||
| if d in embedding_load.keys(): | |||
| if d == "type" and embedding_load[d] != self.type: | |||
| raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!") | |||
| raise TypeError( | |||
| f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!" | |||
| ) | |||
| setattr(self, d, embedding_load[d]) | |||
| def save(self, filepath: str) -> bool: | |||
| @@ -1 +1 @@ | |||
| from .utils import parametrize | |||
| from .utils import parametrize | |||
| @@ -13,10 +13,13 @@ class ModelTemplate: | |||
| class_name: str = field(init=False) | |||
| template_path: str = field(init=False) | |||
| model_kwargs: dict = field(init=False) | |||
| @dataclass | |||
| class PickleModelTemplate(ModelTemplate): | |||
| model_kwargs: dict | |||
| pickle_filepath: str | |||
| def __post_init__(self): | |||
| self.class_name = "PickleLoadedModel" | |||
| self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") | |||
| @@ -29,13 +32,14 @@ class PickleModelTemplate(ModelTemplate): | |||
| default_model_kwargs.update(self.model_kwargs) | |||
| self.model_kwargs = default_model_kwargs | |||
| @dataclass | |||
| class StatSpecTemplate: | |||
| filepath: str | |||
| type: str = field(default="RKMETableSpecification") | |||
| class LearnwareTemplate: | |||
| class LearnwareTemplate: | |||
| @staticmethod | |||
| def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None): | |||
| requirements = [] if requirements is None else requirements | |||
| @@ -49,14 +53,16 @@ class LearnwareTemplate: | |||
| line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n" | |||
| else: | |||
| raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}") | |||
| requirements_str += line_str | |||
| with open(filepath, "w") as fdout: | |||
| fdout.write(requirements_str) | |||
| @staticmethod | |||
| def generate_learnware_yaml(filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None): | |||
| def generate_learnware_yaml( | |||
| filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None | |||
| ): | |||
| learnware_config = {} | |||
| if model_config is not None: | |||
| learnware_config["model"] = model_config | |||
| @@ -64,7 +70,7 @@ class LearnwareTemplate: | |||
| learnware_config["stat_specifications"] = stat_spec_config | |||
| save_dict_to_yaml(learnware_config, filepath) | |||
| @staticmethod | |||
| def generate_learnware_zipfile( | |||
| learnware_zippath: str, | |||
| @@ -75,27 +81,29 @@ class LearnwareTemplate: | |||
| with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: | |||
| requirement_filepath = os.path.join(tempdir, "requirements.txt") | |||
| LearnwareTemplate.generate_requirements(requirement_filepath, requirements) | |||
| model_filepath = os.path.join(tempdir, "__init__.py") | |||
| model_filepath = os.path.join(tempdir, "__init__.py") | |||
| copyfile(model_template.template_path, model_filepath) | |||
| learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml") | |||
| model_config = { | |||
| "class_name": model_template.class_name, | |||
| "kwargs": model_template.model_kwargs, | |||
| } | |||
| stat_spec_config = { | |||
| "module_path": "learnware.specification", | |||
| "class_name": stat_spec_template.type, | |||
| "file_name": "stat_spec.json", | |||
| "kwargs": {} | |||
| "kwargs": {}, | |||
| } | |||
| copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"])) | |||
| LearnwareTemplate.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config]) | |||
| LearnwareTemplate.generate_learnware_yaml( | |||
| learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config] | |||
| ) | |||
| if isinstance(model_template, PickleModelTemplate): | |||
| pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"]) | |||
| copyfile(model_template.pickle_filepath, pickle_filepath) | |||
| convert_folder_to_zipfile(tempdir, learnware_zippath) | |||
| @@ -7,7 +7,6 @@ from learnware.model.base import BaseModel | |||
| class PickleLoadedModel(BaseModel): | |||
| def __init__( | |||
| self, | |||
| input_shape, | |||
| @@ -25,10 +24,10 @@ class PickleLoadedModel(BaseModel): | |||
| self.predict_method = predict_method | |||
| self.fit_method = fit_method | |||
| self.finetune_method = finetune_method | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return getattr(self.model, self.predict_method)(X) | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| getattr(self.model, self.fit_method)(X, y) | |||
| @@ -7,4 +7,4 @@ def parametrize(test_class, **kwargs): | |||
| _suite = unittest.TestSuite() | |||
| for name in test_names: | |||
| _suite.addTest(test_class(name, **kwargs)) | |||
| return _suite | |||
| return _suite | |||
| @@ -1,8 +1,7 @@ | |||
| import os | |||
| import zipfile | |||
| from .file import (convert_folder_to_zipfile, read_yaml_to_dict, | |||
| save_dict_to_yaml) | |||
| from .file import convert_folder_to_zipfile, read_yaml_to_dict, save_dict_to_yaml | |||
| from .gpu import allocate_cuda_idx, choose_device, setup_seed | |||
| from .import_utils import is_torch_available | |||
| from .module import get_module_by_module_path | |||
| @@ -16,6 +16,7 @@ def read_yaml_to_dict(yaml_path: str): | |||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | |||
| return dict_value | |||
| def convert_folder_to_zipfile(folder_path, zip_path): | |||
| with zipfile.ZipFile(zip_path, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(folder_path): | |||
| @@ -17,6 +17,7 @@ def setup_seed(seed): | |||
| random.seed(seed) | |||
| if is_torch_available(verbose=False): | |||
| import torch | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed_all(seed) | |||
| torch.backends.cudnn.deterministic = True | |||
| @@ -4,15 +4,16 @@ import numpy as np | |||
| from learnware.client import LearnwareClient | |||
| from learnware.client.container import LearnwaresContainer | |||
| class TestContainer(unittest.TestCase): | |||
| def __init__(self, method_name='runTest', mode="all"): | |||
| def __init__(self, method_name="runTest", mode="all"): | |||
| super(TestContainer, self).__init__(method_name) | |||
| self.modes = [] | |||
| if mode in {"all", "conda"}: | |||
| self.modes.append("conda") | |||
| if mode in {"all", "docker"}: | |||
| self.modes.append("docker") | |||
| def setUp(self): | |||
| self.client = LearnwareClient() | |||
| @@ -35,17 +36,19 @@ class TestContainer(unittest.TestCase): | |||
| def test_container_with_pip(self): | |||
| for mode in self.modes: | |||
| self._test_container_with_pip(mode=mode) | |||
| def test_container_with_conda(self): | |||
| for mode in self.modes: | |||
| self._test_container_with_conda(mode=mode) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestContainer("test_container_with_pip", mode="all")) | |||
| _suite.addTest(TestContainer("test_container_with_conda", mode="all")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| runner.run(suite()) | |||
| @@ -5,8 +5,9 @@ import numpy as np | |||
| from learnware.client import LearnwareClient | |||
| from learnware.reuse import AveragingReuser | |||
| class TestLearnwareLoad(unittest.TestCase): | |||
| def __init__(self, method_name='runTest', mode="all"): | |||
| def __init__(self, method_name="runTest", mode="all"): | |||
| super(TestLearnwareLoad, self).__init__(method_name) | |||
| self.runnable_options = [] | |||
| if mode in {"all", "conda"}: | |||
| @@ -31,7 +32,6 @@ class TestLearnwareLoad(unittest.TestCase): | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def _test_load_learnware_by_id(self, runnable_option): | |||
| learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option) | |||
| reuser = AveragingReuser(learnware_list, mode="vote_by_label") | |||
| @@ -44,11 +44,11 @@ class TestLearnwareLoad(unittest.TestCase): | |||
| def test_load_learnware_by_zippath(self): | |||
| for runnable_option in self.runnable_options: | |||
| self._test_load_learnware_by_zippath(runnable_option=runnable_option) | |||
| def test_load_learnware_by_id(self): | |||
| for runnable_option in self.runnable_options: | |||
| self._test_load_learnware_by_id(runnable_option=runnable_option) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| @@ -56,6 +56,7 @@ def suite(): | |||
| _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| runner.run(suite()) | |||
| @@ -11,11 +11,11 @@ from learnware.specification import RKMETableSpecification, HeteroMapTableSpecif | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.market.heterogeneous.organizer import HeteroMap | |||
| class TestTableRKME(unittest.TestCase): | |||
| def setUp(self): | |||
| self.hetero_map = HeteroMap() | |||
| def _test_hetero_spec(self, X): | |||
| rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X) | |||
| hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict()) | |||
| @@ -30,14 +30,14 @@ class TestTableRKME(unittest.TestCase): | |||
| rkme2 = HeteroMapTableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "HeteroMapTableSpecification" | |||
| def test_hetero_rkme(self): | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150))) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -25,7 +25,7 @@ class TestImageRKME(unittest.TestCase): | |||
| rkme2 = RKMEImageSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMEImageSpecification" | |||
| def test_image_rkme(self): | |||
| self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | |||
| self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) | |||
| @@ -34,5 +34,6 @@ class TestImageRKME(unittest.TestCase): | |||
| self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) | |||
| self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -24,7 +24,7 @@ class TestTableRKME(unittest.TestCase): | |||
| rkme2 = RKMETableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETableSpecification" | |||
| def test_table_rkme(self): | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) | |||
| @@ -32,5 +32,6 @@ class TestTableRKME(unittest.TestCase): | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -12,19 +12,19 @@ from learnware.specification import generate_stat_spec | |||
| class TestTextRKME(unittest.TestCase): | |||
| @staticmethod | |||
| def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | |||
| text_list = [] | |||
| for i in range(num): | |||
| length = random.randint(min_len, max_len) | |||
| if text_type == "en": | |||
| characters = string.ascii_letters + string.digits + string.punctuation | |||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||
| text_list.append(result_str) | |||
| elif text_type == "zh": | |||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||
| text_list.append(result_str) | |||
| else: | |||
| raise ValueError("Type should be en or zh") | |||
| return text_list | |||
| text_list = [] | |||
| for i in range(num): | |||
| length = random.randint(min_len, max_len) | |||
| if text_type == "en": | |||
| characters = string.ascii_letters + string.digits + string.punctuation | |||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||
| text_list.append(result_str) | |||
| elif text_type == "zh": | |||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||
| text_list.append(result_str) | |||
| else: | |||
| raise ValueError("Type should be en or zh") | |||
| return text_list | |||
| @staticmethod | |||
| def _test_text_rkme(X): | |||
| @@ -11,6 +11,7 @@ from shutil import copyfile, rmtree | |||
| from sklearn.metrics import mean_squared_error | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| @@ -23,6 +24,7 @@ from hetero_config import input_shape_list, input_description_list, output_descr | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| class TestHeteroWorkflow(unittest.TestCase): | |||
| universal_semantic_config = { | |||
| "data_type": "Table", | |||
| @@ -46,10 +48,12 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero") | |||
| os.makedirs(learnware_pool_dirpath, exist_ok=True) | |||
| learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i)) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42) | |||
| X, y = make_regression( | |||
| n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42 | |||
| ) | |||
| clf = Ridge(alpha=1.0) | |||
| clf.fit(X, y) | |||
| pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl") | |||
| @@ -62,14 +66,16 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| LearnwareTemplate.generate_learnware_zipfile( | |||
| learnware_zippath=learnware_zippath, | |||
| model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}), | |||
| model_template=PickleModelTemplate( | |||
| pickle_filepath=pickle_filepath, | |||
| model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)}, | |||
| ), | |||
| stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | |||
| requirements=["scikit-learn==0.22"], | |||
| ) | |||
| self.zip_path_list.append(learnware_zippath) | |||
| def _upload_delete_learnware(self, hetero_market, learnware_num, delete): | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| @@ -83,7 +89,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| description=f"test_learnware_number_{idx}", | |||
| input_description=input_description_list[idx % 2], | |||
| output_description=output_description_list[idx % 2], | |||
| **self.universal_semantic_config | |||
| **self.universal_semantic_config, | |||
| ) | |||
| hetero_market.add_learnware(zip_path, semantic_spec) | |||
| @@ -106,7 +112,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| assert len(curr_inds) == 0, f"The market should be empty!" | |||
| return hetero_market | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| hetero_market = self._init_learnware_market() | |||
| return self._upload_delete_learnware(hetero_market, learnware_num, delete) | |||
| @@ -129,7 +135,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| name=f"learnware_{learnware_num - 1}", | |||
| **self.universal_semantic_config, | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| @@ -154,7 +160,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| def test_hetero_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | |||
| print("Total Item:", len(hetero_market)) | |||
| user_dim = 15 | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | |||
| @@ -174,7 +180,10 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| semantic_spec = generate_semantic_spec( | |||
| input_description={ | |||
| "Dimension": user_dim, | |||
| "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, | |||
| "Description": { | |||
| str(key): input_description_list[idx % 2]["Description"][str(key)] | |||
| for key in range(user_dim) | |||
| }, | |||
| }, | |||
| **self.universal_semantic_config, | |||
| ) | |||
| @@ -182,7 +191,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{idx}:") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| @@ -215,7 +224,10 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| semantic_spec = generate_semantic_spec( | |||
| input_description={ | |||
| "Dimension": user_dim - 2, | |||
| "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, | |||
| "Description": { | |||
| str(key): input_description_list[idx % 2]["Description"][str(key)] | |||
| for key in range(user_dim) | |||
| }, | |||
| }, | |||
| **self.universal_semantic_config, | |||
| ) | |||
| @@ -228,7 +240,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| def test_homo_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | |||
| print("Total Item:", len(hetero_market)) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| @@ -260,7 +272,9 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) | |||
| # generate specification | |||
| semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config) | |||
| semantic_spec = generate_semantic_spec( | |||
| input_description=user_description_list[0], **self.universal_semantic_config | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| # learnware market search | |||
| @@ -268,7 +282,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| # print search results | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| @@ -306,9 +320,9 @@ class TestHeteroWorkflow(unittest.TestCase): | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| #_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) | |||
| #_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) | |||
| #_suite.addTest(TestHeteroWorkflow("test_train_market_model")) | |||
| # _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) | |||
| # _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) | |||
| # _suite.addTest(TestHeteroWorkflow("test_train_market_model")) | |||
| _suite.addTest(TestHeteroWorkflow("test_search_semantics")) | |||
| _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search")) | |||
| _suite.addTest(TestHeteroWorkflow("test_homo_stat_search")) | |||
| @@ -10,6 +10,7 @@ from sklearn.datasets import load_digits | |||
| from sklearn.model_selection import train_test_split | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| @@ -19,8 +20,8 @@ from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, St | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| class TestWorkflow(unittest.TestCase): | |||
| universal_semantic_config = { | |||
| "data_type": "Table", | |||
| "task_type": "Classification", | |||
| @@ -28,7 +29,7 @@ class TestWorkflow(unittest.TestCase): | |||
| "scenarios": "Education", | |||
| "license": "MIT", | |||
| } | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) | |||
| @@ -42,7 +43,7 @@ class TestWorkflow(unittest.TestCase): | |||
| learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool") | |||
| os.makedirs(learnware_pool_dirpath, exist_ok=True) | |||
| learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i)) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| clf = svm.SVC(kernel="linear", probability=True) | |||
| @@ -54,14 +55,17 @@ class TestWorkflow(unittest.TestCase): | |||
| spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") | |||
| spec.save(spec_filepath) | |||
| LearnwareTemplate.generate_learnware_zipfile( | |||
| learnware_zippath=learnware_zippath, | |||
| model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}), | |||
| model_template=PickleModelTemplate( | |||
| pickle_filepath=pickle_filepath, | |||
| model_kwargs={"input_shape": (64,), "output_shape": (10,), "predict_method": "predict_proba"}, | |||
| ), | |||
| stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | |||
| requirements=["scikit-learn==0.22"], | |||
| ) | |||
| self.zip_path_list.append(learnware_zippath) | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| @@ -87,7 +91,7 @@ class TestWorkflow(unittest.TestCase): | |||
| "Dimension": 10, | |||
| "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | |||
| }, | |||
| **self.universal_semantic_config | |||
| **self.universal_semantic_config, | |||
| ) | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| @@ -113,7 +117,7 @@ class TestWorkflow(unittest.TestCase): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: | |||
| with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| @@ -123,15 +127,15 @@ class TestWorkflow(unittest.TestCase): | |||
| description=f"test_learnware_number_{learnware_num - 1}", | |||
| **self.universal_semantic_config, | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print(f"Search result:") | |||
| for search_item in single_result: | |||
| print("Choose learnware:",search_item.learnware.id) | |||
| print("Choose learnware:", search_item.learnware.id) | |||
| def test_stat_search(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||