| @@ -1,4 +1,4 @@ | |||
| __version__ = "0.2.0.5" | |||
| __version__ = "0.2.0.7" | |||
| import os | |||
| import json | |||
| @@ -21,7 +21,7 @@ def try_to_run(args, timeout=10, retry=3): | |||
| return result.stdout.decode() | |||
| except subprocess.TimeoutExpired as e: | |||
| pass | |||
| raise subprocess.TimeoutExpired(args, timeout) | |||
| @@ -30,18 +30,18 @@ def parse_pip_requirement(line: str): | |||
| line = line.strip() | |||
| if len(line) == 0 or line[0] in ("#", "-"): | |||
| return None | |||
| return None, None | |||
| package_name, package_version = line, line | |||
| for split_ch in ("=", ">", "<", "!", "~", " ", "="): | |||
| split_ch_index = package_name.find(split_ch) | |||
| if split_ch_index != -1: | |||
| package_name = package_name[:split_ch_index] | |||
| split_ch_index = package_version.find(split_ch) | |||
| if split_ch_index != -1: | |||
| package_version = package_version[split_ch_index + 1:] | |||
| package_version = package_version[split_ch_index + 1 :] | |||
| if package_version == package_name: | |||
| package_version = "" | |||
| @@ -71,6 +71,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||
| exist_packages: list of exist packages | |||
| nonexist_packages: list of non-exist packages | |||
| """ | |||
| def _filter_nonexist_pip_package_worker(package): | |||
| # Return filtered package | |||
| try: | |||
| @@ -83,13 +84,13 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||
| return package | |||
| except Exception as e: | |||
| logger.error(e) | |||
| return None | |||
| exist_packages = [] | |||
| nonexist_packages = [] | |||
| packages = [package for package in packages if package is not None] | |||
| with ThreadPoolExecutor(max_workers=max(os.cpu_count() // 5, 1)) as executor: | |||
| results = executor.map(_filter_nonexist_pip_package_worker, packages) | |||
| @@ -98,7 +99,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||
| exist_packages.append(result) | |||
| else: | |||
| nonexist_packages.append(package) | |||
| return exist_packages, nonexist_packages | |||
| @@ -129,7 +130,14 @@ def filter_nonexist_conda_packages(packages: list) -> Tuple[List[str], List[str] | |||
| last_bracket = stdout.rfind("\n{") | |||
| if last_bracket != -1: | |||
| stdout = stdout[last_bracket:] | |||
| return json.loads(stdout).get("bad_deps", []) | |||
| stdout_json = json.loads(stdout) | |||
| if "error" in stdout_json: | |||
| if "bad_deps" in stdout_json: | |||
| return stdout_json["bad_deps"] | |||
| elif "packages" in stdout_json: | |||
| return stdout_json["packages"] | |||
| return [] | |||
| org_yaml = { | |||
| "channels": ["defaults"], | |||
| @@ -95,8 +95,8 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | |||
| return self.INVALID_LEARNWARE, traceback.format_exc() | |||
| try: | |||
| learnware_model = learnware.get_model() | |||
| # Check input shape | |||
| learnware_model = learnware.get_model() | |||
| input_shape = learnware_model.input_shape | |||
| if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( | |||
| @@ -106,14 +106,18 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| # Check statistical specification | |||
| spec_type = parse_specification_type(learnware.get_specification().stat_spec) | |||
| if spec_type is None: | |||
| message = f"No valid specification is found in stat spec {spec_type}" | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| # Check if statistical specification is computable in dist() | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| stat_spec.dist(stat_spec) | |||
| if spec_type == "RKMETableSpecification": | |||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | |||
| raise ValueError( | |||
| f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" | |||
| @@ -124,14 +128,17 @@ class EasyStatChecker(BaseChecker): | |||
| logger.warning(message) | |||
| return self.INVALID_LEARNWARE, message | |||
| inputs = np.random.randn(10, *input_shape) | |||
| elif spec_type == "RKMETextSpecification": | |||
| inputs = EasyStatChecker._generate_random_text_list(10) | |||
| elif spec_type == "RKMEImageSpecification": | |||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | |||
| raise ValueError( | |||
| f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" | |||
| ) | |||
| inputs = np.random.randint(0, 255, size=(10, *input_shape)) | |||
| else: | |||
| raise ValueError(f"not supported spec type for spec_type = {spec_type}") | |||
| @@ -1,2 +1,12 @@ | |||
| from .organizer import HeteroMapTableOrganizer | |||
| from .searcher import HeteroSearcher | |||
| from ...utils import is_torch_available | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("market_hetero") | |||
| if not is_torch_available(verbose=False): | |||
| HeteroMapTableOrganizer = None | |||
| HeteroSearcher = None | |||
| logger.error("HeteroMapTableOrganizer and HeteroSearcher are not available because 'torch' is not installed!") | |||
| else: | |||
| from .organizer import HeteroMapTableOrganizer | |||
| from .searcher import HeteroSearcher | |||
| @@ -1,11 +1,9 @@ | |||
| import traceback | |||
| from typing import Tuple, List | |||
| from typing import Optional | |||
| from .utils import is_hetero | |||
| from ..base import BaseUserInfo, SearchResults | |||
| from ..easy import EasySearcher | |||
| from ..utils import parse_specification_type | |||
| from ...learnware import Learnware | |||
| from ...logger import get_module_logger | |||
| @@ -14,7 +12,7 @@ logger = get_module_logger("hetero_searcher") | |||
| class HeteroSearcher(EasySearcher): | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" | |||
| self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy" | |||
| ) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status. | |||
| Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | |||
| @@ -1,7 +1,6 @@ | |||
| from __future__ import annotations | |||
| import codecs | |||
| import copy | |||
| import functools | |||
| import json | |||
| import os | |||
| @@ -17,8 +16,11 @@ from tqdm import tqdm | |||
| from . import cnn_gp | |||
| from ..base import RegularStatSpecification | |||
| from ..table.rkme import rkme_solve_qp | |||
| from ....logger import get_module_logger | |||
| from ....utils import choose_device, allocate_cuda_idx | |||
| logger = get_module_logger("image_rkme") | |||
| class RKMEImageSpecification(RegularStatSpecification): | |||
| # INNER_PRODUCT_COUNT = 0 | |||
| @@ -127,8 +129,10 @@ class RKMEImageSpecification(RegularStatSpecification): | |||
| try: | |||
| from torchvision.transforms import Resize | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." ) | |||
| raise ModuleNotFoundError( | |||
| f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." | |||
| ) | |||
| if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: | |||
| X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X) | |||
| @@ -154,12 +158,14 @@ class RKMEImageSpecification(RegularStatSpecification): | |||
| with torch.no_grad(): | |||
| x_features = self._generate_random_feature(X_train, random_models=random_models) | |||
| self._update_beta(x_features, nonnegative_beta, random_models=random_models) | |||
| try: | |||
| import torch_optimizer | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.") | |||
| raise ModuleNotFoundError( | |||
| f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually." | |||
| ) | |||
| optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) | |||
| for _ in tqdm(range(steps)) if verbose else range(steps): | |||
| @@ -385,9 +391,14 @@ class RKMEImageSpecification(RegularStatSpecification): | |||
| self.beta = self.beta.to(self._device) | |||
| self.z = self.z.to(self._device) | |||
| return True | |||
| else: | |||
| return False | |||
| if self.type == self.__class__.__name__: | |||
| return True | |||
| else: | |||
| logger.error( | |||
| f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" | |||
| ) | |||
| return False | |||
| def _get_zca_matrix(X, reg_coef=0.1): | |||
| @@ -6,7 +6,7 @@ import json | |||
| import codecs | |||
| import scipy | |||
| import numpy as np | |||
| from qpsolvers import solve_qp, Problem, solve_problem | |||
| from qpsolvers import Problem, solve_problem | |||
| from collections import Counter | |||
| from typing import Any, Union | |||
| @@ -140,15 +140,17 @@ class RKMETableSpecification(RegularStatSpecification): | |||
| if isinstance(X, np.ndarray): | |||
| X = X.astype("float32") | |||
| X = torch.from_numpy(X) | |||
| X = X.to(self._device) | |||
| try: | |||
| from fast_pytorch_kmeans import KMeans | |||
| except ModuleNotFoundError: | |||
| raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." ) | |||
| raise ModuleNotFoundError( | |||
| f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." | |||
| ) | |||
| kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0) | |||
| kmeans = KMeans(n_clusters=K, mode="euclidean", max_iter=100, verbose=0) | |||
| kmeans.fit(X) | |||
| self.z = kmeans.centroids.double() | |||
| @@ -455,9 +457,15 @@ class RKMETableSpecification(RegularStatSpecification): | |||
| for d in self.get_states(): | |||
| if d in rkme_load.keys(): | |||
| setattr(self, d, rkme_load[d]) | |||
| return True | |||
| else: | |||
| return False | |||
| if self.type == self.__class__.__name__: | |||
| return True | |||
| else: | |||
| logger.error( | |||
| f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" | |||
| ) | |||
| return False | |||
| class RKMEStatSpecification(RKMETableSpecification): | |||
| @@ -1,7 +1,6 @@ | |||
| from __future__ import annotations | |||
| import os | |||
| import copy | |||
| import json | |||
| import torch | |||
| import codecs | |||
| @@ -10,8 +9,11 @@ import numpy as np | |||
| from .base import SystemStatSpecification | |||
| from ..regular import RKMETableSpecification | |||
| from ..regular.table.rkme import torch_rbf_kernel | |||
| from ...logger import get_module_logger | |||
| from ...utils import choose_device, allocate_cuda_idx | |||
| logger = get_module_logger("hetero_map_table_spec") | |||
| class HeteroMapTableSpecification(SystemStatSpecification): | |||
| """Heterogeneous Map-Table Specification""" | |||
| @@ -135,9 +137,14 @@ class HeteroMapTableSpecification(SystemStatSpecification): | |||
| if d in embedding_load.keys(): | |||
| setattr(self, d, embedding_load[d]) | |||
| return True | |||
| else: | |||
| return False | |||
| if self.type == self.__class__.__name__: | |||
| return True | |||
| else: | |||
| logger.error( | |||
| f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" | |||
| ) | |||
| return False | |||
| def save(self, filepath: str) -> bool: | |||
| """Save the computed HeteroMapTableSpecification to a specified path in JSON format. | |||
| @@ -5,10 +5,10 @@ import copy | |||
| import joblib | |||
| import zipfile | |||
| import numpy as np | |||
| import multiprocessing | |||
| from sklearn.linear_model import Ridge | |||
| from sklearn.datasets import make_regression | |||
| from shutil import copyfile, rmtree | |||
| from multiprocessing import Pool | |||
| from learnware.client import LearnwareClient | |||
| from sklearn.metrics import mean_squared_error | |||
| @@ -121,7 +121,8 @@ class TestMarket(unittest.TestCase): | |||
| dir_path = os.path.join(curr_root, "learnware_pool") | |||
| # Execute multi-process checking using Pool | |||
| with Pool() as pool: | |||
| mp_context = multiprocessing.get_context("spawn") | |||
| with mp_context.Pool() as pool: | |||
| results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) | |||
| # Use an assert statement to ensure that all checks return True | |||