| @@ -70,7 +70,14 @@ class LearnwareClient: | |||
| self.tempdir_list = [] | |||
| self.login_status = False | |||
| atexit.register(self.cleanup) | |||
| def is_connected(self): | |||
| url = f"{self.host}/auth/login_by_token" | |||
| response = requests.post(url) | |||
| if response.status_code == 404: | |||
| return False | |||
| return True | |||
| def login(self, email, token): | |||
| url = f"{self.host}/auth/login_by_token" | |||
| @@ -172,7 +179,7 @@ class LearnwareClient: | |||
| if result["code"] != 0: | |||
| raise Exception("update failed: " + json.dumps(result)) | |||
| def download_learnware(self, learnware_id, save_path): | |||
| def download_learnware(self, learnware_id: str, save_path: str): | |||
| url = f"{self.host}/engine/download_learnware" | |||
| response = requests.get( | |||
| @@ -132,7 +132,7 @@ class LearnwareMarket: | |||
| def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| try: | |||
| final_status = BaseChecker.NONUSABLE_LEARNWARE | |||
| if len(checker_names): | |||
| if checker_names is not None and len(checker_names): | |||
| with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, mode="r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| @@ -245,7 +245,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): | |||
| ret = [] | |||
| for idx in ids: | |||
| spec = self.learnware_list[idx].get_specification() | |||
| if is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec()): | |||
| if is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec(), verbose=False): | |||
| ret.append(idx) | |||
| return ret | |||
| @@ -1,9 +1,10 @@ | |||
| import traceback | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("hetero_utils") | |||
| def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool: | |||
| def is_hetero(stat_specs: dict, semantic_spec: dict, verbose=True) -> bool: | |||
| """Check if user_info satifies all the criteria required for enabling heterogeneous learnware search | |||
| Parameters | |||
| @@ -35,15 +36,17 @@ def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool: | |||
| semantic_decription_feature_num = len(semantic_input_description["Description"]) | |||
| if semantic_decription_feature_num <= 0: | |||
| logger.warning("At least one of Input.Description in semantic spec should be provides.") | |||
| if verbose: | |||
| logger.warning("At least one of Input.Description in semantic spec should be provides.") | |||
| return False | |||
| if table_input_shape != semantic_description_dim: | |||
| logger.warning("User data feature dimensions mismatch with semantic specification.") | |||
| if verbose: | |||
| logger.warning("User data feature dimensions mismatch with semantic specification.") | |||
| return False | |||
| return True | |||
| except Exception as e: | |||
| logger.warning(f"Invalid heterogeneous search information provided due to {e}. Use homogeneous search instead.") | |||
| except Exception as err: | |||
| if verbose: | |||
| logger.warning(f"Invalid heterogeneous search information provided.") | |||
| return False | |||
| @@ -1,9 +1,10 @@ | |||
| from .base import LearnwareMarket | |||
| from .classes import CondaChecker | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None): | |||
| def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False): | |||
| organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs | |||
| searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs | |||
| checker_kwargs = {} if checker_kwargs is None else checker_kwargs | |||
| @@ -11,7 +12,7 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||
| if name == "easy": | |||
| easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) | |||
| easy_searcher = EasySearcher(organizer=easy_organizer) | |||
| easy_checker_list = [EasySemanticChecker(), EasyStatChecker()] | |||
| easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||
| market_component = { | |||
| "organizer": easy_organizer, | |||
| "searcher": easy_searcher, | |||
| @@ -20,7 +21,7 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||
| elif name == "hetero": | |||
| hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) | |||
| hetero_searcher = HeteroSearcher(organizer=hetero_organizer) | |||
| hetero_checker_list = [EasySemanticChecker(), EasyStatChecker()] | |||
| hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||
| market_component = { | |||
| "organizer": hetero_organizer, | |||
| @@ -40,9 +41,10 @@ def instantiate_learnware_market( | |||
| organizer_kwargs: dict = None, | |||
| searcher_kwargs: dict = None, | |||
| checker_kwargs: dict = None, | |||
| conda_checker: bool = False, | |||
| **kwargs, | |||
| ): | |||
| market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs) | |||
| market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker) | |||
| return LearnwareMarket( | |||
| organizer=market_componets["organizer"], | |||
| searcher=market_componets["searcher"], | |||
| @@ -0,0 +1,84 @@ | |||
| import os | |||
| import unittest | |||
| import tempfile | |||
| import logging | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.learnware import Learnware | |||
| from learnware.client import LearnwareClient | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker | |||
| from learnware.config import C | |||
| class TestSearch(unittest.TestCase): | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls): | |||
| cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True) | |||
| if cls.client.is_connected(): | |||
| cls._build_learnware_market() | |||
| @classmethod | |||
| def _build_learnware_market(cls): | |||
| table_learnware_ids = ["00001951", "00001980", "00001987"] | |||
| image_learnware_ids = ["00000851", "00000858", "00000841"] | |||
| text_learnware_ids = ["00000652", "00000637"] | |||
| learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids | |||
| with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir: | |||
| for learnware_id in learnware_ids: | |||
| learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip") | |||
| try: | |||
| cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath) | |||
| semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec() | |||
| except Exception: | |||
| print("'learnware_id' is passed due to the network problem.") | |||
| cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"]) | |||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||
| def test_image_search(self): | |||
| learnware_id = "00000619" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_image_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||
| def test_text_search(self): | |||
| learnware_id = "00000653" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_text_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||
| def test_table_search(self): | |||
| learnware_id = "00001950" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_table_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestSearch("test_image_search")) | |||
| _suite.addTest(TestSearch("test_text_search")) | |||
| _suite.addTest(TestSearch("test_table_search")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||