Browse Source

[MNT] format code using black v23.1.0

tags/v0.3.2
Gene 1 year ago
parent
commit
3297847927
33 changed files with 194 additions and 136 deletions
  1. +2
    -2
      docs/conf.py
  2. +1
    -3
      examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py
  3. +2
    -2
      learnware/client/package_utils.py
  4. +11
    -7
      learnware/learnware/__init__.py
  5. +2
    -4
      learnware/market/__init__.py
  6. +1
    -2
      learnware/market/easy/__init__.py
  7. +2
    -4
      learnware/market/easy/searcher.py
  8. +2
    -3
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py
  9. +5
    -1
      learnware/market/heterogeneous/searcher.py
  10. +15
    -6
      learnware/market/module.py
  11. +1
    -1
      learnware/reuse/__init__.py
  12. +12
    -7
      learnware/reuse/ensemble_pruning.py
  13. +1
    -2
      learnware/reuse/job_selector.py
  14. +1
    -0
      learnware/reuse/utils.py
  15. +15
    -6
      learnware/specification/__init__.py
  16. +1
    -2
      learnware/specification/regular/table/__init__.py
  17. +5
    -3
      learnware/specification/regular/text/rkme.py
  18. +3
    -1
      learnware/specification/system/hetero_table.py
  19. +1
    -1
      learnware/tests/__init__.py
  20. +23
    -15
      learnware/tests/templates/__init__.py
  21. +2
    -3
      learnware/tests/templates/pickle_model.py
  22. +1
    -1
      learnware/tests/utils.py
  23. +1
    -2
      learnware/utils/__init__.py
  24. +1
    -0
      learnware/utils/file.py
  25. +1
    -0
      learnware/utils/gpu.py
  26. +7
    -4
      tests/test_learnware_client/test_container.py
  27. +6
    -5
      tests/test_learnware_client/test_load_learnware.py
  28. +5
    -5
      tests/test_specification/test_hetero_spec.py
  29. +2
    -1
      tests/test_specification/test_image_rkme.py
  30. +2
    -1
      tests/test_specification/test_table_rkme.py
  31. +13
    -13
      tests/test_specification/test_text_rkme.py
  32. +32
    -18
      tests/test_workflow/test_hetero_workflow.py
  33. +15
    -11
      tests/test_workflow/test_workflow.py

+ 2
- 2
docs/conf.py View File

@@ -100,12 +100,12 @@ html_logo = "_static/img/logo/logo1.png"


# These folders are copied to the documentation's HTML output
html_static_path = ['_static']
html_static_path = ["_static"]

# These paths are either relative to html_static_path
# or fully qualified paths (eg. https://...)
html_css_files = [
'css/custom_style.css',
"css/custom_style.css",
]

# -- Options for HTMLHelp output ------------------------------------------


+ 1
- 3
examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py View File

@@ -85,9 +85,7 @@ def get_split_errs(algo):
split = train_xs.shape[0] - proportion_list[tmp]
model.fit(
train_xs[
split:,
],
train_xs[split:,],
train_ys[split:],
eval_set=[(val_xs, val_ys)],
early_stopping_rounds=50,


+ 2
- 2
learnware/client/package_utils.py View File

@@ -86,7 +86,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
pass
except Exception as err:
logger.error(err)
return None

exist_packages = []
@@ -101,7 +101,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
exist_packages.append(result)
else:
nonexist_packages.append(package)
if len(nonexist_packages) > 0:
logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.")
return exist_packages, nonexist_packages


+ 11
- 7
learnware/learnware/__init__.py View File

@@ -13,7 +13,9 @@ from ..utils import read_yaml_to_dict
logger = get_module_logger("learnware.learnware")


def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Optional[Learnware]:
def get_learnware_from_dirpath(
id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True
) -> Optional[Learnware]:
"""Get the learnware object from dirpath, and provide the manage interface tor Learnware class

Parameters
@@ -46,11 +48,11 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath,
}

try:
learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"])
assert os.path.exists(learnware_yaml_path), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile."
assert os.path.exists(
learnware_yaml_path
), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile."

yaml_config = read_yaml_to_dict(learnware_yaml_path)

if "name" in yaml_config:
@@ -67,8 +69,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath,
for _stat_spec in learnware_config["stat_specifications"]:
stat_spec = _stat_spec.copy()
stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"])
assert os.path.exists(stat_spec_path), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile."
assert os.path.exists(
stat_spec_path
), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile."

stat_spec["file_name"] = stat_spec_path
stat_spec_inst = get_stat_spec_from_config(stat_spec)
learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst})


+ 2
- 4
learnware/market/__init__.py View File

@@ -1,9 +1,7 @@
from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo
from .base import (BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo,
LearnwareMarket)
from .base import BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, LearnwareMarket
from .classes import CondaChecker
from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker,
EasyStatChecker)
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .evolve import EvolvedOrganizer
from .evolve_anchor import EvolvedAnchoredOrganizer
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher


+ 1
- 2
learnware/market/easy/__init__.py View File

@@ -11,5 +11,4 @@ if not is_torch_available(verbose=False):
logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!")
else:
from .checker import EasySemanticChecker, EasyStatChecker
from .searcher import (EasyExactSemanticSearcher, EasyFuzzSemanticSearcher,
EasySearcher, EasyStatSearcher)
from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasySearcher, EasyStatSearcher

+ 2
- 4
learnware/market/easy/searcher.py View File

@@ -6,13 +6,11 @@ import torch
from rapidfuzz import fuzz

from .organizer import EasyOrganizer
from ..base import (BaseSearcher, BaseUserInfo, MultipleSearchItem,
SearchResults, SingleSearchItem)
from ..base import BaseSearcher, BaseUserInfo, MultipleSearchItem, SearchResults, SingleSearchItem
from ..utils import parse_specification_type
from ...learnware import Learnware
from ...logger import get_module_logger
from ...specification import (RKMEImageSpecification, RKMETableSpecification,
RKMETextSpecification, rkme_solve_qp)
from ...specification import RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification, rkme_solve_qp

logger = get_module_logger("easy_seacher")



+ 2
- 3
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -8,8 +8,7 @@ from torch import nn

from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer
from .trainer import Trainer, TransTabCollatorForCL
from .....specification import (HeteroMapTableSpecification,
RKMETableSpecification)
from .....specification import HeteroMapTableSpecification, RKMETableSpecification
from .....utils import allocate_cuda_idx, choose_device


@@ -288,7 +287,7 @@ class HeteroMap(nn.Module):
# go through transformers, get the first cls embedding
encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim
output_features = encoder_output[:, 0, :]
del inputs, outputs, encoder_output
torch.cuda.empty_cache()



+ 5
- 1
learnware/market/heterogeneous/searcher.py View File

@@ -11,7 +11,11 @@ logger = get_module_logger("hetero_searcher")

class HeteroSearcher(EasySearcher):
def __call__(
self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy"
self,
user_info: BaseUserInfo,
check_status: Optional[int] = None,
max_search_num: int = 5,
search_method: str = "greedy",
) -> SearchResults:
"""Search learnwares based on user_info from learnwares with check_status.
Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods.


+ 15
- 6
learnware/market/module.py View File

@@ -1,11 +1,12 @@
from .base import LearnwareMarket
from .classes import CondaChecker
from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker,
EasyStatChecker)
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher


def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False):
def get_market_component(
name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False
):
organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs
searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs
checker_kwargs = {} if checker_kwargs is None else checker_kwargs
@@ -13,7 +14,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search
if name == "easy":
easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild)
easy_searcher = EasySearcher(organizer=easy_organizer)
easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())]
easy_checker_list = [
EasySemanticChecker(),
EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()),
]
market_component = {
"organizer": easy_organizer,
"searcher": easy_searcher,
@@ -22,7 +26,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search
elif name == "hetero":
hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs)
hetero_searcher = HeteroSearcher(organizer=hetero_organizer)
hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())]
hetero_checker_list = [
EasySemanticChecker(),
EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()),
]

market_component = {
"organizer": hetero_organizer,
@@ -45,7 +52,9 @@ def instantiate_learnware_market(
conda_checker: bool = False,
**kwargs,
):
market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker)
market_componets = get_market_component(
name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker
)
return LearnwareMarket(
organizer=market_componets["organizer"],
searcher=market_componets["searcher"],


+ 1
- 1
learnware/reuse/__init__.py View File

@@ -20,4 +20,4 @@ else:
from .ensemble_pruning import EnsemblePruningReuser
from .feature_augment import FeatureAugmentReuser
from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware
from .job_selector import JobSelectorReuser
from .job_selector import JobSelectorReuser

+ 12
- 7
learnware/reuse/ensemble_pruning.py View File

@@ -54,13 +54,14 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""

try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
raise ModuleNotFoundError(
f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)."
)

model_num = v_predict.shape[1]

@ea.Problem.single
@@ -148,7 +149,9 @@ class EnsemblePruningReuser(BaseReuser):
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
raise ModuleNotFoundError(
f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)."
)

if torch.is_tensor(v_true):
v_true = v_true.detach().cpu().numpy()
@@ -270,8 +273,10 @@ class EnsemblePruningReuser(BaseReuser):
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
raise ModuleNotFoundError(
f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)."
)

model_num = v_predict.shape[1]
v_predict[v_predict == 0.0] = -1
v_true[v_true == 0.0] = -1


+ 1
- 2
learnware/reuse/job_selector.py View File

@@ -8,8 +8,7 @@ from .base import BaseReuser
from ..learnware import Learnware
from ..logger import get_module_logger
from ..market.utils import parse_specification_type
from ..specification import (RKMETableSpecification, RKMETextSpecification,
generate_rkme_table_spec, rkme_solve_qp)
from ..specification import RKMETableSpecification, RKMETextSpecification, generate_rkme_table_spec, rkme_solve_qp

logger = get_module_logger("job_selector_reuse")



+ 1
- 0
learnware/reuse/utils.py View File

@@ -4,6 +4,7 @@ from ..logger import get_module_logger

logger = get_module_logger("reuse_utils")


def fill_data_with_mean(X: np.ndarray) -> np.ndarray:
"""
Fill missing data (NaN, Inf) in the input array with the mean of the column.


+ 15
- 6
learnware/specification/__init__.py View File

@@ -1,7 +1,12 @@
from .base import BaseStatSpecification, Specification
from .regular import (RegularStatSpecification, RKMEImageSpecification,
RKMEStatSpecification, RKMETableSpecification,
RKMETextSpecification, rkme_solve_qp)
from .regular import (
RegularStatSpecification,
RKMEImageSpecification,
RKMEStatSpecification,
RKMETableSpecification,
RKMETextSpecification,
rkme_solve_qp,
)
from .system import HeteroMapTableSpecification
from ..utils import is_torch_available

@@ -12,6 +17,10 @@ if not is_torch_available(verbose=False):
generate_rkme_text_spec = None
generate_semantic_spec = None
else:
from .module import (generate_rkme_image_spec, generate_rkme_table_spec,
generate_rkme_text_spec, generate_semantic_spec,
generate_stat_spec)
from .module import (
generate_rkme_image_spec,
generate_rkme_table_spec,
generate_rkme_text_spec,
generate_semantic_spec,
generate_stat_spec,
)

+ 1
- 2
learnware/specification/regular/table/__init__.py View File

@@ -11,5 +11,4 @@ if not is_torch_available(verbose=False):
f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!"
)
else:
from .rkme import (RKMEStatSpecification, RKMETableSpecification,
rkme_solve_qp)
from .rkme import RKMEStatSpecification, RKMETableSpecification, rkme_solve_qp

+ 5
- 3
learnware/specification/regular/text/rkme.py View File

@@ -87,12 +87,14 @@ class RKMETextSpecification(RKMETableSpecification):
return np.array(miniLM_learnware.predict(X))

logger.info("Load the necessary feature extractor for RKMETextSpecification.")
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.")
raise ModuleNotFoundError(
f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually."
)

if os.path.exists(zip_path):
X = _get_from_client(zip_path, X)
else:


+ 3
- 1
learnware/specification/system/hetero_table.py View File

@@ -137,7 +137,9 @@ class HeteroMapTableSpecification(SystemStatSpecification):
for d in self.get_states():
if d in embedding_load.keys():
if d == "type" and embedding_load[d] != self.type:
raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!")
raise TypeError(
f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!"
)
setattr(self, d, embedding_load[d])

def save(self, filepath: str) -> bool:


+ 1
- 1
learnware/tests/__init__.py View File

@@ -1 +1 @@
from .utils import parametrize
from .utils import parametrize

+ 23
- 15
learnware/tests/templates/__init__.py View File

@@ -13,10 +13,13 @@ class ModelTemplate:
class_name: str = field(init=False)
template_path: str = field(init=False)
model_kwargs: dict = field(init=False)


@dataclass
class PickleModelTemplate(ModelTemplate):
model_kwargs: dict
pickle_filepath: str

def __post_init__(self):
self.class_name = "PickleLoadedModel"
self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py")
@@ -29,13 +32,14 @@ class PickleModelTemplate(ModelTemplate):
default_model_kwargs.update(self.model_kwargs)
self.model_kwargs = default_model_kwargs


@dataclass
class StatSpecTemplate:
filepath: str
type: str = field(default="RKMETableSpecification")
class LearnwareTemplate:


class LearnwareTemplate:
@staticmethod
def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None):
requirements = [] if requirements is None else requirements
@@ -49,14 +53,16 @@ class LearnwareTemplate:
line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n"
else:
raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}")
requirements_str += line_str
with open(filepath, "w") as fdout:
fdout.write(requirements_str)
@staticmethod
def generate_learnware_yaml(filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None):
def generate_learnware_yaml(
filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None
):
learnware_config = {}
if model_config is not None:
learnware_config["model"] = model_config
@@ -64,7 +70,7 @@ class LearnwareTemplate:
learnware_config["stat_specifications"] = stat_spec_config

save_dict_to_yaml(learnware_config, filepath)
@staticmethod
def generate_learnware_zipfile(
learnware_zippath: str,
@@ -75,27 +81,29 @@ class LearnwareTemplate:
with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir:
requirement_filepath = os.path.join(tempdir, "requirements.txt")
LearnwareTemplate.generate_requirements(requirement_filepath, requirements)
model_filepath = os.path.join(tempdir, "__init__.py")
model_filepath = os.path.join(tempdir, "__init__.py")
copyfile(model_template.template_path, model_filepath)
learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml")
model_config = {
"class_name": model_template.class_name,
"kwargs": model_template.model_kwargs,
}
stat_spec_config = {
"module_path": "learnware.specification",
"class_name": stat_spec_template.type,
"file_name": "stat_spec.json",
"kwargs": {}
"kwargs": {},
}
copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"]))
LearnwareTemplate.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config])
LearnwareTemplate.generate_learnware_yaml(
learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config]
)

if isinstance(model_template, PickleModelTemplate):
pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"])
copyfile(model_template.pickle_filepath, pickle_filepath)
convert_folder_to_zipfile(tempdir, learnware_zippath)

+ 2
- 3
learnware/tests/templates/pickle_model.py View File

@@ -7,7 +7,6 @@ from learnware.model.base import BaseModel


class PickleLoadedModel(BaseModel):
def __init__(
self,
input_shape,
@@ -25,10 +24,10 @@ class PickleLoadedModel(BaseModel):
self.predict_method = predict_method
self.fit_method = fit_method
self.finetune_method = finetune_method
def predict(self, X: np.ndarray) -> np.ndarray:
return getattr(self.model, self.predict_method)(X)
def fit(self, X: np.ndarray, y: np.ndarray):
getattr(self.model, self.fit_method)(X, y)



+ 1
- 1
learnware/tests/utils.py View File

@@ -7,4 +7,4 @@ def parametrize(test_class, **kwargs):
_suite = unittest.TestSuite()
for name in test_names:
_suite.addTest(test_class(name, **kwargs))
return _suite
return _suite

+ 1
- 2
learnware/utils/__init__.py View File

@@ -1,8 +1,7 @@
import os
import zipfile

from .file import (convert_folder_to_zipfile, read_yaml_to_dict,
save_dict_to_yaml)
from .file import convert_folder_to_zipfile, read_yaml_to_dict, save_dict_to_yaml
from .gpu import allocate_cuda_idx, choose_device, setup_seed
from .import_utils import is_torch_available
from .module import get_module_by_module_path


+ 1
- 0
learnware/utils/file.py View File

@@ -16,6 +16,7 @@ def read_yaml_to_dict(yaml_path: str):
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value


def convert_folder_to_zipfile(folder_path, zip_path):
with zipfile.ZipFile(zip_path, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(folder_path):


+ 1
- 0
learnware/utils/gpu.py View File

@@ -17,6 +17,7 @@ def setup_seed(seed):
random.seed(seed)
if is_torch_available(verbose=False):
import torch

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True


+ 7
- 4
tests/test_learnware_client/test_container.py View File

@@ -4,15 +4,16 @@ import numpy as np
from learnware.client import LearnwareClient
from learnware.client.container import LearnwaresContainer


class TestContainer(unittest.TestCase):
def __init__(self, method_name='runTest', mode="all"):
def __init__(self, method_name="runTest", mode="all"):
super(TestContainer, self).__init__(method_name)
self.modes = []
if mode in {"all", "conda"}:
self.modes.append("conda")
if mode in {"all", "docker"}:
self.modes.append("docker")
def setUp(self):
self.client = LearnwareClient()

@@ -35,17 +36,19 @@ class TestContainer(unittest.TestCase):
def test_container_with_pip(self):
for mode in self.modes:
self._test_container_with_pip(mode=mode)
def test_container_with_conda(self):
for mode in self.modes:
self._test_container_with_conda(mode=mode)


def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestContainer("test_container_with_pip", mode="all"))
_suite.addTest(TestContainer("test_container_with_conda", mode="all"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())
runner.run(suite())

+ 6
- 5
tests/test_learnware_client/test_load_learnware.py View File

@@ -5,8 +5,9 @@ import numpy as np
from learnware.client import LearnwareClient
from learnware.reuse import AveragingReuser


class TestLearnwareLoad(unittest.TestCase):
def __init__(self, method_name='runTest', mode="all"):
def __init__(self, method_name="runTest", mode="all"):
super(TestLearnwareLoad, self).__init__(method_name)
self.runnable_options = []
if mode in {"all", "conda"}:
@@ -31,7 +32,6 @@ class TestLearnwareLoad(unittest.TestCase):
for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array))


def _test_load_learnware_by_id(self, runnable_option):
learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option)
reuser = AveragingReuser(learnware_list, mode="vote_by_label")
@@ -44,11 +44,11 @@ class TestLearnwareLoad(unittest.TestCase):
def test_load_learnware_by_zippath(self):
for runnable_option in self.runnable_options:
self._test_load_learnware_by_zippath(runnable_option=runnable_option)
def test_load_learnware_by_id(self):
for runnable_option in self.runnable_options:
self._test_load_learnware_by_id(runnable_option=runnable_option)

def suite():
_suite = unittest.TestSuite()
@@ -56,6 +56,7 @@ def suite():
_suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())
runner.run(suite())

+ 5
- 5
tests/test_specification/test_hetero_spec.py View File

@@ -11,11 +11,11 @@ from learnware.specification import RKMETableSpecification, HeteroMapTableSpecif
from learnware.specification import generate_stat_spec
from learnware.market.heterogeneous.organizer import HeteroMap


class TestTableRKME(unittest.TestCase):
def setUp(self):
self.hetero_map = HeteroMap()
def _test_hetero_spec(self, X):
rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X)
hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict())
@@ -30,14 +30,14 @@ class TestTableRKME(unittest.TestCase):
rkme2 = HeteroMapTableSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "HeteroMapTableSpecification"

def test_hetero_rkme(self):
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150)))


if __name__ == "__main__":
unittest.main()

+ 2
- 1
tests/test_specification/test_image_rkme.py View File

@@ -25,7 +25,7 @@ class TestImageRKME(unittest.TestCase):
rkme2 = RKMEImageSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "RKMEImageSpecification"
def test_image_rkme(self):
self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128)))
@@ -34,5 +34,6 @@ class TestImageRKME(unittest.TestCase):
self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128)))
self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255)


if __name__ == "__main__":
unittest.main()

+ 2
- 1
tests/test_specification/test_table_rkme.py View File

@@ -24,7 +24,7 @@ class TestTableRKME(unittest.TestCase):
rkme2 = RKMETableSpecification()
rkme2.load(rkme_path)
assert rkme2.type == "RKMETableSpecification"
def test_table_rkme(self):
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200)))
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100)))
@@ -32,5 +32,6 @@ class TestTableRKME(unittest.TestCase):
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50)))
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150)))


if __name__ == "__main__":
unittest.main()

+ 13
- 13
tests/test_specification/test_text_rkme.py View File

@@ -12,19 +12,19 @@ from learnware.specification import generate_stat_spec
class TestTextRKME(unittest.TestCase):
@staticmethod
def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list

@staticmethod
def _test_text_rkme(X):


+ 32
- 18
tests/test_workflow/test_hetero_workflow.py View File

@@ -11,6 +11,7 @@ from shutil import copyfile, rmtree
from sklearn.metrics import mean_squared_error

import learnware

learnware.init(logging_level=logging.WARNING)

from learnware.market import instantiate_learnware_market, BaseUserInfo
@@ -23,6 +24,7 @@ from hetero_config import input_shape_list, input_description_list, output_descr

curr_root = os.path.dirname(os.path.abspath(__file__))


class TestHeteroWorkflow(unittest.TestCase):
universal_semantic_config = {
"data_type": "Table",
@@ -46,10 +48,12 @@ class TestHeteroWorkflow(unittest.TestCase):
learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
os.makedirs(learnware_pool_dirpath, exist_ok=True)
learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
print("Preparing Learnware: %d" % (i))

X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42)
X, y = make_regression(
n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42
)
clf = Ridge(alpha=1.0)
clf.fit(X, y)
pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
@@ -62,14 +66,16 @@ class TestHeteroWorkflow(unittest.TestCase):

LearnwareTemplate.generate_learnware_zipfile(
learnware_zippath=learnware_zippath,
model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}),
model_template=PickleModelTemplate(
pickle_filepath=pickle_filepath,
model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)},
),
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
requirements=["scikit-learn==0.22"],
)
self.zip_path_list.append(learnware_zippath)

def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
self.test_prepare_learnware_randomly(learnware_num)
self.learnware_num = learnware_num
@@ -83,7 +89,7 @@ class TestHeteroWorkflow(unittest.TestCase):
description=f"test_learnware_number_{idx}",
input_description=input_description_list[idx % 2],
output_description=output_description_list[idx % 2],
**self.universal_semantic_config
**self.universal_semantic_config,
)
hetero_market.add_learnware(zip_path, semantic_spec)

@@ -106,7 +112,7 @@ class TestHeteroWorkflow(unittest.TestCase):
assert len(curr_inds) == 0, f"The market should be empty!"

return hetero_market
def test_upload_delete_learnware(self, learnware_num=5, delete=True):
hetero_market = self._init_learnware_market()
return self._upload_delete_learnware(hetero_market, learnware_num, delete)
@@ -129,7 +135,7 @@ class TestHeteroWorkflow(unittest.TestCase):
name=f"learnware_{learnware_num - 1}",
**self.universal_semantic_config,
)
user_info = BaseUserInfo(semantic_spec=semantic_spec)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
@@ -154,7 +160,7 @@ class TestHeteroWorkflow(unittest.TestCase):
def test_hetero_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num, delete=False)
print("Total Item:", len(hetero_market))
user_dim = 15

with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
@@ -174,7 +180,10 @@ class TestHeteroWorkflow(unittest.TestCase):
semantic_spec = generate_semantic_spec(
input_description={
"Dimension": user_dim,
"Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
"Description": {
str(key): input_description_list[idx % 2]["Description"][str(key)]
for key in range(user_dim)
},
},
**self.universal_semantic_config,
)
@@ -182,7 +191,7 @@ class TestHeteroWorkflow(unittest.TestCase):
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
print(f"search result of user{idx}:")
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
@@ -215,7 +224,10 @@ class TestHeteroWorkflow(unittest.TestCase):
semantic_spec = generate_semantic_spec(
input_description={
"Dimension": user_dim - 2,
"Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
"Description": {
str(key): input_description_list[idx % 2]["Description"][str(key)]
for key in range(user_dim)
},
},
**self.universal_semantic_config,
)
@@ -228,7 +240,7 @@ class TestHeteroWorkflow(unittest.TestCase):
def test_homo_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num, delete=False)
print("Total Item:", len(hetero_market))
with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
for idx, zip_path in enumerate(self.zip_path_list):
with zipfile.ZipFile(zip_path, "r") as zip_obj:
@@ -260,7 +272,9 @@ class TestHeteroWorkflow(unittest.TestCase):
user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)

# generate specification
semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config)
semantic_spec = generate_semantic_spec(
input_description=user_description_list[0], **self.universal_semantic_config
)
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})

# learnware market search
@@ -268,7 +282,7 @@ class TestHeteroWorkflow(unittest.TestCase):
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
# print search results
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
@@ -306,9 +320,9 @@ class TestHeteroWorkflow(unittest.TestCase):

def suite():
_suite = unittest.TestSuite()
#_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
#_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
#_suite.addTest(TestHeteroWorkflow("test_train_market_model"))
# _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
# _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
# _suite.addTest(TestHeteroWorkflow("test_train_market_model"))
_suite.addTest(TestHeteroWorkflow("test_search_semantics"))
_suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
_suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))


+ 15
- 11
tests/test_workflow/test_workflow.py View File

@@ -10,6 +10,7 @@ from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split

import learnware

learnware.init(logging_level=logging.WARNING)

from learnware.market import instantiate_learnware_market, BaseUserInfo
@@ -19,8 +20,8 @@ from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, St

curr_root = os.path.dirname(os.path.abspath(__file__))


class TestWorkflow(unittest.TestCase):
universal_semantic_config = {
"data_type": "Table",
"task_type": "Classification",
@@ -28,7 +29,7 @@ class TestWorkflow(unittest.TestCase):
"scenarios": "Education",
"license": "MIT",
}
def _init_learnware_market(self):
"""initialize learnware market"""
easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True)
@@ -42,7 +43,7 @@ class TestWorkflow(unittest.TestCase):
learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool")
os.makedirs(learnware_pool_dirpath, exist_ok=True)
learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i))
print("Preparing Learnware: %d" % (i))
data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
clf = svm.SVC(kernel="linear", probability=True)
@@ -54,14 +55,17 @@ class TestWorkflow(unittest.TestCase):
spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
spec.save(spec_filepath)
LearnwareTemplate.generate_learnware_zipfile(
learnware_zippath=learnware_zippath,
model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}),
model_template=PickleModelTemplate(
pickle_filepath=pickle_filepath,
model_kwargs={"input_shape": (64,), "output_shape": (10,), "predict_method": "predict_proba"},
),
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
requirements=["scikit-learn==0.22"],
)
self.zip_path_list.append(learnware_zippath)

def test_upload_delete_learnware(self, learnware_num=5, delete=True):
@@ -87,7 +91,7 @@ class TestWorkflow(unittest.TestCase):
"Dimension": 10,
"Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)},
},
**self.universal_semantic_config
**self.universal_semantic_config,
)
easy_market.add_learnware(zip_path, semantic_spec)

@@ -113,7 +117,7 @@ class TestWorkflow(unittest.TestCase):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))
assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder:
with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj:
zip_obj.extractall(path=test_folder)
@@ -123,15 +127,15 @@ class TestWorkflow(unittest.TestCase):
description=f"test_learnware_number_{learnware_num - 1}",
**self.universal_semantic_config,
)
user_info = BaseUserInfo(semantic_spec=semantic_spec)
search_result = easy_market.search_learnware(user_info)
single_result = search_result.get_single_results()

print(f"Search result:")
for search_item in single_result:
print("Choose learnware:",search_item.learnware.id)
print("Choose learnware:", search_item.learnware.id)
def test_stat_search(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))


Loading…
Cancel
Save