Browse Source

Merge branch 'fix_spec_type' of https://github.com/Learnware-LAMDA/Learnware into search_result

tags/v0.3.2
bxdd 1 year ago
parent
commit
8073446007
9 changed files with 92 additions and 42 deletions
  1. +1
    -1
      learnware/__init__.py
  2. +18
    -10
      learnware/client/package_utils.py
  3. +9
    -2
      learnware/market/easy/checker.py
  4. +12
    -2
      learnware/market/heterogeneous/__init__.py
  5. +2
    -4
      learnware/market/heterogeneous/searcher.py
  6. +20
    -9
      learnware/specification/regular/image/rkme.py
  7. +16
    -8
      learnware/specification/regular/table/rkme.py
  8. +11
    -4
      learnware/specification/system/hetero_table.py
  9. +3
    -2
      tests/test_hetero_market/test_hetero.py

+ 1
- 1
learnware/__init__.py View File

@@ -1,4 +1,4 @@
__version__ = "0.2.0.5"
__version__ = "0.2.0.7"

import os
import json


+ 18
- 10
learnware/client/package_utils.py View File

@@ -21,7 +21,7 @@ def try_to_run(args, timeout=10, retry=3):
return result.stdout.decode()
except subprocess.TimeoutExpired as e:
pass
raise subprocess.TimeoutExpired(args, timeout)


@@ -30,18 +30,18 @@ def parse_pip_requirement(line: str):

line = line.strip()
if len(line) == 0 or line[0] in ("#", "-"):
return None
return None, None

package_name, package_version = line, line
for split_ch in ("=", ">", "<", "!", "~", " ", "="):
split_ch_index = package_name.find(split_ch)
if split_ch_index != -1:
package_name = package_name[:split_ch_index]
split_ch_index = package_version.find(split_ch)
if split_ch_index != -1:
package_version = package_version[split_ch_index + 1:]
package_version = package_version[split_ch_index + 1 :]
if package_version == package_name:
package_version = ""

@@ -71,6 +71,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
exist_packages: list of exist packages
nonexist_packages: list of non-exist packages
"""

def _filter_nonexist_pip_package_worker(package):
# Return filtered package
try:
@@ -83,13 +84,13 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
return package
except Exception as e:
logger.error(e)
return None
exist_packages = []
nonexist_packages = []
packages = [package for package in packages if package is not None]
with ThreadPoolExecutor(max_workers=max(os.cpu_count() // 5, 1)) as executor:
results = executor.map(_filter_nonexist_pip_package_worker, packages)

@@ -98,7 +99,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
exist_packages.append(result)
else:
nonexist_packages.append(package)
return exist_packages, nonexist_packages


@@ -129,7 +130,14 @@ def filter_nonexist_conda_packages(packages: list) -> Tuple[List[str], List[str]
last_bracket = stdout.rfind("\n{")
if last_bracket != -1:
stdout = stdout[last_bracket:]
return json.loads(stdout).get("bad_deps", [])

stdout_json = json.loads(stdout)
if "error" in stdout_json:
if "bad_deps" in stdout_json:
return stdout_json["bad_deps"]
elif "packages" in stdout_json:
return stdout_json["packages"]
return []

org_yaml = {
"channels": ["defaults"],


+ 9
- 2
learnware/market/easy/checker.py View File

@@ -95,8 +95,8 @@ class EasyStatChecker(BaseChecker):
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.")
return self.INVALID_LEARNWARE, traceback.format_exc()
try:
learnware_model = learnware.get_model()
# Check input shape
learnware_model = learnware.get_model()
input_shape = learnware_model.input_shape

if semantic_spec["Data"]["Values"][0] == "Table" and input_shape != (
@@ -106,14 +106,18 @@ class EasyStatChecker(BaseChecker):
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check statistical specification
spec_type = parse_specification_type(learnware.get_specification().stat_spec)
if spec_type is None:
message = f"No valid specification is found in stat spec {spec_type}"
logger.warning(message)
return self.INVALID_LEARNWARE, message

# Check if statistical specification is computable in dist()
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
stat_spec.dist(stat_spec)

if spec_type == "RKMETableSpecification":
stat_spec = learnware.get_specification().get_stat_spec_by_name(spec_type)
if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape):
raise ValueError(
f"For RKMETableSpecification, input_shape should be tuple of int, but got {input_shape}"
@@ -124,14 +128,17 @@ class EasyStatChecker(BaseChecker):
logger.warning(message)
return self.INVALID_LEARNWARE, message
inputs = np.random.randn(10, *input_shape)

elif spec_type == "RKMETextSpecification":
inputs = EasyStatChecker._generate_random_text_list(10)

elif spec_type == "RKMEImageSpecification":
if not isinstance(input_shape, tuple) or not all(isinstance(item, int) for item in input_shape):
raise ValueError(
f"For RKMEImageSpecification, input_shape should be tuple of int, but got {input_shape}"
)
inputs = np.random.randint(0, 255, size=(10, *input_shape))

else:
raise ValueError(f"not supported spec type for spec_type = {spec_type}")



+ 12
- 2
learnware/market/heterogeneous/__init__.py View File

@@ -1,2 +1,12 @@
from .organizer import HeteroMapTableOrganizer
from .searcher import HeteroSearcher
from ...utils import is_torch_available
from ...logger import get_module_logger

logger = get_module_logger("market_hetero")

if not is_torch_available(verbose=False):
HeteroMapTableOrganizer = None
HeteroSearcher = None
logger.error("HeteroMapTableOrganizer and HeteroSearcher are not available because 'torch' is not installed!")
else:
from .organizer import HeteroMapTableOrganizer
from .searcher import HeteroSearcher

+ 2
- 4
learnware/market/heterogeneous/searcher.py View File

@@ -1,11 +1,9 @@
import traceback
from typing import Tuple, List
from typing import Optional

from .utils import is_hetero
from ..base import BaseUserInfo, SearchResults
from ..easy import EasySearcher
from ..utils import parse_specification_type
from ...learnware import Learnware
from ...logger import get_module_logger


@@ -14,7 +12,7 @@ logger = get_module_logger("hetero_searcher")

class HeteroSearcher(EasySearcher):
def __call__(
self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy"
self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy"
) -> SearchResults:
"""Search learnwares based on user_info from learnwares with check_status.
Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods.


+ 20
- 9
learnware/specification/regular/image/rkme.py View File

@@ -1,7 +1,6 @@
from __future__ import annotations

import codecs
import copy
import functools
import json
import os
@@ -17,8 +16,11 @@ from tqdm import tqdm
from . import cnn_gp
from ..base import RegularStatSpecification
from ..table.rkme import rkme_solve_qp
from ....logger import get_module_logger
from ....utils import choose_device, allocate_cuda_idx

logger = get_module_logger("image_rkme")


class RKMEImageSpecification(RegularStatSpecification):
# INNER_PRODUCT_COUNT = 0
@@ -127,8 +129,10 @@ class RKMEImageSpecification(RegularStatSpecification):
try:
from torchvision.transforms import Resize
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." )
raise ModuleNotFoundError(
f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually."
)

if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH:
X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=True)(X)

@@ -154,12 +158,14 @@ class RKMEImageSpecification(RegularStatSpecification):
with torch.no_grad():
x_features = self._generate_random_feature(X_train, random_models=random_models)
self._update_beta(x_features, nonnegative_beta, random_models=random_models)
try:
import torch_optimizer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.")
raise ModuleNotFoundError(
f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually."
)

optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16)

for _ in tqdm(range(steps)) if verbose else range(steps):
@@ -385,9 +391,14 @@ class RKMEImageSpecification(RegularStatSpecification):
self.beta = self.beta.to(self._device)
self.z = self.z.to(self._device)

return True
else:
return False
if self.type == self.__class__.__name__:
return True
else:
logger.error(
f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!"
)

return False


def _get_zca_matrix(X, reg_coef=0.1):


+ 16
- 8
learnware/specification/regular/table/rkme.py View File

@@ -6,7 +6,7 @@ import json
import codecs
import scipy
import numpy as np
from qpsolvers import solve_qp, Problem, solve_problem
from qpsolvers import Problem, solve_problem
from collections import Counter
from typing import Any, Union

@@ -140,15 +140,17 @@ class RKMETableSpecification(RegularStatSpecification):
if isinstance(X, np.ndarray):
X = X.astype("float32")
X = torch.from_numpy(X)
X = X.to(self._device)
try:
from fast_pytorch_kmeans import KMeans
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." )
raise ModuleNotFoundError(
f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually."
)

kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0)
kmeans = KMeans(n_clusters=K, mode="euclidean", max_iter=100, verbose=0)
kmeans.fit(X)
self.z = kmeans.centroids.double()

@@ -455,9 +457,15 @@ class RKMETableSpecification(RegularStatSpecification):
for d in self.get_states():
if d in rkme_load.keys():
setattr(self, d, rkme_load[d])
return True
else:
return False

if self.type == self.__class__.__name__:
return True
else:
logger.error(
f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!"
)

return False


class RKMEStatSpecification(RKMETableSpecification):


+ 11
- 4
learnware/specification/system/hetero_table.py View File

@@ -1,7 +1,6 @@
from __future__ import annotations

import os
import copy
import json
import torch
import codecs
@@ -10,8 +9,11 @@ import numpy as np
from .base import SystemStatSpecification
from ..regular import RKMETableSpecification
from ..regular.table.rkme import torch_rbf_kernel
from ...logger import get_module_logger
from ...utils import choose_device, allocate_cuda_idx

logger = get_module_logger("hetero_map_table_spec")


class HeteroMapTableSpecification(SystemStatSpecification):
"""Heterogeneous Map-Table Specification"""
@@ -135,9 +137,14 @@ class HeteroMapTableSpecification(SystemStatSpecification):
if d in embedding_load.keys():
setattr(self, d, embedding_load[d])

return True
else:
return False
if self.type == self.__class__.__name__:
return True
else:
logger.error(
f"The type of loaded RKME ({self.type}) is different from the expected type ({self.__class__.__name__})!"
)

return False

def save(self, filepath: str) -> bool:
"""Save the computed HeteroMapTableSpecification to a specified path in JSON format.


+ 3
- 2
tests/test_hetero_market/test_hetero.py View File

@@ -5,10 +5,10 @@ import copy
import joblib
import zipfile
import numpy as np
import multiprocessing
from sklearn.linear_model import Ridge
from sklearn.datasets import make_regression
from shutil import copyfile, rmtree
from multiprocessing import Pool
from learnware.client import LearnwareClient
from sklearn.metrics import mean_squared_error

@@ -121,7 +121,8 @@ class TestMarket(unittest.TestCase):
dir_path = os.path.join(curr_root, "learnware_pool")

# Execute multi-process checking using Pool
with Pool() as pool:
mp_context = multiprocessing.get_context("spawn")
with mp_context.Pool() as pool:
results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)])

# Use an assert statement to ensure that all checks return True


Loading…
Cancel
Save