| @@ -100,12 +100,12 @@ html_logo = "_static/img/logo/logo1.png" | |||||
| # These folders are copied to the documentation's HTML output | # These folders are copied to the documentation's HTML output | ||||
| html_static_path = ['_static'] | |||||
| html_static_path = ["_static"] | |||||
| # These paths are either relative to html_static_path | # These paths are either relative to html_static_path | ||||
| # or fully qualified paths (eg. https://...) | # or fully qualified paths (eg. https://...) | ||||
| html_css_files = [ | html_css_files = [ | ||||
| 'css/custom_style.css', | |||||
| "css/custom_style.css", | |||||
| ] | ] | ||||
| # -- Options for HTMLHelp output ------------------------------------------ | # -- Options for HTMLHelp output ------------------------------------------ | ||||
| @@ -85,9 +85,7 @@ def get_split_errs(algo): | |||||
| split = train_xs.shape[0] - proportion_list[tmp] | split = train_xs.shape[0] - proportion_list[tmp] | ||||
| model.fit( | model.fit( | ||||
| train_xs[ | |||||
| split:, | |||||
| ], | |||||
| train_xs[split:,], | |||||
| train_ys[split:], | train_ys[split:], | ||||
| eval_set=[(val_xs, val_ys)], | eval_set=[(val_xs, val_ys)], | ||||
| early_stopping_rounds=50, | early_stopping_rounds=50, | ||||
| @@ -86,7 +86,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||||
| pass | pass | ||||
| except Exception as err: | except Exception as err: | ||||
| logger.error(err) | logger.error(err) | ||||
| return None | return None | ||||
| exist_packages = [] | exist_packages = [] | ||||
| @@ -101,7 +101,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||||
| exist_packages.append(result) | exist_packages.append(result) | ||||
| else: | else: | ||||
| nonexist_packages.append(package) | nonexist_packages.append(package) | ||||
| if len(nonexist_packages) > 0: | if len(nonexist_packages) > 0: | ||||
| logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.") | logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.") | ||||
| return exist_packages, nonexist_packages | return exist_packages, nonexist_packages | ||||
| @@ -13,7 +13,9 @@ from ..utils import read_yaml_to_dict | |||||
| logger = get_module_logger("learnware.learnware") | logger = get_module_logger("learnware.learnware") | ||||
| def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Optional[Learnware]: | |||||
| def get_learnware_from_dirpath( | |||||
| id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True | |||||
| ) -> Optional[Learnware]: | |||||
| """Get the learnware object from dirpath, and provide the manage interface tor Learnware class | """Get the learnware object from dirpath, and provide the manage interface tor Learnware class | ||||
| Parameters | Parameters | ||||
| @@ -46,11 +48,11 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||||
| } | } | ||||
| try: | try: | ||||
| learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"]) | learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"]) | ||||
| assert os.path.exists(learnware_yaml_path), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." | |||||
| assert os.path.exists( | |||||
| learnware_yaml_path | |||||
| ), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." | |||||
| yaml_config = read_yaml_to_dict(learnware_yaml_path) | yaml_config = read_yaml_to_dict(learnware_yaml_path) | ||||
| if "name" in yaml_config: | if "name" in yaml_config: | ||||
| @@ -67,8 +69,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||||
| for _stat_spec in learnware_config["stat_specifications"]: | for _stat_spec in learnware_config["stat_specifications"]: | ||||
| stat_spec = _stat_spec.copy() | stat_spec = _stat_spec.copy() | ||||
| stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"]) | stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"]) | ||||
| assert os.path.exists(stat_spec_path), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." | |||||
| assert os.path.exists( | |||||
| stat_spec_path | |||||
| ), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." | |||||
| stat_spec["file_name"] = stat_spec_path | stat_spec["file_name"] = stat_spec_path | ||||
| stat_spec_inst = get_stat_spec_from_config(stat_spec) | stat_spec_inst = get_stat_spec_from_config(stat_spec) | ||||
| learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst}) | learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst}) | ||||
| @@ -1,9 +1,7 @@ | |||||
| from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo | from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo | ||||
| from .base import (BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, | |||||
| LearnwareMarket) | |||||
| from .base import BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, LearnwareMarket | |||||
| from .classes import CondaChecker | from .classes import CondaChecker | ||||
| from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker, | |||||
| EasyStatChecker) | |||||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||||
| from .evolve import EvolvedOrganizer | from .evolve import EvolvedOrganizer | ||||
| from .evolve_anchor import EvolvedAnchoredOrganizer | from .evolve_anchor import EvolvedAnchoredOrganizer | ||||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | ||||
| @@ -11,5 +11,4 @@ if not is_torch_available(verbose=False): | |||||
| logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") | logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") | ||||
| else: | else: | ||||
| from .checker import EasySemanticChecker, EasyStatChecker | from .checker import EasySemanticChecker, EasyStatChecker | ||||
| from .searcher import (EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, | |||||
| EasySearcher, EasyStatSearcher) | |||||
| from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasySearcher, EasyStatSearcher | |||||
| @@ -6,13 +6,11 @@ import torch | |||||
| from rapidfuzz import fuzz | from rapidfuzz import fuzz | ||||
| from .organizer import EasyOrganizer | from .organizer import EasyOrganizer | ||||
| from ..base import (BaseSearcher, BaseUserInfo, MultipleSearchItem, | |||||
| SearchResults, SingleSearchItem) | |||||
| from ..base import BaseSearcher, BaseUserInfo, MultipleSearchItem, SearchResults, SingleSearchItem | |||||
| from ..utils import parse_specification_type | from ..utils import parse_specification_type | ||||
| from ...learnware import Learnware | from ...learnware import Learnware | ||||
| from ...logger import get_module_logger | from ...logger import get_module_logger | ||||
| from ...specification import (RKMEImageSpecification, RKMETableSpecification, | |||||
| RKMETextSpecification, rkme_solve_qp) | |||||
| from ...specification import RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification, rkme_solve_qp | |||||
| logger = get_module_logger("easy_seacher") | logger = get_module_logger("easy_seacher") | ||||
| @@ -8,8 +8,7 @@ from torch import nn | |||||
| from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer | from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer | ||||
| from .trainer import Trainer, TransTabCollatorForCL | from .trainer import Trainer, TransTabCollatorForCL | ||||
| from .....specification import (HeteroMapTableSpecification, | |||||
| RKMETableSpecification) | |||||
| from .....specification import HeteroMapTableSpecification, RKMETableSpecification | |||||
| from .....utils import allocate_cuda_idx, choose_device | from .....utils import allocate_cuda_idx, choose_device | ||||
| @@ -288,7 +287,7 @@ class HeteroMap(nn.Module): | |||||
| # go through transformers, get the first cls embedding | # go through transformers, get the first cls embedding | ||||
| encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim | encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim | ||||
| output_features = encoder_output[:, 0, :] | output_features = encoder_output[:, 0, :] | ||||
| del inputs, outputs, encoder_output | del inputs, outputs, encoder_output | ||||
| torch.cuda.empty_cache() | torch.cuda.empty_cache() | ||||
| @@ -11,7 +11,11 @@ logger = get_module_logger("hetero_searcher") | |||||
| class HeteroSearcher(EasySearcher): | class HeteroSearcher(EasySearcher): | ||||
| def __call__( | def __call__( | ||||
| self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy" | |||||
| self, | |||||
| user_info: BaseUserInfo, | |||||
| check_status: Optional[int] = None, | |||||
| max_search_num: int = 5, | |||||
| search_method: str = "greedy", | |||||
| ) -> SearchResults: | ) -> SearchResults: | ||||
| """Search learnwares based on user_info from learnwares with check_status. | """Search learnwares based on user_info from learnwares with check_status. | ||||
| Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | ||||
| @@ -1,11 +1,12 @@ | |||||
| from .base import LearnwareMarket | from .base import LearnwareMarket | ||||
| from .classes import CondaChecker | from .classes import CondaChecker | ||||
| from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker, | |||||
| EasyStatChecker) | |||||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | ||||
| def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False): | |||||
| def get_market_component( | |||||
| name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False | |||||
| ): | |||||
| organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs | organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs | ||||
| searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs | searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs | ||||
| checker_kwargs = {} if checker_kwargs is None else checker_kwargs | checker_kwargs = {} if checker_kwargs is None else checker_kwargs | ||||
| @@ -13,7 +14,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||||
| if name == "easy": | if name == "easy": | ||||
| easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) | easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) | ||||
| easy_searcher = EasySearcher(organizer=easy_organizer) | easy_searcher = EasySearcher(organizer=easy_organizer) | ||||
| easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||||
| easy_checker_list = [ | |||||
| EasySemanticChecker(), | |||||
| EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), | |||||
| ] | |||||
| market_component = { | market_component = { | ||||
| "organizer": easy_organizer, | "organizer": easy_organizer, | ||||
| "searcher": easy_searcher, | "searcher": easy_searcher, | ||||
| @@ -22,7 +26,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||||
| elif name == "hetero": | elif name == "hetero": | ||||
| hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) | hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) | ||||
| hetero_searcher = HeteroSearcher(organizer=hetero_organizer) | hetero_searcher = HeteroSearcher(organizer=hetero_organizer) | ||||
| hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||||
| hetero_checker_list = [ | |||||
| EasySemanticChecker(), | |||||
| EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()), | |||||
| ] | |||||
| market_component = { | market_component = { | ||||
| "organizer": hetero_organizer, | "organizer": hetero_organizer, | ||||
| @@ -45,7 +52,9 @@ def instantiate_learnware_market( | |||||
| conda_checker: bool = False, | conda_checker: bool = False, | ||||
| **kwargs, | **kwargs, | ||||
| ): | ): | ||||
| market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker) | |||||
| market_componets = get_market_component( | |||||
| name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker | |||||
| ) | |||||
| return LearnwareMarket( | return LearnwareMarket( | ||||
| organizer=market_componets["organizer"], | organizer=market_componets["organizer"], | ||||
| searcher=market_componets["searcher"], | searcher=market_componets["searcher"], | ||||
| @@ -20,4 +20,4 @@ else: | |||||
| from .ensemble_pruning import EnsemblePruningReuser | from .ensemble_pruning import EnsemblePruningReuser | ||||
| from .feature_augment import FeatureAugmentReuser | from .feature_augment import FeatureAugmentReuser | ||||
| from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware | from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware | ||||
| from .job_selector import JobSelectorReuser | |||||
| from .job_selector import JobSelectorReuser | |||||
| @@ -54,13 +54,14 @@ class EnsemblePruningReuser(BaseReuser): | |||||
| np.ndarray | np.ndarray | ||||
| Binary one-dimensional vector, 1 indicates that the corresponding model is selected. | Binary one-dimensional vector, 1 indicates that the corresponding model is selected. | ||||
| """ | """ | ||||
| try: | try: | ||||
| import geatpy as ea | import geatpy as ea | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") | |||||
| raise ModuleNotFoundError( | |||||
| f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." | |||||
| ) | |||||
| model_num = v_predict.shape[1] | model_num = v_predict.shape[1] | ||||
| @ea.Problem.single | @ea.Problem.single | ||||
| @@ -148,7 +149,9 @@ class EnsemblePruningReuser(BaseReuser): | |||||
| try: | try: | ||||
| import geatpy as ea | import geatpy as ea | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") | |||||
| raise ModuleNotFoundError( | |||||
| f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." | |||||
| ) | |||||
| if torch.is_tensor(v_true): | if torch.is_tensor(v_true): | ||||
| v_true = v_true.detach().cpu().numpy() | v_true = v_true.detach().cpu().numpy() | ||||
| @@ -270,8 +273,10 @@ class EnsemblePruningReuser(BaseReuser): | |||||
| try: | try: | ||||
| import geatpy as ea | import geatpy as ea | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") | |||||
| raise ModuleNotFoundError( | |||||
| f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)." | |||||
| ) | |||||
| model_num = v_predict.shape[1] | model_num = v_predict.shape[1] | ||||
| v_predict[v_predict == 0.0] = -1 | v_predict[v_predict == 0.0] = -1 | ||||
| v_true[v_true == 0.0] = -1 | v_true[v_true == 0.0] = -1 | ||||
| @@ -8,8 +8,7 @@ from .base import BaseReuser | |||||
| from ..learnware import Learnware | from ..learnware import Learnware | ||||
| from ..logger import get_module_logger | from ..logger import get_module_logger | ||||
| from ..market.utils import parse_specification_type | from ..market.utils import parse_specification_type | ||||
| from ..specification import (RKMETableSpecification, RKMETextSpecification, | |||||
| generate_rkme_table_spec, rkme_solve_qp) | |||||
| from ..specification import RKMETableSpecification, RKMETextSpecification, generate_rkme_table_spec, rkme_solve_qp | |||||
| logger = get_module_logger("job_selector_reuse") | logger = get_module_logger("job_selector_reuse") | ||||
| @@ -4,6 +4,7 @@ from ..logger import get_module_logger | |||||
| logger = get_module_logger("reuse_utils") | logger = get_module_logger("reuse_utils") | ||||
| def fill_data_with_mean(X: np.ndarray) -> np.ndarray: | def fill_data_with_mean(X: np.ndarray) -> np.ndarray: | ||||
| """ | """ | ||||
| Fill missing data (NaN, Inf) in the input array with the mean of the column. | Fill missing data (NaN, Inf) in the input array with the mean of the column. | ||||
| @@ -1,7 +1,12 @@ | |||||
| from .base import BaseStatSpecification, Specification | from .base import BaseStatSpecification, Specification | ||||
| from .regular import (RegularStatSpecification, RKMEImageSpecification, | |||||
| RKMEStatSpecification, RKMETableSpecification, | |||||
| RKMETextSpecification, rkme_solve_qp) | |||||
| from .regular import ( | |||||
| RegularStatSpecification, | |||||
| RKMEImageSpecification, | |||||
| RKMEStatSpecification, | |||||
| RKMETableSpecification, | |||||
| RKMETextSpecification, | |||||
| rkme_solve_qp, | |||||
| ) | |||||
| from .system import HeteroMapTableSpecification | from .system import HeteroMapTableSpecification | ||||
| from ..utils import is_torch_available | from ..utils import is_torch_available | ||||
| @@ -12,6 +17,10 @@ if not is_torch_available(verbose=False): | |||||
| generate_rkme_text_spec = None | generate_rkme_text_spec = None | ||||
| generate_semantic_spec = None | generate_semantic_spec = None | ||||
| else: | else: | ||||
| from .module import (generate_rkme_image_spec, generate_rkme_table_spec, | |||||
| generate_rkme_text_spec, generate_semantic_spec, | |||||
| generate_stat_spec) | |||||
| from .module import ( | |||||
| generate_rkme_image_spec, | |||||
| generate_rkme_table_spec, | |||||
| generate_rkme_text_spec, | |||||
| generate_semantic_spec, | |||||
| generate_stat_spec, | |||||
| ) | |||||
| @@ -11,5 +11,4 @@ if not is_torch_available(verbose=False): | |||||
| f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!" | f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!" | ||||
| ) | ) | ||||
| else: | else: | ||||
| from .rkme import (RKMEStatSpecification, RKMETableSpecification, | |||||
| rkme_solve_qp) | |||||
| from .rkme import RKMEStatSpecification, RKMETableSpecification, rkme_solve_qp | |||||
| @@ -87,12 +87,14 @@ class RKMETextSpecification(RKMETableSpecification): | |||||
| return np.array(miniLM_learnware.predict(X)) | return np.array(miniLM_learnware.predict(X)) | ||||
| logger.info("Load the necessary feature extractor for RKMETextSpecification.") | logger.info("Load the necessary feature extractor for RKMETextSpecification.") | ||||
| try: | try: | ||||
| from sentence_transformers import SentenceTransformer | from sentence_transformers import SentenceTransformer | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.") | |||||
| raise ModuleNotFoundError( | |||||
| f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually." | |||||
| ) | |||||
| if os.path.exists(zip_path): | if os.path.exists(zip_path): | ||||
| X = _get_from_client(zip_path, X) | X = _get_from_client(zip_path, X) | ||||
| else: | else: | ||||
| @@ -137,7 +137,9 @@ class HeteroMapTableSpecification(SystemStatSpecification): | |||||
| for d in self.get_states(): | for d in self.get_states(): | ||||
| if d in embedding_load.keys(): | if d in embedding_load.keys(): | ||||
| if d == "type" and embedding_load[d] != self.type: | if d == "type" and embedding_load[d] != self.type: | ||||
| raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!") | |||||
| raise TypeError( | |||||
| f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!" | |||||
| ) | |||||
| setattr(self, d, embedding_load[d]) | setattr(self, d, embedding_load[d]) | ||||
| def save(self, filepath: str) -> bool: | def save(self, filepath: str) -> bool: | ||||
| @@ -1 +1 @@ | |||||
| from .utils import parametrize | |||||
| from .utils import parametrize | |||||
| @@ -13,10 +13,13 @@ class ModelTemplate: | |||||
| class_name: str = field(init=False) | class_name: str = field(init=False) | ||||
| template_path: str = field(init=False) | template_path: str = field(init=False) | ||||
| model_kwargs: dict = field(init=False) | model_kwargs: dict = field(init=False) | ||||
| @dataclass | @dataclass | ||||
| class PickleModelTemplate(ModelTemplate): | class PickleModelTemplate(ModelTemplate): | ||||
| model_kwargs: dict | model_kwargs: dict | ||||
| pickle_filepath: str | pickle_filepath: str | ||||
| def __post_init__(self): | def __post_init__(self): | ||||
| self.class_name = "PickleLoadedModel" | self.class_name = "PickleLoadedModel" | ||||
| self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") | self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") | ||||
| @@ -29,13 +32,14 @@ class PickleModelTemplate(ModelTemplate): | |||||
| default_model_kwargs.update(self.model_kwargs) | default_model_kwargs.update(self.model_kwargs) | ||||
| self.model_kwargs = default_model_kwargs | self.model_kwargs = default_model_kwargs | ||||
| @dataclass | @dataclass | ||||
| class StatSpecTemplate: | class StatSpecTemplate: | ||||
| filepath: str | filepath: str | ||||
| type: str = field(default="RKMETableSpecification") | type: str = field(default="RKMETableSpecification") | ||||
| class LearnwareTemplate: | |||||
| class LearnwareTemplate: | |||||
| @staticmethod | @staticmethod | ||||
| def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None): | def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None): | ||||
| requirements = [] if requirements is None else requirements | requirements = [] if requirements is None else requirements | ||||
| @@ -49,14 +53,16 @@ class LearnwareTemplate: | |||||
| line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n" | line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n" | ||||
| else: | else: | ||||
| raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}") | raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}") | ||||
| requirements_str += line_str | requirements_str += line_str | ||||
| with open(filepath, "w") as fdout: | with open(filepath, "w") as fdout: | ||||
| fdout.write(requirements_str) | fdout.write(requirements_str) | ||||
| @staticmethod | @staticmethod | ||||
| def generate_learnware_yaml(filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None): | |||||
| def generate_learnware_yaml( | |||||
| filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None | |||||
| ): | |||||
| learnware_config = {} | learnware_config = {} | ||||
| if model_config is not None: | if model_config is not None: | ||||
| learnware_config["model"] = model_config | learnware_config["model"] = model_config | ||||
| @@ -64,7 +70,7 @@ class LearnwareTemplate: | |||||
| learnware_config["stat_specifications"] = stat_spec_config | learnware_config["stat_specifications"] = stat_spec_config | ||||
| save_dict_to_yaml(learnware_config, filepath) | save_dict_to_yaml(learnware_config, filepath) | ||||
| @staticmethod | @staticmethod | ||||
| def generate_learnware_zipfile( | def generate_learnware_zipfile( | ||||
| learnware_zippath: str, | learnware_zippath: str, | ||||
| @@ -75,27 +81,29 @@ class LearnwareTemplate: | |||||
| with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: | with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: | ||||
| requirement_filepath = os.path.join(tempdir, "requirements.txt") | requirement_filepath = os.path.join(tempdir, "requirements.txt") | ||||
| LearnwareTemplate.generate_requirements(requirement_filepath, requirements) | LearnwareTemplate.generate_requirements(requirement_filepath, requirements) | ||||
| model_filepath = os.path.join(tempdir, "__init__.py") | |||||
| model_filepath = os.path.join(tempdir, "__init__.py") | |||||
| copyfile(model_template.template_path, model_filepath) | copyfile(model_template.template_path, model_filepath) | ||||
| learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml") | learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml") | ||||
| model_config = { | model_config = { | ||||
| "class_name": model_template.class_name, | "class_name": model_template.class_name, | ||||
| "kwargs": model_template.model_kwargs, | "kwargs": model_template.model_kwargs, | ||||
| } | } | ||||
| stat_spec_config = { | stat_spec_config = { | ||||
| "module_path": "learnware.specification", | "module_path": "learnware.specification", | ||||
| "class_name": stat_spec_template.type, | "class_name": stat_spec_template.type, | ||||
| "file_name": "stat_spec.json", | "file_name": "stat_spec.json", | ||||
| "kwargs": {} | |||||
| "kwargs": {}, | |||||
| } | } | ||||
| copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"])) | copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"])) | ||||
| LearnwareTemplate.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config]) | |||||
| LearnwareTemplate.generate_learnware_yaml( | |||||
| learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config] | |||||
| ) | |||||
| if isinstance(model_template, PickleModelTemplate): | if isinstance(model_template, PickleModelTemplate): | ||||
| pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"]) | pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"]) | ||||
| copyfile(model_template.pickle_filepath, pickle_filepath) | copyfile(model_template.pickle_filepath, pickle_filepath) | ||||
| convert_folder_to_zipfile(tempdir, learnware_zippath) | convert_folder_to_zipfile(tempdir, learnware_zippath) | ||||
| @@ -7,7 +7,6 @@ from learnware.model.base import BaseModel | |||||
| class PickleLoadedModel(BaseModel): | class PickleLoadedModel(BaseModel): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| input_shape, | input_shape, | ||||
| @@ -25,10 +24,10 @@ class PickleLoadedModel(BaseModel): | |||||
| self.predict_method = predict_method | self.predict_method = predict_method | ||||
| self.fit_method = fit_method | self.fit_method = fit_method | ||||
| self.finetune_method = finetune_method | self.finetune_method = finetune_method | ||||
| def predict(self, X: np.ndarray) -> np.ndarray: | def predict(self, X: np.ndarray) -> np.ndarray: | ||||
| return getattr(self.model, self.predict_method)(X) | return getattr(self.model, self.predict_method)(X) | ||||
| def fit(self, X: np.ndarray, y: np.ndarray): | def fit(self, X: np.ndarray, y: np.ndarray): | ||||
| getattr(self.model, self.fit_method)(X, y) | getattr(self.model, self.fit_method)(X, y) | ||||
| @@ -7,4 +7,4 @@ def parametrize(test_class, **kwargs): | |||||
| _suite = unittest.TestSuite() | _suite = unittest.TestSuite() | ||||
| for name in test_names: | for name in test_names: | ||||
| _suite.addTest(test_class(name, **kwargs)) | _suite.addTest(test_class(name, **kwargs)) | ||||
| return _suite | |||||
| return _suite | |||||
| @@ -1,8 +1,7 @@ | |||||
| import os | import os | ||||
| import zipfile | import zipfile | ||||
| from .file import (convert_folder_to_zipfile, read_yaml_to_dict, | |||||
| save_dict_to_yaml) | |||||
| from .file import convert_folder_to_zipfile, read_yaml_to_dict, save_dict_to_yaml | |||||
| from .gpu import allocate_cuda_idx, choose_device, setup_seed | from .gpu import allocate_cuda_idx, choose_device, setup_seed | ||||
| from .import_utils import is_torch_available | from .import_utils import is_torch_available | ||||
| from .module import get_module_by_module_path | from .module import get_module_by_module_path | ||||
| @@ -16,6 +16,7 @@ def read_yaml_to_dict(yaml_path: str): | |||||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | ||||
| return dict_value | return dict_value | ||||
| def convert_folder_to_zipfile(folder_path, zip_path): | def convert_folder_to_zipfile(folder_path, zip_path): | ||||
| with zipfile.ZipFile(zip_path, "w") as zip_obj: | with zipfile.ZipFile(zip_path, "w") as zip_obj: | ||||
| for foldername, subfolders, filenames in os.walk(folder_path): | for foldername, subfolders, filenames in os.walk(folder_path): | ||||
| @@ -17,6 +17,7 @@ def setup_seed(seed): | |||||
| random.seed(seed) | random.seed(seed) | ||||
| if is_torch_available(verbose=False): | if is_torch_available(verbose=False): | ||||
| import torch | import torch | ||||
| torch.manual_seed(seed) | torch.manual_seed(seed) | ||||
| torch.cuda.manual_seed_all(seed) | torch.cuda.manual_seed_all(seed) | ||||
| torch.backends.cudnn.deterministic = True | torch.backends.cudnn.deterministic = True | ||||
| @@ -4,15 +4,16 @@ import numpy as np | |||||
| from learnware.client import LearnwareClient | from learnware.client import LearnwareClient | ||||
| from learnware.client.container import LearnwaresContainer | from learnware.client.container import LearnwaresContainer | ||||
| class TestContainer(unittest.TestCase): | class TestContainer(unittest.TestCase): | ||||
| def __init__(self, method_name='runTest', mode="all"): | |||||
| def __init__(self, method_name="runTest", mode="all"): | |||||
| super(TestContainer, self).__init__(method_name) | super(TestContainer, self).__init__(method_name) | ||||
| self.modes = [] | self.modes = [] | ||||
| if mode in {"all", "conda"}: | if mode in {"all", "conda"}: | ||||
| self.modes.append("conda") | self.modes.append("conda") | ||||
| if mode in {"all", "docker"}: | if mode in {"all", "docker"}: | ||||
| self.modes.append("docker") | self.modes.append("docker") | ||||
| def setUp(self): | def setUp(self): | ||||
| self.client = LearnwareClient() | self.client = LearnwareClient() | ||||
| @@ -35,17 +36,19 @@ class TestContainer(unittest.TestCase): | |||||
| def test_container_with_pip(self): | def test_container_with_pip(self): | ||||
| for mode in self.modes: | for mode in self.modes: | ||||
| self._test_container_with_pip(mode=mode) | self._test_container_with_pip(mode=mode) | ||||
| def test_container_with_conda(self): | def test_container_with_conda(self): | ||||
| for mode in self.modes: | for mode in self.modes: | ||||
| self._test_container_with_conda(mode=mode) | self._test_container_with_conda(mode=mode) | ||||
| def suite(): | def suite(): | ||||
| _suite = unittest.TestSuite() | _suite = unittest.TestSuite() | ||||
| _suite.addTest(TestContainer("test_container_with_pip", mode="all")) | _suite.addTest(TestContainer("test_container_with_pip", mode="all")) | ||||
| _suite.addTest(TestContainer("test_container_with_conda", mode="all")) | _suite.addTest(TestContainer("test_container_with_conda", mode="all")) | ||||
| return _suite | return _suite | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| runner = unittest.TextTestRunner() | runner = unittest.TextTestRunner() | ||||
| runner.run(suite()) | |||||
| runner.run(suite()) | |||||
| @@ -5,8 +5,9 @@ import numpy as np | |||||
| from learnware.client import LearnwareClient | from learnware.client import LearnwareClient | ||||
| from learnware.reuse import AveragingReuser | from learnware.reuse import AveragingReuser | ||||
| class TestLearnwareLoad(unittest.TestCase): | class TestLearnwareLoad(unittest.TestCase): | ||||
| def __init__(self, method_name='runTest', mode="all"): | |||||
| def __init__(self, method_name="runTest", mode="all"): | |||||
| super(TestLearnwareLoad, self).__init__(method_name) | super(TestLearnwareLoad, self).__init__(method_name) | ||||
| self.runnable_options = [] | self.runnable_options = [] | ||||
| if mode in {"all", "conda"}: | if mode in {"all", "conda"}: | ||||
| @@ -31,7 +32,6 @@ class TestLearnwareLoad(unittest.TestCase): | |||||
| for learnware in learnware_list: | for learnware in learnware_list: | ||||
| print(learnware.id, learnware.predict(input_array)) | print(learnware.id, learnware.predict(input_array)) | ||||
| def _test_load_learnware_by_id(self, runnable_option): | def _test_load_learnware_by_id(self, runnable_option): | ||||
| learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option) | learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option) | ||||
| reuser = AveragingReuser(learnware_list, mode="vote_by_label") | reuser = AveragingReuser(learnware_list, mode="vote_by_label") | ||||
| @@ -44,11 +44,11 @@ class TestLearnwareLoad(unittest.TestCase): | |||||
| def test_load_learnware_by_zippath(self): | def test_load_learnware_by_zippath(self): | ||||
| for runnable_option in self.runnable_options: | for runnable_option in self.runnable_options: | ||||
| self._test_load_learnware_by_zippath(runnable_option=runnable_option) | self._test_load_learnware_by_zippath(runnable_option=runnable_option) | ||||
| def test_load_learnware_by_id(self): | def test_load_learnware_by_id(self): | ||||
| for runnable_option in self.runnable_options: | for runnable_option in self.runnable_options: | ||||
| self._test_load_learnware_by_id(runnable_option=runnable_option) | self._test_load_learnware_by_id(runnable_option=runnable_option) | ||||
| def suite(): | def suite(): | ||||
| _suite = unittest.TestSuite() | _suite = unittest.TestSuite() | ||||
| @@ -56,6 +56,7 @@ def suite(): | |||||
| _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all")) | _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all")) | ||||
| return _suite | return _suite | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| runner = unittest.TextTestRunner() | runner = unittest.TextTestRunner() | ||||
| runner.run(suite()) | |||||
| runner.run(suite()) | |||||
| @@ -11,11 +11,11 @@ from learnware.specification import RKMETableSpecification, HeteroMapTableSpecif | |||||
| from learnware.specification import generate_stat_spec | from learnware.specification import generate_stat_spec | ||||
| from learnware.market.heterogeneous.organizer import HeteroMap | from learnware.market.heterogeneous.organizer import HeteroMap | ||||
| class TestTableRKME(unittest.TestCase): | class TestTableRKME(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| self.hetero_map = HeteroMap() | self.hetero_map = HeteroMap() | ||||
| def _test_hetero_spec(self, X): | def _test_hetero_spec(self, X): | ||||
| rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X) | rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X) | ||||
| hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict()) | hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict()) | ||||
| @@ -30,14 +30,14 @@ class TestTableRKME(unittest.TestCase): | |||||
| rkme2 = HeteroMapTableSpecification() | rkme2 = HeteroMapTableSpecification() | ||||
| rkme2.load(rkme_path) | rkme2.load(rkme_path) | ||||
| assert rkme2.type == "HeteroMapTableSpecification" | assert rkme2.type == "HeteroMapTableSpecification" | ||||
| def test_hetero_rkme(self): | def test_hetero_rkme(self): | ||||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200))) | self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200))) | ||||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100))) | self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100))) | ||||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20))) | self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20))) | ||||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50))) | self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50))) | ||||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150))) | self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150))) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| unittest.main() | unittest.main() | ||||
| @@ -25,7 +25,7 @@ class TestImageRKME(unittest.TestCase): | |||||
| rkme2 = RKMEImageSpecification() | rkme2 = RKMEImageSpecification() | ||||
| rkme2.load(rkme_path) | rkme2.load(rkme_path) | ||||
| assert rkme2.type == "RKMEImageSpecification" | assert rkme2.type == "RKMEImageSpecification" | ||||
| def test_image_rkme(self): | def test_image_rkme(self): | ||||
| self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | ||||
| self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) | self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) | ||||
| @@ -34,5 +34,6 @@ class TestImageRKME(unittest.TestCase): | |||||
| self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) | self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) | ||||
| self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) | self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| unittest.main() | unittest.main() | ||||
| @@ -24,7 +24,7 @@ class TestTableRKME(unittest.TestCase): | |||||
| rkme2 = RKMETableSpecification() | rkme2 = RKMETableSpecification() | ||||
| rkme2.load(rkme_path) | rkme2.load(rkme_path) | ||||
| assert rkme2.type == "RKMETableSpecification" | assert rkme2.type == "RKMETableSpecification" | ||||
| def test_table_rkme(self): | def test_table_rkme(self): | ||||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) | self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) | ||||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) | self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) | ||||
| @@ -32,5 +32,6 @@ class TestTableRKME(unittest.TestCase): | |||||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) | self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) | ||||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) | self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| unittest.main() | unittest.main() | ||||
| @@ -12,19 +12,19 @@ from learnware.specification import generate_stat_spec | |||||
| class TestTextRKME(unittest.TestCase): | class TestTextRKME(unittest.TestCase): | ||||
| @staticmethod | @staticmethod | ||||
| def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | ||||
| text_list = [] | |||||
| for i in range(num): | |||||
| length = random.randint(min_len, max_len) | |||||
| if text_type == "en": | |||||
| characters = string.ascii_letters + string.digits + string.punctuation | |||||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||||
| text_list.append(result_str) | |||||
| elif text_type == "zh": | |||||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||||
| text_list.append(result_str) | |||||
| else: | |||||
| raise ValueError("Type should be en or zh") | |||||
| return text_list | |||||
| text_list = [] | |||||
| for i in range(num): | |||||
| length = random.randint(min_len, max_len) | |||||
| if text_type == "en": | |||||
| characters = string.ascii_letters + string.digits + string.punctuation | |||||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||||
| text_list.append(result_str) | |||||
| elif text_type == "zh": | |||||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||||
| text_list.append(result_str) | |||||
| else: | |||||
| raise ValueError("Type should be en or zh") | |||||
| return text_list | |||||
| @staticmethod | @staticmethod | ||||
| def _test_text_rkme(X): | def _test_text_rkme(X): | ||||
| @@ -11,6 +11,7 @@ from shutil import copyfile, rmtree | |||||
| from sklearn.metrics import mean_squared_error | from sklearn.metrics import mean_squared_error | ||||
| import learnware | import learnware | ||||
| learnware.init(logging_level=logging.WARNING) | learnware.init(logging_level=logging.WARNING) | ||||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | from learnware.market import instantiate_learnware_market, BaseUserInfo | ||||
| @@ -23,6 +24,7 @@ from hetero_config import input_shape_list, input_description_list, output_descr | |||||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | curr_root = os.path.dirname(os.path.abspath(__file__)) | ||||
| class TestHeteroWorkflow(unittest.TestCase): | class TestHeteroWorkflow(unittest.TestCase): | ||||
| universal_semantic_config = { | universal_semantic_config = { | ||||
| "data_type": "Table", | "data_type": "Table", | ||||
| @@ -46,10 +48,12 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero") | learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero") | ||||
| os.makedirs(learnware_pool_dirpath, exist_ok=True) | os.makedirs(learnware_pool_dirpath, exist_ok=True) | ||||
| learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i)) | learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i)) | ||||
| print("Preparing Learnware: %d" % (i)) | print("Preparing Learnware: %d" % (i)) | ||||
| X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42) | |||||
| X, y = make_regression( | |||||
| n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42 | |||||
| ) | |||||
| clf = Ridge(alpha=1.0) | clf = Ridge(alpha=1.0) | ||||
| clf.fit(X, y) | clf.fit(X, y) | ||||
| pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl") | pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl") | ||||
| @@ -62,14 +66,16 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| LearnwareTemplate.generate_learnware_zipfile( | LearnwareTemplate.generate_learnware_zipfile( | ||||
| learnware_zippath=learnware_zippath, | learnware_zippath=learnware_zippath, | ||||
| model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}), | |||||
| model_template=PickleModelTemplate( | |||||
| pickle_filepath=pickle_filepath, | |||||
| model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)}, | |||||
| ), | |||||
| stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | ||||
| requirements=["scikit-learn==0.22"], | requirements=["scikit-learn==0.22"], | ||||
| ) | ) | ||||
| self.zip_path_list.append(learnware_zippath) | self.zip_path_list.append(learnware_zippath) | ||||
| def _upload_delete_learnware(self, hetero_market, learnware_num, delete): | def _upload_delete_learnware(self, hetero_market, learnware_num, delete): | ||||
| self.test_prepare_learnware_randomly(learnware_num) | self.test_prepare_learnware_randomly(learnware_num) | ||||
| self.learnware_num = learnware_num | self.learnware_num = learnware_num | ||||
| @@ -83,7 +89,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| description=f"test_learnware_number_{idx}", | description=f"test_learnware_number_{idx}", | ||||
| input_description=input_description_list[idx % 2], | input_description=input_description_list[idx % 2], | ||||
| output_description=output_description_list[idx % 2], | output_description=output_description_list[idx % 2], | ||||
| **self.universal_semantic_config | |||||
| **self.universal_semantic_config, | |||||
| ) | ) | ||||
| hetero_market.add_learnware(zip_path, semantic_spec) | hetero_market.add_learnware(zip_path, semantic_spec) | ||||
| @@ -106,7 +112,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| assert len(curr_inds) == 0, f"The market should be empty!" | assert len(curr_inds) == 0, f"The market should be empty!" | ||||
| return hetero_market | return hetero_market | ||||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | def test_upload_delete_learnware(self, learnware_num=5, delete=True): | ||||
| hetero_market = self._init_learnware_market() | hetero_market = self._init_learnware_market() | ||||
| return self._upload_delete_learnware(hetero_market, learnware_num, delete) | return self._upload_delete_learnware(hetero_market, learnware_num, delete) | ||||
| @@ -129,7 +135,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| name=f"learnware_{learnware_num - 1}", | name=f"learnware_{learnware_num - 1}", | ||||
| **self.universal_semantic_config, | **self.universal_semantic_config, | ||||
| ) | ) | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | user_info = BaseUserInfo(semantic_spec=semantic_spec) | ||||
| search_result = hetero_market.search_learnware(user_info) | search_result = hetero_market.search_learnware(user_info) | ||||
| single_result = search_result.get_single_results() | single_result = search_result.get_single_results() | ||||
| @@ -154,7 +160,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| def test_hetero_stat_search(self, learnware_num=5): | def test_hetero_stat_search(self, learnware_num=5): | ||||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | hetero_market = self.test_train_market_model(learnware_num, delete=False) | ||||
| print("Total Item:", len(hetero_market)) | print("Total Item:", len(hetero_market)) | ||||
| user_dim = 15 | user_dim = 15 | ||||
| with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | ||||
| @@ -174,7 +180,10 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| semantic_spec = generate_semantic_spec( | semantic_spec = generate_semantic_spec( | ||||
| input_description={ | input_description={ | ||||
| "Dimension": user_dim, | "Dimension": user_dim, | ||||
| "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, | |||||
| "Description": { | |||||
| str(key): input_description_list[idx % 2]["Description"][str(key)] | |||||
| for key in range(user_dim) | |||||
| }, | |||||
| }, | }, | ||||
| **self.universal_semantic_config, | **self.universal_semantic_config, | ||||
| ) | ) | ||||
| @@ -182,7 +191,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| search_result = hetero_market.search_learnware(user_info) | search_result = hetero_market.search_learnware(user_info) | ||||
| single_result = search_result.get_single_results() | single_result = search_result.get_single_results() | ||||
| multiple_result = search_result.get_multiple_results() | multiple_result = search_result.get_multiple_results() | ||||
| print(f"search result of user{idx}:") | print(f"search result of user{idx}:") | ||||
| for single_item in single_result: | for single_item in single_result: | ||||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | ||||
| @@ -215,7 +224,10 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| semantic_spec = generate_semantic_spec( | semantic_spec = generate_semantic_spec( | ||||
| input_description={ | input_description={ | ||||
| "Dimension": user_dim - 2, | "Dimension": user_dim - 2, | ||||
| "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, | |||||
| "Description": { | |||||
| str(key): input_description_list[idx % 2]["Description"][str(key)] | |||||
| for key in range(user_dim) | |||||
| }, | |||||
| }, | }, | ||||
| **self.universal_semantic_config, | **self.universal_semantic_config, | ||||
| ) | ) | ||||
| @@ -228,7 +240,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| def test_homo_stat_search(self, learnware_num=5): | def test_homo_stat_search(self, learnware_num=5): | ||||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | hetero_market = self.test_train_market_model(learnware_num, delete=False) | ||||
| print("Total Item:", len(hetero_market)) | print("Total Item:", len(hetero_market)) | ||||
| with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | ||||
| for idx, zip_path in enumerate(self.zip_path_list): | for idx, zip_path in enumerate(self.zip_path_list): | ||||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | with zipfile.ZipFile(zip_path, "r") as zip_obj: | ||||
| @@ -260,7 +272,9 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) | user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) | ||||
| # generate specification | # generate specification | ||||
| semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config) | |||||
| semantic_spec = generate_semantic_spec( | |||||
| input_description=user_description_list[0], **self.universal_semantic_config | |||||
| ) | |||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | ||||
| # learnware market search | # learnware market search | ||||
| @@ -268,7 +282,7 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| search_result = hetero_market.search_learnware(user_info) | search_result = hetero_market.search_learnware(user_info) | ||||
| single_result = search_result.get_single_results() | single_result = search_result.get_single_results() | ||||
| multiple_result = search_result.get_multiple_results() | multiple_result = search_result.get_multiple_results() | ||||
| # print search results | # print search results | ||||
| for single_item in single_result: | for single_item in single_result: | ||||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | ||||
| @@ -306,9 +320,9 @@ class TestHeteroWorkflow(unittest.TestCase): | |||||
| def suite(): | def suite(): | ||||
| _suite = unittest.TestSuite() | _suite = unittest.TestSuite() | ||||
| #_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) | |||||
| #_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) | |||||
| #_suite.addTest(TestHeteroWorkflow("test_train_market_model")) | |||||
| # _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) | |||||
| # _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) | |||||
| # _suite.addTest(TestHeteroWorkflow("test_train_market_model")) | |||||
| _suite.addTest(TestHeteroWorkflow("test_search_semantics")) | _suite.addTest(TestHeteroWorkflow("test_search_semantics")) | ||||
| _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search")) | _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search")) | ||||
| _suite.addTest(TestHeteroWorkflow("test_homo_stat_search")) | _suite.addTest(TestHeteroWorkflow("test_homo_stat_search")) | ||||
| @@ -10,6 +10,7 @@ from sklearn.datasets import load_digits | |||||
| from sklearn.model_selection import train_test_split | from sklearn.model_selection import train_test_split | ||||
| import learnware | import learnware | ||||
| learnware.init(logging_level=logging.WARNING) | learnware.init(logging_level=logging.WARNING) | ||||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | from learnware.market import instantiate_learnware_market, BaseUserInfo | ||||
| @@ -19,8 +20,8 @@ from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, St | |||||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | curr_root = os.path.dirname(os.path.abspath(__file__)) | ||||
| class TestWorkflow(unittest.TestCase): | class TestWorkflow(unittest.TestCase): | ||||
| universal_semantic_config = { | universal_semantic_config = { | ||||
| "data_type": "Table", | "data_type": "Table", | ||||
| "task_type": "Classification", | "task_type": "Classification", | ||||
| @@ -28,7 +29,7 @@ class TestWorkflow(unittest.TestCase): | |||||
| "scenarios": "Education", | "scenarios": "Education", | ||||
| "license": "MIT", | "license": "MIT", | ||||
| } | } | ||||
| def _init_learnware_market(self): | def _init_learnware_market(self): | ||||
| """initialize learnware market""" | """initialize learnware market""" | ||||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) | easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) | ||||
| @@ -42,7 +43,7 @@ class TestWorkflow(unittest.TestCase): | |||||
| learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool") | learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool") | ||||
| os.makedirs(learnware_pool_dirpath, exist_ok=True) | os.makedirs(learnware_pool_dirpath, exist_ok=True) | ||||
| learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i)) | learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i)) | ||||
| print("Preparing Learnware: %d" % (i)) | print("Preparing Learnware: %d" % (i)) | ||||
| data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | ||||
| clf = svm.SVC(kernel="linear", probability=True) | clf = svm.SVC(kernel="linear", probability=True) | ||||
| @@ -54,14 +55,17 @@ class TestWorkflow(unittest.TestCase): | |||||
| spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | ||||
| spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") | spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") | ||||
| spec.save(spec_filepath) | spec.save(spec_filepath) | ||||
| LearnwareTemplate.generate_learnware_zipfile( | LearnwareTemplate.generate_learnware_zipfile( | ||||
| learnware_zippath=learnware_zippath, | learnware_zippath=learnware_zippath, | ||||
| model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}), | |||||
| model_template=PickleModelTemplate( | |||||
| pickle_filepath=pickle_filepath, | |||||
| model_kwargs={"input_shape": (64,), "output_shape": (10,), "predict_method": "predict_proba"}, | |||||
| ), | |||||
| stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | ||||
| requirements=["scikit-learn==0.22"], | requirements=["scikit-learn==0.22"], | ||||
| ) | ) | ||||
| self.zip_path_list.append(learnware_zippath) | self.zip_path_list.append(learnware_zippath) | ||||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | def test_upload_delete_learnware(self, learnware_num=5, delete=True): | ||||
| @@ -87,7 +91,7 @@ class TestWorkflow(unittest.TestCase): | |||||
| "Dimension": 10, | "Dimension": 10, | ||||
| "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | ||||
| }, | }, | ||||
| **self.universal_semantic_config | |||||
| **self.universal_semantic_config, | |||||
| ) | ) | ||||
| easy_market.add_learnware(zip_path, semantic_spec) | easy_market.add_learnware(zip_path, semantic_spec) | ||||
| @@ -113,7 +117,7 @@ class TestWorkflow(unittest.TestCase): | |||||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | ||||
| print("Total Item:", len(easy_market)) | print("Total Item:", len(easy_market)) | ||||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | ||||
| with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: | with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: | ||||
| with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | ||||
| zip_obj.extractall(path=test_folder) | zip_obj.extractall(path=test_folder) | ||||
| @@ -123,15 +127,15 @@ class TestWorkflow(unittest.TestCase): | |||||
| description=f"test_learnware_number_{learnware_num - 1}", | description=f"test_learnware_number_{learnware_num - 1}", | ||||
| **self.universal_semantic_config, | **self.universal_semantic_config, | ||||
| ) | ) | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | user_info = BaseUserInfo(semantic_spec=semantic_spec) | ||||
| search_result = easy_market.search_learnware(user_info) | search_result = easy_market.search_learnware(user_info) | ||||
| single_result = search_result.get_single_results() | single_result = search_result.get_single_results() | ||||
| print(f"Search result:") | print(f"Search result:") | ||||
| for search_item in single_result: | for search_item in single_result: | ||||
| print("Choose learnware:",search_item.learnware.id) | |||||
| print("Choose learnware:", search_item.learnware.id) | |||||
| def test_stat_search(self, learnware_num=5): | def test_stat_search(self, learnware_num=5): | ||||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | ||||
| print("Total Item:", len(easy_market)) | print("Total Item:", len(easy_market)) | ||||