| @@ -1,4 +1,4 @@ | |||||
| __version__ = "0.2.0.5" | |||||
| __version__ = "0.2.0.7" | |||||
| import os | import os | ||||
| import json | import json | ||||
| @@ -21,7 +21,7 @@ def try_to_run(args, timeout=10, retry=3): | |||||
| return result.stdout.decode() | return result.stdout.decode() | ||||
| except subprocess.TimeoutExpired as e: | except subprocess.TimeoutExpired as e: | ||||
| pass | pass | ||||
| raise subprocess.TimeoutExpired(args, timeout) | raise subprocess.TimeoutExpired(args, timeout) | ||||
| @@ -30,18 +30,18 @@ def parse_pip_requirement(line: str): | |||||
| line = line.strip() | line = line.strip() | ||||
| if len(line) == 0 or line[0] in ("#", "-"): | if len(line) == 0 or line[0] in ("#", "-"): | ||||
| return None | |||||
| return None, None | |||||
| package_name, package_version = line, line | package_name, package_version = line, line | ||||
| for split_ch in ("=", ">", "<", "!", "~", " ", "="): | for split_ch in ("=", ">", "<", "!", "~", " ", "="): | ||||
| split_ch_index = package_name.find(split_ch) | split_ch_index = package_name.find(split_ch) | ||||
| if split_ch_index != -1: | if split_ch_index != -1: | ||||
| package_name = package_name[:split_ch_index] | package_name = package_name[:split_ch_index] | ||||
| split_ch_index = package_version.find(split_ch) | split_ch_index = package_version.find(split_ch) | ||||
| if split_ch_index != -1: | if split_ch_index != -1: | ||||
| package_version = package_version[split_ch_index + 1:] | |||||
| package_version = package_version[split_ch_index + 1 :] | |||||
| if package_version == package_name: | if package_version == package_name: | ||||
| package_version = "" | package_version = "" | ||||
| @@ -71,6 +71,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||||
| exist_packages: list of exist packages | exist_packages: list of exist packages | ||||
| nonexist_packages: list of non-exist packages | nonexist_packages: list of non-exist packages | ||||
| """ | """ | ||||
| def _filter_nonexist_pip_package_worker(package): | def _filter_nonexist_pip_package_worker(package): | ||||
| # Return filtered package | # Return filtered package | ||||
| try: | try: | ||||
| @@ -83,13 +84,13 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||||
| return package | return package | ||||
| except Exception as e: | except Exception as e: | ||||
| logger.error(e) | logger.error(e) | ||||
| return None | return None | ||||
| exist_packages = [] | exist_packages = [] | ||||
| nonexist_packages = [] | nonexist_packages = [] | ||||
| packages = [package for package in packages if package is not None] | packages = [package for package in packages if package is not None] | ||||
| with ThreadPoolExecutor(max_workers=max(os.cpu_count() // 5, 1)) as executor: | with ThreadPoolExecutor(max_workers=max(os.cpu_count() // 5, 1)) as executor: | ||||
| results = executor.map(_filter_nonexist_pip_package_worker, packages) | results = executor.map(_filter_nonexist_pip_package_worker, packages) | ||||
| @@ -98,7 +99,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]: | |||||
| exist_packages.append(result) | exist_packages.append(result) | ||||
| else: | else: | ||||
| nonexist_packages.append(package) | nonexist_packages.append(package) | ||||
| return exist_packages, nonexist_packages | return exist_packages, nonexist_packages | ||||
| @@ -129,7 +130,14 @@ def filter_nonexist_conda_packages(packages: list) -> Tuple[List[str], List[str] | |||||
| last_bracket = stdout.rfind("\n{") | last_bracket = stdout.rfind("\n{") | ||||
| if last_bracket != -1: | if last_bracket != -1: | ||||
| stdout = stdout[last_bracket:] | stdout = stdout[last_bracket:] | ||||
| return json.loads(stdout).get("bad_deps", []) | |||||
| stdout_json = json.loads(stdout) | |||||
| if "error" in stdout_json: | |||||
| if "bad_deps" in stdout_json: | |||||
| return stdout_json["bad_deps"] | |||||
| elif "packages" in stdout_json: | |||||
| return stdout_json["packages"] | |||||
| return [] | |||||
| org_yaml = { | org_yaml = { | ||||
| "channels": ["defaults"], | "channels": ["defaults"], | ||||
| @@ -95,8 +95,8 @@ class EasyStatChecker(BaseChecker): | |||||
| logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") | ||||
| return self.INVALID_LEARNWARE, traceback.format_exc() | return self.INVALID_LEARNWARE, traceback.format_exc() | ||||
| try: | try: | ||||
| learnware_model = learnware.get_model() | |||||
| # Check input shape | # Check input shape | ||||
| learnware_model = learnware.get_model() | |||||
| input_shape = learnware_model.input_shape | input_shape = learnware_model.input_shape | ||||
| if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( | if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != ( | ||||
| @@ -106,14 +106,18 @@ class EasyStatChecker(BaseChecker): | |||||
| logger.warning(message) | logger.warning(message) | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||
| # Check statistical specification | |||||
| spec_type = parse_specification_type(learnware.get_specification().stat_spec) | spec_type = parse_specification_type(learnware.get_specification().stat_spec) | ||||
| if spec_type is None: | if spec_type is None: | ||||
| message = f"No valid specification is found in stat spec {spec_type}" | message = f"No valid specification is found in stat spec {spec_type}" | ||||
| logger.warning(message) | logger.warning(message) | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||
| # Check if statistical specification is computable in dist() | |||||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||||
| stat_spec.dist(stat_spec) | |||||
| if spec_type == "RKMETableSpecification": | if spec_type == "RKMETableSpecification": | ||||
| stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type) | |||||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | ||||
| raise ValueError( | raise ValueError( | ||||
| f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" | f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}" | ||||
| @@ -124,14 +128,17 @@ class EasyStatChecker(BaseChecker): | |||||
| logger.warning(message) | logger.warning(message) | ||||
| return self.INVALID_LEARNWARE, message | return self.INVALID_LEARNWARE, message | ||||
| inputs = np.random.randn(10, *input_shape) | inputs = np.random.randn(10, *input_shape) | ||||
| elif spec_type == "RKMETextSpecification": | elif spec_type == "RKMETextSpecification": | ||||
| inputs = EasyStatChecker._generate_random_text_list(10) | inputs = EasyStatChecker._generate_random_text_list(10) | ||||
| elif spec_type == "RKMEImageSpecification": | elif spec_type == "RKMEImageSpecification": | ||||
| if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape): | ||||
| raise ValueError( | raise ValueError( | ||||
| f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" | f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}" | ||||
| ) | ) | ||||
| inputs = np.random.randint(0, 255, size=(10, *input_shape)) | inputs = np.random.randint(0, 255, size=(10, *input_shape)) | ||||
| else: | else: | ||||
| raise ValueError(f"not supported spec type for spec_type = {spec_type}") | raise ValueError(f"not supported spec type for spec_type = {spec_type}") | ||||
| @@ -1,2 +1,12 @@ | |||||
| from .organizer import HeteroMapTableOrganizer | |||||
| from .searcher import HeteroSearcher | |||||
| from ...utils import is_torch_available | |||||
| from ...logger import get_module_logger | |||||
| logger = get_module_logger("market_hetero") | |||||
| if not is_torch_available(verbose=False): | |||||
| HeteroMapTableOrganizer = None | |||||
| HeteroSearcher = None | |||||
| logger.error("HeteroMapTableOrganizer and HeteroSearcher are not available because 'torch' is not installed!") | |||||
| else: | |||||
| from .organizer import HeteroMapTableOrganizer | |||||
| from .searcher import HeteroSearcher | |||||
| @@ -1,11 +1,9 @@ | |||||
| import traceback | |||||
| from typing import Tuple, List | |||||
| from typing import Optional | |||||
| from .utils import is_hetero | from .utils import is_hetero | ||||
| from ..base import BaseUserInfo, SearchResults | from ..base import BaseUserInfo, SearchResults | ||||
| from ..easy import EasySearcher | from ..easy import EasySearcher | ||||
| from ..utils import parse_specification_type | from ..utils import parse_specification_type | ||||
| from ...learnware import Learnware | |||||
| from ...logger import get_module_logger | from ...logger import get_module_logger | ||||
| @@ -14,7 +12,7 @@ logger = get_module_logger("hetero_searcher") | |||||
| class HeteroSearcher(EasySearcher): | class HeteroSearcher(EasySearcher): | ||||
| def __call__( | def __call__( | ||||
| self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" | |||||
| self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy" | |||||
| ) -> SearchResults: | ) -> SearchResults: | ||||
| """Search learnwares based on user_info from learnwares with check_status. | """Search learnwares based on user_info from learnwares with check_status. | ||||
| Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | ||||
| @@ -1,7 +1,6 @@ | |||||
| from __future__ import annotations | from __future__ import annotations | ||||
| import codecs | import codecs | ||||
| import copy | |||||
| import functools | import functools | ||||
| import json | import json | ||||
| import os | import os | ||||
| @@ -17,8 +16,11 @@ from tqdm import tqdm | |||||
| from . import cnn_gp | from . import cnn_gp | ||||
| from ..base import RegularStatSpecification | from ..base import RegularStatSpecification | ||||
| from ..table.rkme import rkme_solve_qp | from ..table.rkme import rkme_solve_qp | ||||
| from ....logger import get_module_logger | |||||
| from ....utils import choose_device, allocate_cuda_idx | from ....utils import choose_device, allocate_cuda_idx | ||||
| logger = get_module_logger("image_rkme") | |||||
| class RKMEImageSpecification(RegularStatSpecification): | class RKMEImageSpecification(RegularStatSpecification): | ||||
| # INNER_PRODUCT_COUNT = 0 | # INNER_PRODUCT_COUNT = 0 | ||||
| @@ -127,8 +129,10 @@ class RKMEImageSpecification(RegularStatSpecification): | |||||
| try: | try: | ||||
| from torchvision.transforms import Resize | from torchvision.transforms import Resize | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." ) | |||||
| raise ModuleNotFoundError( | |||||
| f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." | |||||
| ) | |||||
| if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: | if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: | ||||
| X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X) | X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X) | ||||
| @@ -154,12 +158,14 @@ class RKMEImageSpecification(RegularStatSpecification): | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| x_features = self._generate_random_feature(X_train, random_models=random_models) | x_features = self._generate_random_feature(X_train, random_models=random_models) | ||||
| self._update_beta(x_features, nonnegative_beta, random_models=random_models) | self._update_beta(x_features, nonnegative_beta, random_models=random_models) | ||||
| try: | try: | ||||
| import torch_optimizer | import torch_optimizer | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.") | |||||
| raise ModuleNotFoundError( | |||||
| f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually." | |||||
| ) | |||||
| optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) | optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) | ||||
| for _ in tqdm(range(steps)) if verbose else range(steps): | for _ in tqdm(range(steps)) if verbose else range(steps): | ||||
| @@ -385,9 +391,14 @@ class RKMEImageSpecification(RegularStatSpecification): | |||||
| self.beta = self.beta.to(self._device) | self.beta = self.beta.to(self._device) | ||||
| self.z = self.z.to(self._device) | self.z = self.z.to(self._device) | ||||
| return True | |||||
| else: | |||||
| return False | |||||
| if self.type == self.__class__.__name__: | |||||
| return True | |||||
| else: | |||||
| logger.error( | |||||
| f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" | |||||
| ) | |||||
| return False | |||||
| def _get_zca_matrix(X, reg_coef=0.1): | def _get_zca_matrix(X, reg_coef=0.1): | ||||
| @@ -6,7 +6,7 @@ import json | |||||
| import codecs | import codecs | ||||
| import scipy | import scipy | ||||
| import numpy as np | import numpy as np | ||||
| from qpsolvers import solve_qp, Problem, solve_problem | |||||
| from qpsolvers import Problem, solve_problem | |||||
| from collections import Counter | from collections import Counter | ||||
| from typing import Any, Union | from typing import Any, Union | ||||
| @@ -140,15 +140,17 @@ class RKMETableSpecification(RegularStatSpecification): | |||||
| if isinstance(X, np.ndarray): | if isinstance(X, np.ndarray): | ||||
| X = X.astype("float32") | X = X.astype("float32") | ||||
| X = torch.from_numpy(X) | X = torch.from_numpy(X) | ||||
| X = X.to(self._device) | X = X.to(self._device) | ||||
| try: | try: | ||||
| from fast_pytorch_kmeans import KMeans | from fast_pytorch_kmeans import KMeans | ||||
| except ModuleNotFoundError: | except ModuleNotFoundError: | ||||
| raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." ) | |||||
| raise ModuleNotFoundError( | |||||
| f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." | |||||
| ) | |||||
| kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0) | |||||
| kmeans = KMeans(n_clusters=K, mode="euclidean", max_iter=100, verbose=0) | |||||
| kmeans.fit(X) | kmeans.fit(X) | ||||
| self.z = kmeans.centroids.double() | self.z = kmeans.centroids.double() | ||||
| @@ -455,9 +457,15 @@ class RKMETableSpecification(RegularStatSpecification): | |||||
| for d in self.get_states(): | for d in self.get_states(): | ||||
| if d in rkme_load.keys(): | if d in rkme_load.keys(): | ||||
| setattr(self, d, rkme_load[d]) | setattr(self, d, rkme_load[d]) | ||||
| return True | |||||
| else: | |||||
| return False | |||||
| if self.type == self.__class__.__name__: | |||||
| return True | |||||
| else: | |||||
| logger.error( | |||||
| f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" | |||||
| ) | |||||
| return False | |||||
| class RKMEStatSpecification(RKMETableSpecification): | class RKMEStatSpecification(RKMETableSpecification): | ||||
| @@ -1,7 +1,6 @@ | |||||
| from __future__ import annotations | from __future__ import annotations | ||||
| import os | import os | ||||
| import copy | |||||
| import json | import json | ||||
| import torch | import torch | ||||
| import codecs | import codecs | ||||
| @@ -10,8 +9,11 @@ import numpy as np | |||||
| from .base import SystemStatSpecification | from .base import SystemStatSpecification | ||||
| from ..regular import RKMETableSpecification | from ..regular import RKMETableSpecification | ||||
| from ..regular.table.rkme import torch_rbf_kernel | from ..regular.table.rkme import torch_rbf_kernel | ||||
| from ...logger import get_module_logger | |||||
| from ...utils import choose_device, allocate_cuda_idx | from ...utils import choose_device, allocate_cuda_idx | ||||
| logger = get_module_logger("hetero_map_table_spec") | |||||
| class HeteroMapTableSpecification(SystemStatSpecification): | class HeteroMapTableSpecification(SystemStatSpecification): | ||||
| """Heterogeneous Map-Table Specification""" | """Heterogeneous Map-Table Specification""" | ||||
| @@ -135,9 +137,14 @@ class HeteroMapTableSpecification(SystemStatSpecification): | |||||
| if d in embedding_load.keys(): | if d in embedding_load.keys(): | ||||
| setattr(self, d, embedding_load[d]) | setattr(self, d, embedding_load[d]) | ||||
| return True | |||||
| else: | |||||
| return False | |||||
| if self.type == self.__class__.__name__: | |||||
| return True | |||||
| else: | |||||
| logger.error( | |||||
| f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!" | |||||
| ) | |||||
| return False | |||||
| def save(self, filepath: str) -> bool: | def save(self, filepath: str) -> bool: | ||||
| """Save the computed HeteroMapTableSpecification to a specified path in JSON format. | """Save the computed HeteroMapTableSpecification to a specified path in JSON format. | ||||
| @@ -5,10 +5,10 @@ import copy | |||||
| import joblib | import joblib | ||||
| import zipfile | import zipfile | ||||
| import numpy as np | import numpy as np | ||||
| import multiprocessing | |||||
| from sklearn.linear_model import Ridge | from sklearn.linear_model import Ridge | ||||
| from sklearn.datasets import make_regression | from sklearn.datasets import make_regression | ||||
| from shutil import copyfile, rmtree | from shutil import copyfile, rmtree | ||||
| from multiprocessing import Pool | |||||
| from learnware.client import LearnwareClient | from learnware.client import LearnwareClient | ||||
| from sklearn.metrics import mean_squared_error | from sklearn.metrics import mean_squared_error | ||||
| @@ -121,7 +121,8 @@ class TestMarket(unittest.TestCase): | |||||
| dir_path = os.path.join(curr_root, "learnware_pool") | dir_path = os.path.join(curr_root, "learnware_pool") | ||||
| # Execute multi-process checking using Pool | # Execute multi-process checking using Pool | ||||
| with Pool() as pool: | |||||
| mp_context = multiprocessing.get_context("spawn") | |||||
| with mp_context.Pool() as pool: | |||||
| results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) | results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) | ||||
| # Use an assert statement to ensure that all checks return True | # Use an assert statement to ensure that all checks return True | ||||