Browse Source

[MNT] format code using black v23.1.0

tags/v0.3.2
Gene 1 year ago
parent
commit
3297847927
33 changed files with 194 additions and 136 deletions
  1. +2
    -2
      docs/conf.py
  2. +1
    -3
      examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py
  3. +2
    -2
      learnware/client/package_utils.py
  4. +11
    -7
      learnware/learnware/__init__.py
  5. +2
    -4
      learnware/market/__init__.py
  6. +1
    -2
      learnware/market/easy/__init__.py
  7. +2
    -4
      learnware/market/easy/searcher.py
  8. +2
    -3
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py
  9. +5
    -1
      learnware/market/heterogeneous/searcher.py
  10. +15
    -6
      learnware/market/module.py
  11. +1
    -1
      learnware/reuse/__init__.py
  12. +12
    -7
      learnware/reuse/ensemble_pruning.py
  13. +1
    -2
      learnware/reuse/job_selector.py
  14. +1
    -0
      learnware/reuse/utils.py
  15. +15
    -6
      learnware/specification/__init__.py
  16. +1
    -2
      learnware/specification/regular/table/__init__.py
  17. +5
    -3
      learnware/specification/regular/text/rkme.py
  18. +3
    -1
      learnware/specification/system/hetero_table.py
  19. +1
    -1
      learnware/tests/__init__.py
  20. +23
    -15
      learnware/tests/templates/__init__.py
  21. +2
    -3
      learnware/tests/templates/pickle_model.py
  22. +1
    -1
      learnware/tests/utils.py
  23. +1
    -2
      learnware/utils/__init__.py
  24. +1
    -0
      learnware/utils/file.py
  25. +1
    -0
      learnware/utils/gpu.py
  26. +7
    -4
      tests/test_learnware_client/test_container.py
  27. +6
    -5
      tests/test_learnware_client/test_load_learnware.py
  28. +5
    -5
      tests/test_specification/test_hetero_spec.py
  29. +2
    -1
      tests/test_specification/test_image_rkme.py
  30. +2
    -1
      tests/test_specification/test_table_rkme.py
  31. +13
    -13
      tests/test_specification/test_text_rkme.py
  32. +32
    -18
      tests/test_workflow/test_hetero_workflow.py
  33. +15
    -11
      tests/test_workflow/test_workflow.py

+ 2
- 2
docs/conf.py View File

@@ -100,12 +100,12 @@ html_logo = "_static/img/logo/logo1.png"




# These folders are copied to the documentation's HTML output # These folders are copied to the documentation's HTML output
html_static_path = ['_static']
html_static_path = ["_static"]


# These paths are either relative to html_static_path # These paths are either relative to html_static_path
# or fully qualified paths (eg. https://...) # or fully qualified paths (eg. https://...)
html_css_files = [ html_css_files = [
'css/custom_style.css',
"css/custom_style.css",
] ]


# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------


+ 1
- 3
examples/dataset_pfs_workflow/pfs/pfs_cross_transfer.py View File

@@ -85,9 +85,7 @@ def get_split_errs(algo):
split = train_xs.shape[0] - proportion_list[tmp] split = train_xs.shape[0] - proportion_list[tmp]
model.fit( model.fit(
train_xs[
split:,
],
train_xs[split:,],
train_ys[split:], train_ys[split:],
eval_set=[(val_xs, val_ys)], eval_set=[(val_xs, val_ys)],
early_stopping_rounds=50, early_stopping_rounds=50,


+ 2
- 2
learnware/client/package_utils.py View File

@@ -86,7 +86,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
pass pass
except Exception as err: except Exception as err:
logger.error(err) logger.error(err)
return None return None


exist_packages = [] exist_packages = []
@@ -101,7 +101,7 @@ def filter_nonexist_pip_packages(packages: list) -> Tuple[List[str], List[str]]:
exist_packages.append(result) exist_packages.append(result)
else: else:
nonexist_packages.append(package) nonexist_packages.append(package)
if len(nonexist_packages) > 0: if len(nonexist_packages) > 0:
logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.") logger.info(f"Filtered out {len(nonexist_packages)} non-exist pip packages.")
return exist_packages, nonexist_packages return exist_packages, nonexist_packages


+ 11
- 7
learnware/learnware/__init__.py View File

@@ -13,7 +13,9 @@ from ..utils import read_yaml_to_dict
logger = get_module_logger("learnware.learnware") logger = get_module_logger("learnware.learnware")




def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Optional[Learnware]:
def get_learnware_from_dirpath(
id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True
) -> Optional[Learnware]:
"""Get the learnware object from dirpath, and provide the manage interface tor Learnware class """Get the learnware object from dirpath, and provide the manage interface tor Learnware class


Parameters Parameters
@@ -46,11 +48,11 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath,
} }


try: try:
learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"]) learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"])
assert os.path.exists(learnware_yaml_path), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile."
assert os.path.exists(
learnware_yaml_path
), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile."

yaml_config = read_yaml_to_dict(learnware_yaml_path) yaml_config = read_yaml_to_dict(learnware_yaml_path)


if "name" in yaml_config: if "name" in yaml_config:
@@ -67,8 +69,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath,
for _stat_spec in learnware_config["stat_specifications"]: for _stat_spec in learnware_config["stat_specifications"]:
stat_spec = _stat_spec.copy() stat_spec = _stat_spec.copy()
stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"]) stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"])
assert os.path.exists(stat_spec_path), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile."
assert os.path.exists(
stat_spec_path
), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile."

stat_spec["file_name"] = stat_spec_path stat_spec["file_name"] = stat_spec_path
stat_spec_inst = get_stat_spec_from_config(stat_spec) stat_spec_inst = get_stat_spec_from_config(stat_spec)
learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst}) learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst})


+ 2
- 4
learnware/market/__init__.py View File

@@ -1,9 +1,7 @@
from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo from .anchor import AnchoredOrganizer, AnchoredSearcher, AnchoredUserInfo
from .base import (BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo,
LearnwareMarket)
from .base import BaseChecker, BaseOrganizer, BaseSearcher, BaseUserInfo, LearnwareMarket
from .classes import CondaChecker from .classes import CondaChecker
from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker,
EasyStatChecker)
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .evolve import EvolvedOrganizer from .evolve import EvolvedOrganizer
from .evolve_anchor import EvolvedAnchoredOrganizer from .evolve_anchor import EvolvedAnchoredOrganizer
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher


+ 1
- 2
learnware/market/easy/__init__.py View File

@@ -11,5 +11,4 @@ if not is_torch_available(verbose=False):
logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!")
else: else:
from .checker import EasySemanticChecker, EasyStatChecker from .checker import EasySemanticChecker, EasyStatChecker
from .searcher import (EasyExactSemanticSearcher, EasyFuzzSemanticSearcher,
EasySearcher, EasyStatSearcher)
from .searcher import EasyExactSemanticSearcher, EasyFuzzSemanticSearcher, EasySearcher, EasyStatSearcher

+ 2
- 4
learnware/market/easy/searcher.py View File

@@ -6,13 +6,11 @@ import torch
from rapidfuzz import fuzz from rapidfuzz import fuzz


from .organizer import EasyOrganizer from .organizer import EasyOrganizer
from ..base import (BaseSearcher, BaseUserInfo, MultipleSearchItem,
SearchResults, SingleSearchItem)
from ..base import BaseSearcher, BaseUserInfo, MultipleSearchItem, SearchResults, SingleSearchItem
from ..utils import parse_specification_type from ..utils import parse_specification_type
from ...learnware import Learnware from ...learnware import Learnware
from ...logger import get_module_logger from ...logger import get_module_logger
from ...specification import (RKMEImageSpecification, RKMETableSpecification,
RKMETextSpecification, rkme_solve_qp)
from ...specification import RKMEImageSpecification, RKMETableSpecification, RKMETextSpecification, rkme_solve_qp


logger = get_module_logger("easy_seacher") logger = get_module_logger("easy_seacher")




+ 2
- 3
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -8,8 +8,7 @@ from torch import nn


from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer
from .trainer import Trainer, TransTabCollatorForCL from .trainer import Trainer, TransTabCollatorForCL
from .....specification import (HeteroMapTableSpecification,
RKMETableSpecification)
from .....specification import HeteroMapTableSpecification, RKMETableSpecification
from .....utils import allocate_cuda_idx, choose_device from .....utils import allocate_cuda_idx, choose_device




@@ -288,7 +287,7 @@ class HeteroMap(nn.Module):
# go through transformers, get the first cls embedding # go through transformers, get the first cls embedding
encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim
output_features = encoder_output[:, 0, :] output_features = encoder_output[:, 0, :]
del inputs, outputs, encoder_output del inputs, outputs, encoder_output
torch.cuda.empty_cache() torch.cuda.empty_cache()




+ 5
- 1
learnware/market/heterogeneous/searcher.py View File

@@ -11,7 +11,11 @@ logger = get_module_logger("hetero_searcher")


class HeteroSearcher(EasySearcher): class HeteroSearcher(EasySearcher):
def __call__( def __call__(
self, user_info: BaseUserInfo, check_status: Optional[int] = None, max_search_num: int = 5, search_method: str = "greedy"
self,
user_info: BaseUserInfo,
check_status: Optional[int] = None,
max_search_num: int = 5,
search_method: str = "greedy",
) -> SearchResults: ) -> SearchResults:
"""Search learnwares based on user_info from learnwares with check_status. """Search learnwares based on user_info from learnwares with check_status.
Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods.


+ 15
- 6
learnware/market/module.py View File

@@ -1,11 +1,12 @@
from .base import LearnwareMarket from .base import LearnwareMarket
from .classes import CondaChecker from .classes import CondaChecker
from .easy import (EasyOrganizer, EasySearcher, EasySemanticChecker,
EasyStatChecker)
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher




def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False):
def get_market_component(
name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False
):
organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs
searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs
checker_kwargs = {} if checker_kwargs is None else checker_kwargs checker_kwargs = {} if checker_kwargs is None else checker_kwargs
@@ -13,7 +14,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search
if name == "easy": if name == "easy":
easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild)
easy_searcher = EasySearcher(organizer=easy_organizer) easy_searcher = EasySearcher(organizer=easy_organizer)
easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())]
easy_checker_list = [
EasySemanticChecker(),
EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()),
]
market_component = { market_component = {
"organizer": easy_organizer, "organizer": easy_organizer,
"searcher": easy_searcher, "searcher": easy_searcher,
@@ -22,7 +26,10 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search
elif name == "hetero": elif name == "hetero":
hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs)
hetero_searcher = HeteroSearcher(organizer=hetero_organizer) hetero_searcher = HeteroSearcher(organizer=hetero_organizer)
hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())]
hetero_checker_list = [
EasySemanticChecker(),
EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker()),
]


market_component = { market_component = {
"organizer": hetero_organizer, "organizer": hetero_organizer,
@@ -45,7 +52,9 @@ def instantiate_learnware_market(
conda_checker: bool = False, conda_checker: bool = False,
**kwargs, **kwargs,
): ):
market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker)
market_componets = get_market_component(
name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker
)
return LearnwareMarket( return LearnwareMarket(
organizer=market_componets["organizer"], organizer=market_componets["organizer"],
searcher=market_componets["searcher"], searcher=market_componets["searcher"],


+ 1
- 1
learnware/reuse/__init__.py View File

@@ -20,4 +20,4 @@ else:
from .ensemble_pruning import EnsemblePruningReuser from .ensemble_pruning import EnsemblePruningReuser
from .feature_augment import FeatureAugmentReuser from .feature_augment import FeatureAugmentReuser
from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware from .hetero import FeatureAlignLearnware, HeteroMapAlignLearnware
from .job_selector import JobSelectorReuser
from .job_selector import JobSelectorReuser

+ 12
- 7
learnware/reuse/ensemble_pruning.py View File

@@ -54,13 +54,14 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected. Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
""" """

try: try:
import geatpy as ea import geatpy as ea
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
raise ModuleNotFoundError(
f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)."
)

model_num = v_predict.shape[1] model_num = v_predict.shape[1]


@ea.Problem.single @ea.Problem.single
@@ -148,7 +149,9 @@ class EnsemblePruningReuser(BaseReuser):
try: try:
import geatpy as ea import geatpy as ea
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
raise ModuleNotFoundError(
f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)."
)


if torch.is_tensor(v_true): if torch.is_tensor(v_true):
v_true = v_true.detach().cpu().numpy() v_true = v_true.detach().cpu().numpy()
@@ -270,8 +273,10 @@ class EnsemblePruningReuser(BaseReuser):
try: try:
import geatpy as ea import geatpy as ea
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
raise ModuleNotFoundError(
f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11)."
)

model_num = v_predict.shape[1] model_num = v_predict.shape[1]
v_predict[v_predict == 0.0] = -1 v_predict[v_predict == 0.0] = -1
v_true[v_true == 0.0] = -1 v_true[v_true == 0.0] = -1


+ 1
- 2
learnware/reuse/job_selector.py View File

@@ -8,8 +8,7 @@ from .base import BaseReuser
from ..learnware import Learnware from ..learnware import Learnware
from ..logger import get_module_logger from ..logger import get_module_logger
from ..market.utils import parse_specification_type from ..market.utils import parse_specification_type
from ..specification import (RKMETableSpecification, RKMETextSpecification,
generate_rkme_table_spec, rkme_solve_qp)
from ..specification import RKMETableSpecification, RKMETextSpecification, generate_rkme_table_spec, rkme_solve_qp


logger = get_module_logger("job_selector_reuse") logger = get_module_logger("job_selector_reuse")




+ 1
- 0
learnware/reuse/utils.py View File

@@ -4,6 +4,7 @@ from ..logger import get_module_logger


logger = get_module_logger("reuse_utils") logger = get_module_logger("reuse_utils")



def fill_data_with_mean(X: np.ndarray) -> np.ndarray: def fill_data_with_mean(X: np.ndarray) -> np.ndarray:
""" """
Fill missing data (NaN, Inf) in the input array with the mean of the column. Fill missing data (NaN, Inf) in the input array with the mean of the column.


+ 15
- 6
learnware/specification/__init__.py View File

@@ -1,7 +1,12 @@
from .base import BaseStatSpecification, Specification from .base import BaseStatSpecification, Specification
from .regular import (RegularStatSpecification, RKMEImageSpecification,
RKMEStatSpecification, RKMETableSpecification,
RKMETextSpecification, rkme_solve_qp)
from .regular import (
RegularStatSpecification,
RKMEImageSpecification,
RKMEStatSpecification,
RKMETableSpecification,
RKMETextSpecification,
rkme_solve_qp,
)
from .system import HeteroMapTableSpecification from .system import HeteroMapTableSpecification
from ..utils import is_torch_available from ..utils import is_torch_available


@@ -12,6 +17,10 @@ if not is_torch_available(verbose=False):
generate_rkme_text_spec = None generate_rkme_text_spec = None
generate_semantic_spec = None generate_semantic_spec = None
else: else:
from .module import (generate_rkme_image_spec, generate_rkme_table_spec,
generate_rkme_text_spec, generate_semantic_spec,
generate_stat_spec)
from .module import (
generate_rkme_image_spec,
generate_rkme_table_spec,
generate_rkme_text_spec,
generate_semantic_spec,
generate_stat_spec,
)

+ 1
- 2
learnware/specification/regular/table/__init__.py View File

@@ -11,5 +11,4 @@ if not is_torch_available(verbose=False):
f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!" f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!"
) )
else: else:
from .rkme import (RKMEStatSpecification, RKMETableSpecification,
rkme_solve_qp)
from .rkme import RKMEStatSpecification, RKMETableSpecification, rkme_solve_qp

+ 5
- 3
learnware/specification/regular/text/rkme.py View File

@@ -87,12 +87,14 @@ class RKMETextSpecification(RKMETableSpecification):
return np.array(miniLM_learnware.predict(X)) return np.array(miniLM_learnware.predict(X))


logger.info("Load the necessary feature extractor for RKMETextSpecification.") logger.info("Load the necessary feature extractor for RKMETextSpecification.")
try: try:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.")
raise ModuleNotFoundError(
f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually."
)

if os.path.exists(zip_path): if os.path.exists(zip_path):
X = _get_from_client(zip_path, X) X = _get_from_client(zip_path, X)
else: else:


+ 3
- 1
learnware/specification/system/hetero_table.py View File

@@ -137,7 +137,9 @@ class HeteroMapTableSpecification(SystemStatSpecification):
for d in self.get_states(): for d in self.get_states():
if d in embedding_load.keys(): if d in embedding_load.keys():
if d == "type" and embedding_load[d] != self.type: if d == "type" and embedding_load[d] != self.type:
raise TypeError(f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!")
raise TypeError(
f"The type of loaded RKME ({embedding_load[d]}) is different from the expected type ({self.type})!"
)
setattr(self, d, embedding_load[d]) setattr(self, d, embedding_load[d])


def save(self, filepath: str) -> bool: def save(self, filepath: str) -> bool:


+ 1
- 1
learnware/tests/__init__.py View File

@@ -1 +1 @@
from .utils import parametrize
from .utils import parametrize

+ 23
- 15
learnware/tests/templates/__init__.py View File

@@ -13,10 +13,13 @@ class ModelTemplate:
class_name: str = field(init=False) class_name: str = field(init=False)
template_path: str = field(init=False) template_path: str = field(init=False)
model_kwargs: dict = field(init=False) model_kwargs: dict = field(init=False)


@dataclass @dataclass
class PickleModelTemplate(ModelTemplate): class PickleModelTemplate(ModelTemplate):
model_kwargs: dict model_kwargs: dict
pickle_filepath: str pickle_filepath: str

def __post_init__(self): def __post_init__(self):
self.class_name = "PickleLoadedModel" self.class_name = "PickleLoadedModel"
self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py")
@@ -29,13 +32,14 @@ class PickleModelTemplate(ModelTemplate):
default_model_kwargs.update(self.model_kwargs) default_model_kwargs.update(self.model_kwargs)
self.model_kwargs = default_model_kwargs self.model_kwargs = default_model_kwargs



@dataclass @dataclass
class StatSpecTemplate: class StatSpecTemplate:
filepath: str filepath: str
type: str = field(default="RKMETableSpecification") type: str = field(default="RKMETableSpecification")
class LearnwareTemplate:



class LearnwareTemplate:
@staticmethod @staticmethod
def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None): def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None):
requirements = [] if requirements is None else requirements requirements = [] if requirements is None else requirements
@@ -49,14 +53,16 @@ class LearnwareTemplate:
line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n" line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n"
else: else:
raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}") raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}")
requirements_str += line_str requirements_str += line_str
with open(filepath, "w") as fdout: with open(filepath, "w") as fdout:
fdout.write(requirements_str) fdout.write(requirements_str)
@staticmethod @staticmethod
def generate_learnware_yaml(filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None):
def generate_learnware_yaml(
filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None
):
learnware_config = {} learnware_config = {}
if model_config is not None: if model_config is not None:
learnware_config["model"] = model_config learnware_config["model"] = model_config
@@ -64,7 +70,7 @@ class LearnwareTemplate:
learnware_config["stat_specifications"] = stat_spec_config learnware_config["stat_specifications"] = stat_spec_config


save_dict_to_yaml(learnware_config, filepath) save_dict_to_yaml(learnware_config, filepath)
@staticmethod @staticmethod
def generate_learnware_zipfile( def generate_learnware_zipfile(
learnware_zippath: str, learnware_zippath: str,
@@ -75,27 +81,29 @@ class LearnwareTemplate:
with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir:
requirement_filepath = os.path.join(tempdir, "requirements.txt") requirement_filepath = os.path.join(tempdir, "requirements.txt")
LearnwareTemplate.generate_requirements(requirement_filepath, requirements) LearnwareTemplate.generate_requirements(requirement_filepath, requirements)
model_filepath = os.path.join(tempdir, "__init__.py")
model_filepath = os.path.join(tempdir, "__init__.py")
copyfile(model_template.template_path, model_filepath) copyfile(model_template.template_path, model_filepath)
learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml") learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml")
model_config = { model_config = {
"class_name": model_template.class_name, "class_name": model_template.class_name,
"kwargs": model_template.model_kwargs, "kwargs": model_template.model_kwargs,
} }
stat_spec_config = { stat_spec_config = {
"module_path": "learnware.specification", "module_path": "learnware.specification",
"class_name": stat_spec_template.type, "class_name": stat_spec_template.type,
"file_name": "stat_spec.json", "file_name": "stat_spec.json",
"kwargs": {}
"kwargs": {},
} }
copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"])) copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"]))
LearnwareTemplate.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config])
LearnwareTemplate.generate_learnware_yaml(
learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config]
)

if isinstance(model_template, PickleModelTemplate): if isinstance(model_template, PickleModelTemplate):
pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"]) pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"])
copyfile(model_template.pickle_filepath, pickle_filepath) copyfile(model_template.pickle_filepath, pickle_filepath)
convert_folder_to_zipfile(tempdir, learnware_zippath) convert_folder_to_zipfile(tempdir, learnware_zippath)

+ 2
- 3
learnware/tests/templates/pickle_model.py View File

@@ -7,7 +7,6 @@ from learnware.model.base import BaseModel




class PickleLoadedModel(BaseModel): class PickleLoadedModel(BaseModel):
def __init__( def __init__(
self, self,
input_shape, input_shape,
@@ -25,10 +24,10 @@ class PickleLoadedModel(BaseModel):
self.predict_method = predict_method self.predict_method = predict_method
self.fit_method = fit_method self.fit_method = fit_method
self.finetune_method = finetune_method self.finetune_method = finetune_method
def predict(self, X: np.ndarray) -> np.ndarray: def predict(self, X: np.ndarray) -> np.ndarray:
return getattr(self.model, self.predict_method)(X) return getattr(self.model, self.predict_method)(X)
def fit(self, X: np.ndarray, y: np.ndarray): def fit(self, X: np.ndarray, y: np.ndarray):
getattr(self.model, self.fit_method)(X, y) getattr(self.model, self.fit_method)(X, y)




+ 1
- 1
learnware/tests/utils.py View File

@@ -7,4 +7,4 @@ def parametrize(test_class, **kwargs):
_suite = unittest.TestSuite() _suite = unittest.TestSuite()
for name in test_names: for name in test_names:
_suite.addTest(test_class(name, **kwargs)) _suite.addTest(test_class(name, **kwargs))
return _suite
return _suite

+ 1
- 2
learnware/utils/__init__.py View File

@@ -1,8 +1,7 @@
import os import os
import zipfile import zipfile


from .file import (convert_folder_to_zipfile, read_yaml_to_dict,
save_dict_to_yaml)
from .file import convert_folder_to_zipfile, read_yaml_to_dict, save_dict_to_yaml
from .gpu import allocate_cuda_idx, choose_device, setup_seed from .gpu import allocate_cuda_idx, choose_device, setup_seed
from .import_utils import is_torch_available from .import_utils import is_torch_available
from .module import get_module_by_module_path from .module import get_module_by_module_path


+ 1
- 0
learnware/utils/file.py View File

@@ -16,6 +16,7 @@ def read_yaml_to_dict(yaml_path: str):
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value return dict_value



def convert_folder_to_zipfile(folder_path, zip_path): def convert_folder_to_zipfile(folder_path, zip_path):
with zipfile.ZipFile(zip_path, "w") as zip_obj: with zipfile.ZipFile(zip_path, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(folder_path): for foldername, subfolders, filenames in os.walk(folder_path):


+ 1
- 0
learnware/utils/gpu.py View File

@@ -17,6 +17,7 @@ def setup_seed(seed):
random.seed(seed) random.seed(seed)
if is_torch_available(verbose=False): if is_torch_available(verbose=False):
import torch import torch

torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True


+ 7
- 4
tests/test_learnware_client/test_container.py View File

@@ -4,15 +4,16 @@ import numpy as np
from learnware.client import LearnwareClient from learnware.client import LearnwareClient
from learnware.client.container import LearnwaresContainer from learnware.client.container import LearnwaresContainer



class TestContainer(unittest.TestCase): class TestContainer(unittest.TestCase):
def __init__(self, method_name='runTest', mode="all"):
def __init__(self, method_name="runTest", mode="all"):
super(TestContainer, self).__init__(method_name) super(TestContainer, self).__init__(method_name)
self.modes = [] self.modes = []
if mode in {"all", "conda"}: if mode in {"all", "conda"}:
self.modes.append("conda") self.modes.append("conda")
if mode in {"all", "docker"}: if mode in {"all", "docker"}:
self.modes.append("docker") self.modes.append("docker")
def setUp(self): def setUp(self):
self.client = LearnwareClient() self.client = LearnwareClient()


@@ -35,17 +36,19 @@ class TestContainer(unittest.TestCase):
def test_container_with_pip(self): def test_container_with_pip(self):
for mode in self.modes: for mode in self.modes:
self._test_container_with_pip(mode=mode) self._test_container_with_pip(mode=mode)
def test_container_with_conda(self): def test_container_with_conda(self):
for mode in self.modes: for mode in self.modes:
self._test_container_with_conda(mode=mode) self._test_container_with_conda(mode=mode)



def suite(): def suite():
_suite = unittest.TestSuite() _suite = unittest.TestSuite()
_suite.addTest(TestContainer("test_container_with_pip", mode="all")) _suite.addTest(TestContainer("test_container_with_pip", mode="all"))
_suite.addTest(TestContainer("test_container_with_conda", mode="all")) _suite.addTest(TestContainer("test_container_with_conda", mode="all"))
return _suite return _suite



if __name__ == "__main__": if __name__ == "__main__":
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()
runner.run(suite())
runner.run(suite())

+ 6
- 5
tests/test_learnware_client/test_load_learnware.py View File

@@ -5,8 +5,9 @@ import numpy as np
from learnware.client import LearnwareClient from learnware.client import LearnwareClient
from learnware.reuse import AveragingReuser from learnware.reuse import AveragingReuser



class TestLearnwareLoad(unittest.TestCase): class TestLearnwareLoad(unittest.TestCase):
def __init__(self, method_name='runTest', mode="all"):
def __init__(self, method_name="runTest", mode="all"):
super(TestLearnwareLoad, self).__init__(method_name) super(TestLearnwareLoad, self).__init__(method_name)
self.runnable_options = [] self.runnable_options = []
if mode in {"all", "conda"}: if mode in {"all", "conda"}:
@@ -31,7 +32,6 @@ class TestLearnwareLoad(unittest.TestCase):
for learnware in learnware_list: for learnware in learnware_list:
print(learnware.id, learnware.predict(input_array)) print(learnware.id, learnware.predict(input_array))



def _test_load_learnware_by_id(self, runnable_option): def _test_load_learnware_by_id(self, runnable_option):
learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option) learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option)
reuser = AveragingReuser(learnware_list, mode="vote_by_label") reuser = AveragingReuser(learnware_list, mode="vote_by_label")
@@ -44,11 +44,11 @@ class TestLearnwareLoad(unittest.TestCase):
def test_load_learnware_by_zippath(self): def test_load_learnware_by_zippath(self):
for runnable_option in self.runnable_options: for runnable_option in self.runnable_options:
self._test_load_learnware_by_zippath(runnable_option=runnable_option) self._test_load_learnware_by_zippath(runnable_option=runnable_option)
def test_load_learnware_by_id(self): def test_load_learnware_by_id(self):
for runnable_option in self.runnable_options: for runnable_option in self.runnable_options:
self._test_load_learnware_by_id(runnable_option=runnable_option) self._test_load_learnware_by_id(runnable_option=runnable_option)


def suite(): def suite():
_suite = unittest.TestSuite() _suite = unittest.TestSuite()
@@ -56,6 +56,7 @@ def suite():
_suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all")) _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all"))
return _suite return _suite



if __name__ == "__main__": if __name__ == "__main__":
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()
runner.run(suite())
runner.run(suite())

+ 5
- 5
tests/test_specification/test_hetero_spec.py View File

@@ -11,11 +11,11 @@ from learnware.specification import RKMETableSpecification, HeteroMapTableSpecif
from learnware.specification import generate_stat_spec from learnware.specification import generate_stat_spec
from learnware.market.heterogeneous.organizer import HeteroMap from learnware.market.heterogeneous.organizer import HeteroMap



class TestTableRKME(unittest.TestCase): class TestTableRKME(unittest.TestCase):
def setUp(self): def setUp(self):
self.hetero_map = HeteroMap() self.hetero_map = HeteroMap()
def _test_hetero_spec(self, X): def _test_hetero_spec(self, X):
rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X) rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X)
hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict()) hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict())
@@ -30,14 +30,14 @@ class TestTableRKME(unittest.TestCase):
rkme2 = HeteroMapTableSpecification() rkme2 = HeteroMapTableSpecification()
rkme2.load(rkme_path) rkme2.load(rkme_path)
assert rkme2.type == "HeteroMapTableSpecification" assert rkme2.type == "HeteroMapTableSpecification"

def test_hetero_rkme(self): def test_hetero_rkme(self):
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50)))
self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150))) self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150)))


if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

+ 2
- 1
tests/test_specification/test_image_rkme.py View File

@@ -25,7 +25,7 @@ class TestImageRKME(unittest.TestCase):
rkme2 = RKMEImageSpecification() rkme2 = RKMEImageSpecification()
rkme2.load(rkme_path) rkme2.load(rkme_path)
assert rkme2.type == "RKMEImageSpecification" assert rkme2.type == "RKMEImageSpecification"
def test_image_rkme(self): def test_image_rkme(self):
self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128)))
@@ -34,5 +34,6 @@ class TestImageRKME(unittest.TestCase):
self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128)))
self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255)



if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

+ 2
- 1
tests/test_specification/test_table_rkme.py View File

@@ -24,7 +24,7 @@ class TestTableRKME(unittest.TestCase):
rkme2 = RKMETableSpecification() rkme2 = RKMETableSpecification()
rkme2.load(rkme_path) rkme2.load(rkme_path)
assert rkme2.type == "RKMETableSpecification" assert rkme2.type == "RKMETableSpecification"
def test_table_rkme(self): def test_table_rkme(self):
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200)))
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100)))
@@ -32,5 +32,6 @@ class TestTableRKME(unittest.TestCase):
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50)))
self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150)))



if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

+ 13
- 13
tests/test_specification/test_text_rkme.py View File

@@ -12,19 +12,19 @@ from learnware.specification import generate_stat_spec
class TestTextRKME(unittest.TestCase): class TestTextRKME(unittest.TestCase):
@staticmethod @staticmethod
def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list
text_list = []
for i in range(num):
length = random.randint(min_len, max_len)
if text_type == "en":
characters = string.ascii_letters + string.digits + string.punctuation
result_str = "".join(random.choice(characters) for i in range(length))
text_list.append(result_str)
elif text_type == "zh":
result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
text_list.append(result_str)
else:
raise ValueError("Type should be en or zh")
return text_list


@staticmethod @staticmethod
def _test_text_rkme(X): def _test_text_rkme(X):


+ 32
- 18
tests/test_workflow/test_hetero_workflow.py View File

@@ -11,6 +11,7 @@ from shutil import copyfile, rmtree
from sklearn.metrics import mean_squared_error from sklearn.metrics import mean_squared_error


import learnware import learnware

learnware.init(logging_level=logging.WARNING) learnware.init(logging_level=logging.WARNING)


from learnware.market import instantiate_learnware_market, BaseUserInfo from learnware.market import instantiate_learnware_market, BaseUserInfo
@@ -23,6 +24,7 @@ from hetero_config import input_shape_list, input_description_list, output_descr


curr_root = os.path.dirname(os.path.abspath(__file__)) curr_root = os.path.dirname(os.path.abspath(__file__))



class TestHeteroWorkflow(unittest.TestCase): class TestHeteroWorkflow(unittest.TestCase):
universal_semantic_config = { universal_semantic_config = {
"data_type": "Table", "data_type": "Table",
@@ -46,10 +48,12 @@ class TestHeteroWorkflow(unittest.TestCase):
learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero") learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
os.makedirs(learnware_pool_dirpath, exist_ok=True) os.makedirs(learnware_pool_dirpath, exist_ok=True)
learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i)) learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
print("Preparing Learnware: %d" % (i)) print("Preparing Learnware: %d" % (i))


X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42)
X, y = make_regression(
n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42
)
clf = Ridge(alpha=1.0) clf = Ridge(alpha=1.0)
clf.fit(X, y) clf.fit(X, y)
pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl") pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
@@ -62,14 +66,16 @@ class TestHeteroWorkflow(unittest.TestCase):


LearnwareTemplate.generate_learnware_zipfile( LearnwareTemplate.generate_learnware_zipfile(
learnware_zippath=learnware_zippath, learnware_zippath=learnware_zippath,
model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}),
model_template=PickleModelTemplate(
pickle_filepath=pickle_filepath,
model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)},
),
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
requirements=["scikit-learn==0.22"], requirements=["scikit-learn==0.22"],
) )
self.zip_path_list.append(learnware_zippath) self.zip_path_list.append(learnware_zippath)


def _upload_delete_learnware(self, hetero_market, learnware_num, delete): def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
self.test_prepare_learnware_randomly(learnware_num) self.test_prepare_learnware_randomly(learnware_num)
self.learnware_num = learnware_num self.learnware_num = learnware_num
@@ -83,7 +89,7 @@ class TestHeteroWorkflow(unittest.TestCase):
description=f"test_learnware_number_{idx}", description=f"test_learnware_number_{idx}",
input_description=input_description_list[idx % 2], input_description=input_description_list[idx % 2],
output_description=output_description_list[idx % 2], output_description=output_description_list[idx % 2],
**self.universal_semantic_config
**self.universal_semantic_config,
) )
hetero_market.add_learnware(zip_path, semantic_spec) hetero_market.add_learnware(zip_path, semantic_spec)


@@ -106,7 +112,7 @@ class TestHeteroWorkflow(unittest.TestCase):
assert len(curr_inds) == 0, f"The market should be empty!" assert len(curr_inds) == 0, f"The market should be empty!"


return hetero_market return hetero_market
def test_upload_delete_learnware(self, learnware_num=5, delete=True): def test_upload_delete_learnware(self, learnware_num=5, delete=True):
hetero_market = self._init_learnware_market() hetero_market = self._init_learnware_market()
return self._upload_delete_learnware(hetero_market, learnware_num, delete) return self._upload_delete_learnware(hetero_market, learnware_num, delete)
@@ -129,7 +135,7 @@ class TestHeteroWorkflow(unittest.TestCase):
name=f"learnware_{learnware_num - 1}", name=f"learnware_{learnware_num - 1}",
**self.universal_semantic_config, **self.universal_semantic_config,
) )
user_info = BaseUserInfo(semantic_spec=semantic_spec) user_info = BaseUserInfo(semantic_spec=semantic_spec)
search_result = hetero_market.search_learnware(user_info) search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results() single_result = search_result.get_single_results()
@@ -154,7 +160,7 @@ class TestHeteroWorkflow(unittest.TestCase):
def test_hetero_stat_search(self, learnware_num=5): def test_hetero_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num, delete=False) hetero_market = self.test_train_market_model(learnware_num, delete=False)
print("Total Item:", len(hetero_market)) print("Total Item:", len(hetero_market))
user_dim = 15 user_dim = 15


with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
@@ -174,7 +180,10 @@ class TestHeteroWorkflow(unittest.TestCase):
semantic_spec = generate_semantic_spec( semantic_spec = generate_semantic_spec(
input_description={ input_description={
"Dimension": user_dim, "Dimension": user_dim,
"Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
"Description": {
str(key): input_description_list[idx % 2]["Description"][str(key)]
for key in range(user_dim)
},
}, },
**self.universal_semantic_config, **self.universal_semantic_config,
) )
@@ -182,7 +191,7 @@ class TestHeteroWorkflow(unittest.TestCase):
search_result = hetero_market.search_learnware(user_info) search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results() single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results() multiple_result = search_result.get_multiple_results()
print(f"search result of user{idx}:") print(f"search result of user{idx}:")
for single_item in single_result: for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
@@ -215,7 +224,10 @@ class TestHeteroWorkflow(unittest.TestCase):
semantic_spec = generate_semantic_spec( semantic_spec = generate_semantic_spec(
input_description={ input_description={
"Dimension": user_dim - 2, "Dimension": user_dim - 2,
"Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
"Description": {
str(key): input_description_list[idx % 2]["Description"][str(key)]
for key in range(user_dim)
},
}, },
**self.universal_semantic_config, **self.universal_semantic_config,
) )
@@ -228,7 +240,7 @@ class TestHeteroWorkflow(unittest.TestCase):
def test_homo_stat_search(self, learnware_num=5): def test_homo_stat_search(self, learnware_num=5):
hetero_market = self.test_train_market_model(learnware_num, delete=False) hetero_market = self.test_train_market_model(learnware_num, delete=False)
print("Total Item:", len(hetero_market)) print("Total Item:", len(hetero_market))
with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
for idx, zip_path in enumerate(self.zip_path_list): for idx, zip_path in enumerate(self.zip_path_list):
with zipfile.ZipFile(zip_path, "r") as zip_obj: with zipfile.ZipFile(zip_path, "r") as zip_obj:
@@ -260,7 +272,9 @@ class TestHeteroWorkflow(unittest.TestCase):
user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)


# generate specification # generate specification
semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config)
semantic_spec = generate_semantic_spec(
input_description=user_description_list[0], **self.universal_semantic_config
)
user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})


# learnware market search # learnware market search
@@ -268,7 +282,7 @@ class TestHeteroWorkflow(unittest.TestCase):
search_result = hetero_market.search_learnware(user_info) search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results() single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results() multiple_result = search_result.get_multiple_results()
# print search results # print search results
for single_item in single_result: for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
@@ -306,9 +320,9 @@ class TestHeteroWorkflow(unittest.TestCase):


def suite(): def suite():
_suite = unittest.TestSuite() _suite = unittest.TestSuite()
#_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
#_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
#_suite.addTest(TestHeteroWorkflow("test_train_market_model"))
# _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
# _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
# _suite.addTest(TestHeteroWorkflow("test_train_market_model"))
_suite.addTest(TestHeteroWorkflow("test_search_semantics")) _suite.addTest(TestHeteroWorkflow("test_search_semantics"))
_suite.addTest(TestHeteroWorkflow("test_hetero_stat_search")) _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
_suite.addTest(TestHeteroWorkflow("test_homo_stat_search")) _suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))


+ 15
- 11
tests/test_workflow/test_workflow.py View File

@@ -10,6 +10,7 @@ from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split


import learnware import learnware

learnware.init(logging_level=logging.WARNING) learnware.init(logging_level=logging.WARNING)


from learnware.market import instantiate_learnware_market, BaseUserInfo from learnware.market import instantiate_learnware_market, BaseUserInfo
@@ -19,8 +20,8 @@ from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, St


curr_root = os.path.dirname(os.path.abspath(__file__)) curr_root = os.path.dirname(os.path.abspath(__file__))



class TestWorkflow(unittest.TestCase): class TestWorkflow(unittest.TestCase):
universal_semantic_config = { universal_semantic_config = {
"data_type": "Table", "data_type": "Table",
"task_type": "Classification", "task_type": "Classification",
@@ -28,7 +29,7 @@ class TestWorkflow(unittest.TestCase):
"scenarios": "Education", "scenarios": "Education",
"license": "MIT", "license": "MIT",
} }
def _init_learnware_market(self): def _init_learnware_market(self):
"""initialize learnware market""" """initialize learnware market"""
easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True)
@@ -42,7 +43,7 @@ class TestWorkflow(unittest.TestCase):
learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool") learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool")
os.makedirs(learnware_pool_dirpath, exist_ok=True) os.makedirs(learnware_pool_dirpath, exist_ok=True)
learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i)) learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i))
print("Preparing Learnware: %d" % (i)) print("Preparing Learnware: %d" % (i))
data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
clf = svm.SVC(kernel="linear", probability=True) clf = svm.SVC(kernel="linear", probability=True)
@@ -54,14 +55,17 @@ class TestWorkflow(unittest.TestCase):
spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
spec.save(spec_filepath) spec.save(spec_filepath)
LearnwareTemplate.generate_learnware_zipfile( LearnwareTemplate.generate_learnware_zipfile(
learnware_zippath=learnware_zippath, learnware_zippath=learnware_zippath,
model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}),
model_template=PickleModelTemplate(
pickle_filepath=pickle_filepath,
model_kwargs={"input_shape": (64,), "output_shape": (10,), "predict_method": "predict_proba"},
),
stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
requirements=["scikit-learn==0.22"], requirements=["scikit-learn==0.22"],
) )
self.zip_path_list.append(learnware_zippath) self.zip_path_list.append(learnware_zippath)


def test_upload_delete_learnware(self, learnware_num=5, delete=True): def test_upload_delete_learnware(self, learnware_num=5, delete=True):
@@ -87,7 +91,7 @@ class TestWorkflow(unittest.TestCase):
"Dimension": 10, "Dimension": 10,
"Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)},
}, },
**self.universal_semantic_config
**self.universal_semantic_config,
) )
easy_market.add_learnware(zip_path, semantic_spec) easy_market.add_learnware(zip_path, semantic_spec)


@@ -113,7 +117,7 @@ class TestWorkflow(unittest.TestCase):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market)) print("Total Item:", len(easy_market))
assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder:
with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj:
zip_obj.extractall(path=test_folder) zip_obj.extractall(path=test_folder)
@@ -123,15 +127,15 @@ class TestWorkflow(unittest.TestCase):
description=f"test_learnware_number_{learnware_num - 1}", description=f"test_learnware_number_{learnware_num - 1}",
**self.universal_semantic_config, **self.universal_semantic_config,
) )
user_info = BaseUserInfo(semantic_spec=semantic_spec) user_info = BaseUserInfo(semantic_spec=semantic_spec)
search_result = easy_market.search_learnware(user_info) search_result = easy_market.search_learnware(user_info)
single_result = search_result.get_single_results() single_result = search_result.get_single_results()


print(f"Search result:") print(f"Search result:")
for search_item in single_result: for search_item in single_result:
print("Choose learnware:",search_item.learnware.id)
print("Choose learnware:", search_item.learnware.id)
def test_stat_search(self, learnware_num=5): def test_stat_search(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market)) print("Total Item:", len(easy_market))


Loading…
Cancel
Save