From 3297847927e330026aea86369eebb4a5800345d2 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 11 Jan 2024 21:14:41 +0800 Subject: [PATCH] [MNT] format code using black v23.1.0 --- docs/conf.py | 4 +- .../pfs/pfs_cross_transfer.py | 4 +- learnware/client/package_utils.py | 4 +- learnware/learnware/__init__.py | 18 ++++--- learnware/market/__init__.py | 6 +-- learnware/market/easy/__init__.py | 3 +- learnware/market/easy/searcher.py | 6 +-- .../organizer/hetero_map/__init__.py | 5 +- learnware/market/heterogeneous/searcher.py | 6 ++- learnware/market/module.py | 21 +++++--- learnware/reuse/__init__.py | 2 +- learnware/reuse/ensemble_pruning.py | 19 ++++--- learnware/reuse/job_selector.py | 3 +- learnware/reuse/utils.py | 1 + learnware/specification/__init__.py | 21 +++++--- .../specification/regular/table/__init__.py | 3 +- learnware/specification/regular/text/rkme.py | 8 +-- .../specification/system/hetero_table.py | 4 +- learnware/tests/__init__.py | 2 +- learnware/tests/templates/__init__.py | 38 ++++++++------ learnware/tests/templates/pickle_model.py | 5 +- learnware/tests/utils.py | 2 +- learnware/utils/__init__.py | 3 +- learnware/utils/file.py | 1 + learnware/utils/gpu.py | 1 + tests/test_learnware_client/test_container.py | 11 ++-- .../test_load_learnware.py | 11 ++-- tests/test_specification/test_hetero_spec.py | 10 ++-- tests/test_specification/test_image_rkme.py | 3 +- tests/test_specification/test_table_rkme.py | 3 +- tests/test_specification/test_text_rkme.py | 26 +++++----- tests/test_workflow/test_hetero_workflow.py | 50 ++++++++++++------- tests/test_workflow/test_workflow.py | 26 ++++++---- 33 files changed, 194 insertions(+), 136 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 155d20a..b8507b4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -100,12 +100,12 @@ html_logo = "_static/img/logo/logo1.png" # These folders are copied to the documentation's HTML output -html_static_path = ['_static'] +html_static_path = ["_static"] # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) html_css_files = [ - 'css/custom_style.css', + "css/custom_style.css", ] # -- Options for HTMLHelp output ------------------------------------------ diff --git a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py index 5f69127..93a3fa3 100644 --- a/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py +++ b/examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py @@ -85,9 +85,7 @@ def get_split_errs(algo): split = train_xs.shape[0] - proportion_list[tmp] model.fit( - train_xs[ - split:, - ], + train_xs[split:,], train_ys[split:], eval_set=[(val_xs, val_ys)], early_stopping_rounds=50, diff --git a/learnware/client/package_utils.py b/learnware/client/package_utils.py index 13467ab..cc145d2 100644 --- a/learnware/client/package_utils.py +++ b/learnware/client/package_utils.py @@ -86,7 +86,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: pass except Exception as err: logger.error(err) - + return None exist_packages = [] @@ -101,7 +101,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: exist_packages.append(result) else: nonexist_packages.append(package) - + if len(nonexist_packages) > 0: logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.") return exist_packages, nonexist_packages diff --git a/learnware/learnware/__init__.py b/learnware/learnware/__init__.py index fca213a..60996a7 100644 --- a/learnware/learnware/__init__.py +++ b/learnware/learnware/__init__.py @@ -13,7 +13,9 @@ from ..utils import read_yaml_to_dict logger = get_module_logger("learnware.learnware") -def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Optional[Learnware]: +def get_learnware_from_dirpath( + id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True +) -> Optional[Learnware]: """Get the learnware object from dirpath, and provide the manage interface tor Learnware class Parameters @@ -46,11 +48,11 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, } try: - learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"]) - assert os.path.exists(learnware_yaml_path), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." - - + assert os.path.exists( + learnware_yaml_path + ), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." + yaml_config = read_yaml_to_dict(learnware_yaml_path) if "name" in yaml_config: @@ -67,8 +69,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, for _stat_spec in learnware_config["stat_specifications"]: stat_spec = _stat_spec.copy() stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"]) - assert os.path.exists(stat_spec_path), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." - + assert os.path.exists( + stat_spec_path + ), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." + stat_spec["file_name"] = stat_spec_path stat_spec_inst = get_stat_spec_from_config(stat_spec) learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst}) diff --git a/learnware/market/__init__.py b/learnware/market/__init__.py index 0d2fd4c..fba3552 100644 --- a/learnware/market/__init__.py +++ b/learnware/market/__init__.py @@ -1,9 +1,7 @@ from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo -from .base import (BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, - LearnwareMarket) +from .base import BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, LearnwareMarket from .classes import CondaChecker -from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker, - EasyStatChecker) +from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker from .evolve import EvolvedOrganizer from .evolve_anchor import EvolvedAnchoredOrganizer from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher diff --git a/learnware/market/easy/__init__.py b/learnware/market/easy/__init__.py index 88b574e..e2d58e5 100644 --- a/learnware/market/easy/__init__.py +++ b/learnware/market/easy/__init__.py @@ -11,5 +11,4 @@ if not is_torch_available(verbose=False): logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") else: from .checker import EasySemanticChecker, EasyStatChecker - from .searcher import (EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, - EasySearcher, EasyStatSearcher) + from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasySearcher, EasyStatSearcher diff --git a/learnware/market/easy/searcher.py b/learnware/market/easy/searcher.py index b6f9ede..4225e7a 100644 --- a/learnware/market/easy/searcher.py +++ b/learnware/market/easy/searcher.py @@ -6,13 +6,11 @@ import torch from rapidfuzz import fuzz from .organizer import EasyOrganizer -from ..base import (BaseSearcher, BaseUserInfo, MultipleSearchItem, - SearchResults, SingleSearchItem) +from ..base import BaseSearcher, BaseUserInfo, MultipleSearchItem, SearchResults, SingleSearchItem from ..utils import parse_specification_type from ...learnware import Learnware from ...logger import get_module_logger -from ...specification import (RKMEImageSpecification, RKMETableSpecification, - RKMETextSpecification, rkme_solve_qp) +from ...specification import RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification, rkme_solve_qp logger = get_module_logger("easy_seacher") diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 273cac9..47978cf 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -8,8 +8,7 @@ from torch import nn from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer from .trainer import Trainer, TransTabCollatorForCL -from .....specification import (HeteroMapTableSpecification, - RKMETableSpecification) +from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .....utils import allocate_cuda_idx, choose_device @@ -288,7 +287,7 @@ class HeteroMap(nn.Module): # go through transformers, get the first cls embedding encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim output_features = encoder_output[:, 0, :] - + del inputs, outputs, encoder_output torch.cuda.empty_cache() diff --git a/learnware/market/heterogeneous/searcher.py b/learnware/market/heterogeneous/searcher.py index 8a97dba..5a10ac0 100644 --- a/learnware/market/heterogeneous/searcher.py +++ b/learnware/market/heterogeneous/searcher.py @@ -11,7 +11,11 @@ logger = get_module_logger("hetero_searcher") class HeteroSearcher(EasySearcher): def __call__( - self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy" + self, + user_info: BaseUserInfo, + check_status: Optional[int] = None, + max_search_num: int = 5, + search_method: str = "greedy", ) -> SearchResults: """Search learnwares based on user_info from learnwares with check_status. Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. diff --git a/learnware/market/module.py b/learnware/market/module.py index c5c64f1..cdc13e7 100644 --- a/learnware/market/module.py +++ b/learnware/market/module.py @@ -1,11 +1,12 @@ from .base import LearnwareMarket from .classes import CondaChecker -from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker, - EasyStatChecker) +from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher -def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False): +def get_market_component( + name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False +): organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs checker_kwargs = {} if checker_kwargs is None else checker_kwargs @@ -13,7 +14,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search if name == "easy": easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) easy_searcher = EasySearcher(organizer=easy_organizer) - easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] + easy_checker_list = [ + EasySemanticChecker(), + EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), + ] market_component = { "organizer": easy_organizer, "searcher": easy_searcher, @@ -22,7 +26,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search elif name == "hetero": hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) hetero_searcher = HeteroSearcher(organizer=hetero_organizer) - hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] + hetero_checker_list = [ + EasySemanticChecker(), + EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), + ] market_component = { "organizer": hetero_organizer, @@ -45,7 +52,9 @@ def instantiate_learnware_market( conda_checker: bool = False, **kwargs, ): - market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker) + market_componets = get_market_component( + name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker + ) return LearnwareMarket( organizer=market_componets["organizer"], searcher=market_componets["searcher"], diff --git a/learnware/reuse/__init__.py b/learnware/reuse/__init__.py index 7296ad1..7a8d185 100644 --- a/learnware/reuse/__init__.py +++ b/learnware/reuse/__init__.py @@ -20,4 +20,4 @@ else: from .ensemble_pruning import EnsemblePruningReuser from .feature_augment import FeatureAugmentReuser from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware - from .job_selector import JobSelectorReuser \ No newline at end of file + from .job_selector import JobSelectorReuser diff --git a/learnware/reuse/ensemble_pruning.py b/learnware/reuse/ensemble_pruning.py index d182c9f..a8eb607 100644 --- a/learnware/reuse/ensemble_pruning.py +++ b/learnware/reuse/ensemble_pruning.py @@ -54,13 +54,14 @@ class EnsemblePruningReuser(BaseReuser): np.ndarray Binary one-dimensional vector, 1 indicates that the corresponding model is selected. """ - - + try: import geatpy as ea except ModuleNotFoundError: - raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") - + raise ModuleNotFoundError( + f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." + ) + model_num = v_predict.shape[1] @ea.Problem.single @@ -148,7 +149,9 @@ class EnsemblePruningReuser(BaseReuser): try: import geatpy as ea except ModuleNotFoundError: - raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") + raise ModuleNotFoundError( + f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." + ) if torch.is_tensor(v_true): v_true = v_true.detach().cpu().numpy() @@ -270,8 +273,10 @@ class EnsemblePruningReuser(BaseReuser): try: import geatpy as ea except ModuleNotFoundError: - raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") - + raise ModuleNotFoundError( + f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." + ) + model_num = v_predict.shape[1] v_predict[v_predict == 0.0] = -1 v_true[v_true == 0.0] = -1 diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 355849e..49689ed 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -8,8 +8,7 @@ from .base import BaseReuser from ..learnware import Learnware from ..logger import get_module_logger from ..market.utils import parse_specification_type -from ..specification import (RKMETableSpecification, RKMETextSpecification, - generate_rkme_table_spec, rkme_solve_qp) +from ..specification import RKMETableSpecification, RKMETextSpecification, generate_rkme_table_spec, rkme_solve_qp logger = get_module_logger("job_selector_reuse") diff --git a/learnware/reuse/utils.py b/learnware/reuse/utils.py index 49bb2f2..075cc20 100644 --- a/learnware/reuse/utils.py +++ b/learnware/reuse/utils.py @@ -4,6 +4,7 @@ from ..logger import get_module_logger logger = get_module_logger("reuse_utils") + def fill_data_with_mean(X: np.ndarray) -> np.ndarray: """ Fill missing data (NaN, Inf) in the input array with the mean of the column. diff --git a/learnware/specification/__init__.py b/learnware/specification/__init__.py index 4667548..6f50627 100644 --- a/learnware/specification/__init__.py +++ b/learnware/specification/__init__.py @@ -1,7 +1,12 @@ from .base import BaseStatSpecification, Specification -from .regular import (RegularStatSpecification, RKMEImageSpecification, - RKMEStatSpecification, RKMETableSpecification, - RKMETextSpecification, rkme_solve_qp) +from .regular import ( + RegularStatSpecification, + RKMEImageSpecification, + RKMEStatSpecification, + RKMETableSpecification, + RKMETextSpecification, + rkme_solve_qp, +) from .system import HeteroMapTableSpecification from ..utils import is_torch_available @@ -12,6 +17,10 @@ if not is_torch_available(verbose=False): generate_rkme_text_spec = None generate_semantic_spec = None else: - from .module import (generate_rkme_image_spec, generate_rkme_table_spec, - generate_rkme_text_spec, generate_semantic_spec, - generate_stat_spec) + from .module import ( + generate_rkme_image_spec, + generate_rkme_table_spec, + generate_rkme_text_spec, + generate_semantic_spec, + generate_stat_spec, + ) diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index 681d7ae..7f2b04c 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -11,5 +11,4 @@ if not is_torch_available(verbose=False): f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!" ) else: - from .rkme import (RKMEStatSpecification, RKMETableSpecification, - rkme_solve_qp) + from .rkme import RKMEStatSpecification, RKMETableSpecification, rkme_solve_qp diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index ab5e237..3427e67 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -87,12 +87,14 @@ class RKMETextSpecification(RKMETableSpecification): return np.array(miniLM_learnware.predict(X)) logger.info("Load the necessary feature extractor for RKMETextSpecification.") - + try: from sentence_transformers import SentenceTransformer except ModuleNotFoundError: - raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.") - + raise ModuleNotFoundError( + f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually." + ) + if os.path.exists(zip_path): X = _get_from_client(zip_path, X) else: diff --git a/learnware/specification/system/hetero_table.py b/learnware/specification/system/hetero_table.py index 52602e6..65b6d3f 100644 --- a/learnware/specification/system/hetero_table.py +++ b/learnware/specification/system/hetero_table.py @@ -137,7 +137,9 @@ class HeteroMapTableSpecification(SystemStatSpecification): for d in self.get_states(): if d in embedding_load.keys(): if d == "type" and embedding_load[d] != self.type: - raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!") + raise TypeError( + f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!" + ) setattr(self, d, embedding_load[d]) def save(self, filepath: str) -> bool: diff --git a/learnware/tests/__init__.py b/learnware/tests/__init__.py index e8ee37e..7ba38d5 100644 --- a/learnware/tests/__init__.py +++ b/learnware/tests/__init__.py @@ -1 +1 @@ -from .utils import parametrize \ No newline at end of file +from .utils import parametrize diff --git a/learnware/tests/templates/__init__.py b/learnware/tests/templates/__init__.py index 69237f9..d2f016f 100644 --- a/learnware/tests/templates/__init__.py +++ b/learnware/tests/templates/__init__.py @@ -13,10 +13,13 @@ class ModelTemplate: class_name: str = field(init=False) template_path: str = field(init=False) model_kwargs: dict = field(init=False) + + @dataclass class PickleModelTemplate(ModelTemplate): model_kwargs: dict pickle_filepath: str + def __post_init__(self): self.class_name = "PickleLoadedModel" self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") @@ -29,13 +32,14 @@ class PickleModelTemplate(ModelTemplate): default_model_kwargs.update(self.model_kwargs) self.model_kwargs = default_model_kwargs + @dataclass class StatSpecTemplate: filepath: str type: str = field(default="RKMETableSpecification") - -class LearnwareTemplate: + +class LearnwareTemplate: @staticmethod def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None): requirements = [] if requirements is None else requirements @@ -49,14 +53,16 @@ class LearnwareTemplate: line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n" else: raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}") - + requirements_str += line_str - + with open(filepath, "w") as fdout: fdout.write(requirements_str) - + @staticmethod - def generate_learnware_yaml(filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None): + def generate_learnware_yaml( + filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None + ): learnware_config = {} if model_config is not None: learnware_config["model"] = model_config @@ -64,7 +70,7 @@ class LearnwareTemplate: learnware_config["stat_specifications"] = stat_spec_config save_dict_to_yaml(learnware_config, filepath) - + @staticmethod def generate_learnware_zipfile( learnware_zippath: str, @@ -75,27 +81,29 @@ class LearnwareTemplate: with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: requirement_filepath = os.path.join(tempdir, "requirements.txt") LearnwareTemplate.generate_requirements(requirement_filepath, requirements) - - model_filepath = os.path.join(tempdir, "__init__.py") + + model_filepath = os.path.join(tempdir, "__init__.py") copyfile(model_template.template_path, model_filepath) - + learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml") model_config = { "class_name": model_template.class_name, "kwargs": model_template.model_kwargs, } - + stat_spec_config = { "module_path": "learnware.specification", "class_name": stat_spec_template.type, "file_name": "stat_spec.json", - "kwargs": {} + "kwargs": {}, } copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"])) - LearnwareTemplate.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config]) - + LearnwareTemplate.generate_learnware_yaml( + learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config] + ) + if isinstance(model_template, PickleModelTemplate): pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"]) copyfile(model_template.pickle_filepath, pickle_filepath) - + convert_folder_to_zipfile(tempdir, learnware_zippath) diff --git a/learnware/tests/templates/pickle_model.py b/learnware/tests/templates/pickle_model.py index 8ec7f44..e031d8a 100644 --- a/learnware/tests/templates/pickle_model.py +++ b/learnware/tests/templates/pickle_model.py @@ -7,7 +7,6 @@ from learnware.model.base import BaseModel class PickleLoadedModel(BaseModel): - def __init__( self, input_shape, @@ -25,10 +24,10 @@ class PickleLoadedModel(BaseModel): self.predict_method = predict_method self.fit_method = fit_method self.finetune_method = finetune_method - + def predict(self, X: np.ndarray) -> np.ndarray: return getattr(self.model, self.predict_method)(X) - + def fit(self, X: np.ndarray, y: np.ndarray): getattr(self.model, self.fit_method)(X, y) diff --git a/learnware/tests/utils.py b/learnware/tests/utils.py index 5486bf4..b5738cc 100644 --- a/learnware/tests/utils.py +++ b/learnware/tests/utils.py @@ -7,4 +7,4 @@ def parametrize(test_class, **kwargs): _suite = unittest.TestSuite() for name in test_names: _suite.addTest(test_class(name, **kwargs)) - return _suite \ No newline at end of file + return _suite diff --git a/learnware/utils/__init__.py b/learnware/utils/__init__.py index d7b666a..bde7f09 100644 --- a/learnware/utils/__init__.py +++ b/learnware/utils/__init__.py @@ -1,8 +1,7 @@ import os import zipfile -from .file import (convert_folder_to_zipfile, read_yaml_to_dict, - save_dict_to_yaml) +from .file import convert_folder_to_zipfile, read_yaml_to_dict, save_dict_to_yaml from .gpu import allocate_cuda_idx, choose_device, setup_seed from .import_utils import is_torch_available from .module import get_module_by_module_path diff --git a/learnware/utils/file.py b/learnware/utils/file.py index 4108b49..4366206 100644 --- a/learnware/utils/file.py +++ b/learnware/utils/file.py @@ -16,6 +16,7 @@ def read_yaml_to_dict(yaml_path: str): dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) return dict_value + def convert_folder_to_zipfile(folder_path, zip_path): with zipfile.ZipFile(zip_path, "w") as zip_obj: for foldername, subfolders, filenames in os.walk(folder_path): diff --git a/learnware/utils/gpu.py b/learnware/utils/gpu.py index 23330a5..7423009 100644 --- a/learnware/utils/gpu.py +++ b/learnware/utils/gpu.py @@ -17,6 +17,7 @@ def setup_seed(seed): random.seed(seed) if is_torch_available(verbose=False): import torch + torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True diff --git a/tests/test_learnware_client/test_container.py b/tests/test_learnware_client/test_container.py index 861e0eb..4e1f1f4 100644 --- a/tests/test_learnware_client/test_container.py +++ b/tests/test_learnware_client/test_container.py @@ -4,15 +4,16 @@ import numpy as np from learnware.client import LearnwareClient from learnware.client.container import LearnwaresContainer + class TestContainer(unittest.TestCase): - def __init__(self, method_name='runTest', mode="all"): + def __init__(self, method_name="runTest", mode="all"): super(TestContainer, self).__init__(method_name) self.modes = [] if mode in {"all", "conda"}: self.modes.append("conda") if mode in {"all", "docker"}: self.modes.append("docker") - + def setUp(self): self.client = LearnwareClient() @@ -35,17 +36,19 @@ class TestContainer(unittest.TestCase): def test_container_with_pip(self): for mode in self.modes: self._test_container_with_pip(mode=mode) - + def test_container_with_conda(self): for mode in self.modes: self._test_container_with_conda(mode=mode) + def suite(): _suite = unittest.TestSuite() _suite.addTest(TestContainer("test_container_with_pip", mode="all")) _suite.addTest(TestContainer("test_container_with_conda", mode="all")) return _suite + if __name__ == "__main__": runner = unittest.TextTestRunner() - runner.run(suite()) \ No newline at end of file + runner.run(suite()) diff --git a/tests/test_learnware_client/test_load_learnware.py b/tests/test_learnware_client/test_load_learnware.py index 63f9856..1ce2250 100644 --- a/tests/test_learnware_client/test_load_learnware.py +++ b/tests/test_learnware_client/test_load_learnware.py @@ -5,8 +5,9 @@ import numpy as np from learnware.client import LearnwareClient from learnware.reuse import AveragingReuser + class TestLearnwareLoad(unittest.TestCase): - def __init__(self, method_name='runTest', mode="all"): + def __init__(self, method_name="runTest", mode="all"): super(TestLearnwareLoad, self).__init__(method_name) self.runnable_options = [] if mode in {"all", "conda"}: @@ -31,7 +32,6 @@ class TestLearnwareLoad(unittest.TestCase): for learnware in learnware_list: print(learnware.id, learnware.predict(input_array)) - def _test_load_learnware_by_id(self, runnable_option): learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option) reuser = AveragingReuser(learnware_list, mode="vote_by_label") @@ -44,11 +44,11 @@ class TestLearnwareLoad(unittest.TestCase): def test_load_learnware_by_zippath(self): for runnable_option in self.runnable_options: self._test_load_learnware_by_zippath(runnable_option=runnable_option) - + def test_load_learnware_by_id(self): for runnable_option in self.runnable_options: self._test_load_learnware_by_id(runnable_option=runnable_option) - + def suite(): _suite = unittest.TestSuite() @@ -56,6 +56,7 @@ def suite(): _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all")) return _suite + if __name__ == "__main__": runner = unittest.TextTestRunner() - runner.run(suite()) \ No newline at end of file + runner.run(suite()) diff --git a/tests/test_specification/test_hetero_spec.py b/tests/test_specification/test_hetero_spec.py index 21563b3..b0f7e87 100644 --- a/tests/test_specification/test_hetero_spec.py +++ b/tests/test_specification/test_hetero_spec.py @@ -11,11 +11,11 @@ from learnware.specification import RKMETableSpecification, HeteroMapTableSpecif from learnware.specification import generate_stat_spec from learnware.market.heterogeneous.organizer import HeteroMap + class TestTableRKME(unittest.TestCase): - def setUp(self): self.hetero_map = HeteroMap() - + def _test_hetero_spec(self, X): rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X) hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict()) @@ -30,14 +30,14 @@ class TestTableRKME(unittest.TestCase): rkme2 = HeteroMapTableSpecification() rkme2.load(rkme_path) assert rkme2.type == "HeteroMapTableSpecification" - - + def test_hetero_rkme(self): self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150))) - + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_specification/test_image_rkme.py b/tests/test_specification/test_image_rkme.py index 29312bf..4bd71b5 100644 --- a/tests/test_specification/test_image_rkme.py +++ b/tests/test_specification/test_image_rkme.py @@ -25,7 +25,7 @@ class TestImageRKME(unittest.TestCase): rkme2 = RKMEImageSpecification() rkme2.load(rkme_path) assert rkme2.type == "RKMEImageSpecification" - + def test_image_rkme(self): self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) @@ -34,5 +34,6 @@ class TestImageRKME(unittest.TestCase): self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_specification/test_table_rkme.py b/tests/test_specification/test_table_rkme.py index 9c314f1..2be9113 100644 --- a/tests/test_specification/test_table_rkme.py +++ b/tests/test_specification/test_table_rkme.py @@ -24,7 +24,7 @@ class TestTableRKME(unittest.TestCase): rkme2 = RKMETableSpecification() rkme2.load(rkme_path) assert rkme2.type == "RKMETableSpecification" - + def test_table_rkme(self): self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) @@ -32,5 +32,6 @@ class TestTableRKME(unittest.TestCase): self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_specification/test_text_rkme.py b/tests/test_specification/test_text_rkme.py index 0409d98..6675cf4 100644 --- a/tests/test_specification/test_text_rkme.py +++ b/tests/test_specification/test_text_rkme.py @@ -12,19 +12,19 @@ from learnware.specification import generate_stat_spec class TestTextRKME(unittest.TestCase): @staticmethod def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): - text_list = [] - for i in range(num): - length = random.randint(min_len, max_len) - if text_type == "en": - characters = string.ascii_letters + string.digits + string.punctuation - result_str = "".join(random.choice(characters) for i in range(length)) - text_list.append(result_str) - elif text_type == "zh": - result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) - text_list.append(result_str) - else: - raise ValueError("Type should be en or zh") - return text_list + text_list = [] + for i in range(num): + length = random.randint(min_len, max_len) + if text_type == "en": + characters = string.ascii_letters + string.digits + string.punctuation + result_str = "".join(random.choice(characters) for i in range(length)) + text_list.append(result_str) + elif text_type == "zh": + result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) + text_list.append(result_str) + else: + raise ValueError("Type should be en or zh") + return text_list @staticmethod def _test_text_rkme(X): diff --git a/tests/test_workflow/test_hetero_workflow.py b/tests/test_workflow/test_hetero_workflow.py index 3276bdc..245fc4c 100644 --- a/tests/test_workflow/test_hetero_workflow.py +++ b/tests/test_workflow/test_hetero_workflow.py @@ -11,6 +11,7 @@ from shutil import copyfile, rmtree from sklearn.metrics import mean_squared_error import learnware + learnware.init(logging_level=logging.WARNING) from learnware.market import instantiate_learnware_market, BaseUserInfo @@ -23,6 +24,7 @@ from hetero_config import input_shape_list, input_description_list, output_descr curr_root = os.path.dirname(os.path.abspath(__file__)) + class TestHeteroWorkflow(unittest.TestCase): universal_semantic_config = { "data_type": "Table", @@ -46,10 +48,12 @@ class TestHeteroWorkflow(unittest.TestCase): learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero") os.makedirs(learnware_pool_dirpath, exist_ok=True) learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i)) - + print("Preparing Learnware: %d" % (i)) - X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42) + X, y = make_regression( + n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42 + ) clf = Ridge(alpha=1.0) clf.fit(X, y) pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl") @@ -62,14 +66,16 @@ class TestHeteroWorkflow(unittest.TestCase): LearnwareTemplate.generate_learnware_zipfile( learnware_zippath=learnware_zippath, - model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}), + model_template=PickleModelTemplate( + pickle_filepath=pickle_filepath, + model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)}, + ), stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), requirements=["scikit-learn==0.22"], ) - + self.zip_path_list.append(learnware_zippath) - def _upload_delete_learnware(self, hetero_market, learnware_num, delete): self.test_prepare_learnware_randomly(learnware_num) self.learnware_num = learnware_num @@ -83,7 +89,7 @@ class TestHeteroWorkflow(unittest.TestCase): description=f"test_learnware_number_{idx}", input_description=input_description_list[idx % 2], output_description=output_description_list[idx % 2], - **self.universal_semantic_config + **self.universal_semantic_config, ) hetero_market.add_learnware(zip_path, semantic_spec) @@ -106,7 +112,7 @@ class TestHeteroWorkflow(unittest.TestCase): assert len(curr_inds) == 0, f"The market should be empty!" return hetero_market - + def test_upload_delete_learnware(self, learnware_num=5, delete=True): hetero_market = self._init_learnware_market() return self._upload_delete_learnware(hetero_market, learnware_num, delete) @@ -129,7 +135,7 @@ class TestHeteroWorkflow(unittest.TestCase): name=f"learnware_{learnware_num - 1}", **self.universal_semantic_config, ) - + user_info = BaseUserInfo(semantic_spec=semantic_spec) search_result = hetero_market.search_learnware(user_info) single_result = search_result.get_single_results() @@ -154,7 +160,7 @@ class TestHeteroWorkflow(unittest.TestCase): def test_hetero_stat_search(self, learnware_num=5): hetero_market = self.test_train_market_model(learnware_num, delete=False) print("Total Item:", len(hetero_market)) - + user_dim = 15 with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: @@ -174,7 +180,10 @@ class TestHeteroWorkflow(unittest.TestCase): semantic_spec = generate_semantic_spec( input_description={ "Dimension": user_dim, - "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, + "Description": { + str(key): input_description_list[idx % 2]["Description"][str(key)] + for key in range(user_dim) + }, }, **self.universal_semantic_config, ) @@ -182,7 +191,7 @@ class TestHeteroWorkflow(unittest.TestCase): search_result = hetero_market.search_learnware(user_info) single_result = search_result.get_single_results() multiple_result = search_result.get_multiple_results() - + print(f"search result of user{idx}:") for single_item in single_result: print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") @@ -215,7 +224,10 @@ class TestHeteroWorkflow(unittest.TestCase): semantic_spec = generate_semantic_spec( input_description={ "Dimension": user_dim - 2, - "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, + "Description": { + str(key): input_description_list[idx % 2]["Description"][str(key)] + for key in range(user_dim) + }, }, **self.universal_semantic_config, ) @@ -228,7 +240,7 @@ class TestHeteroWorkflow(unittest.TestCase): def test_homo_stat_search(self, learnware_num=5): hetero_market = self.test_train_market_model(learnware_num, delete=False) print("Total Item:", len(hetero_market)) - + with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: for idx, zip_path in enumerate(self.zip_path_list): with zipfile.ZipFile(zip_path, "r") as zip_obj: @@ -260,7 +272,9 @@ class TestHeteroWorkflow(unittest.TestCase): user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) # generate specification - semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config) + semantic_spec = generate_semantic_spec( + input_description=user_description_list[0], **self.universal_semantic_config + ) user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) # learnware market search @@ -268,7 +282,7 @@ class TestHeteroWorkflow(unittest.TestCase): search_result = hetero_market.search_learnware(user_info) single_result = search_result.get_single_results() multiple_result = search_result.get_multiple_results() - + # print search results for single_item in single_result: print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") @@ -306,9 +320,9 @@ class TestHeteroWorkflow(unittest.TestCase): def suite(): _suite = unittest.TestSuite() - #_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) - #_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) - #_suite.addTest(TestHeteroWorkflow("test_train_market_model")) + # _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) + # _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) + # _suite.addTest(TestHeteroWorkflow("test_train_market_model")) _suite.addTest(TestHeteroWorkflow("test_search_semantics")) _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search")) _suite.addTest(TestHeteroWorkflow("test_homo_stat_search")) diff --git a/tests/test_workflow/test_workflow.py b/tests/test_workflow/test_workflow.py index c7a5bc5..bbd6038 100644 --- a/tests/test_workflow/test_workflow.py +++ b/tests/test_workflow/test_workflow.py @@ -10,6 +10,7 @@ from sklearn.datasets import load_digits from sklearn.model_selection import train_test_split import learnware + learnware.init(logging_level=logging.WARNING) from learnware.market import instantiate_learnware_market, BaseUserInfo @@ -19,8 +20,8 @@ from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, St curr_root = os.path.dirname(os.path.abspath(__file__)) + class TestWorkflow(unittest.TestCase): - universal_semantic_config = { "data_type": "Table", "task_type": "Classification", @@ -28,7 +29,7 @@ class TestWorkflow(unittest.TestCase): "scenarios": "Education", "license": "MIT", } - + def _init_learnware_market(self): """initialize learnware market""" easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) @@ -42,7 +43,7 @@ class TestWorkflow(unittest.TestCase): learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool") os.makedirs(learnware_pool_dirpath, exist_ok=True) learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i)) - + print("Preparing Learnware: %d" % (i)) data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) clf = svm.SVC(kernel="linear", probability=True) @@ -54,14 +55,17 @@ class TestWorkflow(unittest.TestCase): spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") spec.save(spec_filepath) - + LearnwareTemplate.generate_learnware_zipfile( learnware_zippath=learnware_zippath, - model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}), + model_template=PickleModelTemplate( + pickle_filepath=pickle_filepath, + model_kwargs={"input_shape": (64,), "output_shape": (10,), "predict_method": "predict_proba"}, + ), stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), requirements=["scikit-learn==0.22"], ) - + self.zip_path_list.append(learnware_zippath) def test_upload_delete_learnware(self, learnware_num=5, delete=True): @@ -87,7 +91,7 @@ class TestWorkflow(unittest.TestCase): "Dimension": 10, "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, }, - **self.universal_semantic_config + **self.universal_semantic_config, ) easy_market.add_learnware(zip_path, semantic_spec) @@ -113,7 +117,7 @@ class TestWorkflow(unittest.TestCase): easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) print("Total Item:", len(easy_market)) assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" - + with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: zip_obj.extractall(path=test_folder) @@ -123,15 +127,15 @@ class TestWorkflow(unittest.TestCase): description=f"test_learnware_number_{learnware_num - 1}", **self.universal_semantic_config, ) - + user_info = BaseUserInfo(semantic_spec=semantic_spec) search_result = easy_market.search_learnware(user_info) single_result = search_result.get_single_results() print(f"Search result:") for search_item in single_result: - print("Choose learnware:",search_item.learnware.id) - + print("Choose learnware:", search_item.learnware.id) + def test_stat_search(self, learnware_num=5): easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) print("Total Item:", len(easy_market))