diff --git a/tests/test_function/test_search.py b/tests/test_function/test_search.py index 9aa1b42..c006b0d 100644 --- a/tests/test_function/test_search.py +++ b/tests/test_function/test_search.py @@ -4,6 +4,7 @@ import tempfile import logging import learnware + learnware.init(logging_level=logging.WARNING) from learnware.learnware import Learnware @@ -11,9 +12,10 @@ from learnware.client import LearnwareClient from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker from learnware.config import C + class TestSearch(unittest.TestCase): client = LearnwareClient() - + @classmethod def setUpClass(cls): cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True) @@ -31,46 +33,62 @@ class TestSearch(unittest.TestCase): learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip") try: cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath) - semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec() + semantic_spec = ( + cls.client.load_learnware(learnware_path=learnware_zippath) + .get_specification() + .get_semantic_spec() + ) except Exception: print("'learnware_id' is passed due to the network problem.") - cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"]) - - @unittest.skipIf(not client.is_connected(), "Client can not connect!") + cls.market.add_learnware( + learnware_zippath, + learnware_id=learnware_id, + semantic_spec=semantic_spec, + checker_names=["EasySemanticChecker"], + ) + + def _skip_test(self): + if not self.client.is_connected(): + print("Client can not connect!") + return True + return False + def test_image_search(self): - learnware_id = "00000619" - try: - learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) - except Exception: - print("'test_image_search' is passed due to the network problem.") - user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) - search_result = self.market.search_learnware(user_info) - print("Single Search Results:", search_result.get_single_results()) - print("Multiple Search Results:", search_result.get_multiple_results()) - - @unittest.skipIf(not client.is_connected(), "Client can not connect!") + if not self._skip_test(): + learnware_id = "00000619" + try: + learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) + except Exception: + print("'test_image_search' is passed due to the network problem.") + user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) + search_result = self.market.search_learnware(user_info) + print("Single Search Results:", search_result.get_single_results()) + print("Multiple Search Results:", search_result.get_multiple_results()) + def test_text_search(self): - learnware_id = "00000653" - try: - learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) - except Exception: - print("'test_text_search' is passed due to the network problem.") - user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) - search_result = self.market.search_learnware(user_info) - print("Single Search Results:", search_result.get_single_results()) - print("Multiple Search Results:", search_result.get_multiple_results()) - - @unittest.skipIf(not client.is_connected(), "Client can not connect!") + if not self._skip_test(): + learnware_id = "00000653" + try: + learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) + except Exception: + print("'test_text_search' is passed due to the network problem.") + user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) + search_result = self.market.search_learnware(user_info) + print("Single Search Results:", search_result.get_single_results()) + print("Multiple Search Results:", search_result.get_multiple_results()) + def test_table_search(self): - learnware_id = "00001950" - try: - learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) - except Exception: - print("'test_table_search' is passed due to the network problem.") - user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) - search_result = self.market.search_learnware(user_info) - print("Single Search Results:", search_result.get_single_results()) - print("Multiple Search Results:", search_result.get_multiple_results()) + if not self._skip_test(): + learnware_id = "00001950" + try: + learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) + except Exception: + print("'test_table_search' is passed due to the network problem.") + user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) + search_result = self.market.search_learnware(user_info) + print("Single Search Results:", search_result.get_single_results()) + print("Multiple Search Results:", search_result.get_multiple_results()) + def suite(): _suite = unittest.TestSuite() @@ -79,6 +97,7 @@ def suite(): _suite.addTest(TestSearch("test_table_search")) return _suite + if __name__ == "__main__": runner = unittest.TextTestRunner() - runner.run(suite()) \ No newline at end of file + runner.run(suite()) diff --git a/tests/test_learnware_client/test_all_learnware.py b/tests/test_learnware_client/test_all_learnware.py index 94432ed..9fbcc41 100644 --- a/tests/test_learnware_client/test_all_learnware.py +++ b/tests/test_learnware_client/test_all_learnware.py @@ -9,13 +9,14 @@ from learnware.client import LearnwareClient from learnware.specification import generate_semantic_spec from learnware.market import BaseUserInfo + class TestAllLearnware(unittest.TestCase): client = LearnwareClient() - + @classmethod def setUpClass(cls) -> None: config_path = os.path.join(os.path.dirname(__file__), "config.json") - + if not os.path.exists(config_path): data = {"email": None, "token": None} with open(config_path, "w") as file: @@ -25,40 +26,46 @@ class TestAllLearnware(unittest.TestCase): data = json.load(file) email = data.get("email") token = data.get("token") - + if email is None or token is None: print("Please set email and token in config.json.") else: cls.client.login(email, token) - @unittest.skipIf(not client.is_login(), "Client doest not login!") + def _skip_test(self): + if not self.client.is_login(): + print("Client does not login!") + return True + return False + def test_all_learnware(self): - max_learnware_num = 2000 - semantic_spec = generate_semantic_spec() - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) - result = self.client.search_learnware(user_info, page_size=max_learnware_num) - - learnware_ids = result["single"]["learnware_ids"] - keys = [key for key in result["single"]["semantic_specifications"][0]] - print(f"result size: {len(learnware_ids)}") - print(f"key in result: {keys}") - - failed_ids = [] - with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: - for idx in learnware_ids: - zip_path = os.path.join(tempdir, f"test_{idx}.zip") - self.client.download_learnware(idx, zip_path) - with zipfile.ZipFile(zip_path, "r") as zip_file: - with zip_file.open("semantic_specification.json") as json_file: - semantic_spec = json.load(json_file) - try: - LearnwareClient.check_learnware(zip_path, semantic_spec) - print(f"check learnware {idx} succeed") - except: - failed_ids.append(idx) - print(f"check learnware {idx} failed!!!") - - print(f"The currently failed learnware ids: {failed_ids}") + if not self._skip_test(): + max_learnware_num = 2000 + semantic_spec = generate_semantic_spec() + user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) + result = self.client.search_learnware(user_info, page_size=max_learnware_num) + + learnware_ids = result["single"]["learnware_ids"] + keys = [key for key in result["single"]["semantic_specifications"][0]] + print(f"result size: {len(learnware_ids)}") + print(f"key in result: {keys}") + + failed_ids = [] + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + for idx in learnware_ids: + zip_path = os.path.join(tempdir, f"test_{idx}.zip") + self.client.download_learnware(idx, zip_path) + with zipfile.ZipFile(zip_path, "r") as zip_file: + with zip_file.open("semantic_specification.json") as json_file: + semantic_spec = json.load(json_file) + try: + LearnwareClient.check_learnware(zip_path, semantic_spec) + print(f"check learnware {idx} succeed") + except: + failed_ids.append(idx) + print(f"check learnware {idx} failed!!!") + + print(f"The currently failed learnware ids: {failed_ids}") def suite(): @@ -66,6 +73,7 @@ def suite(): _suite.addTest(TestAllLearnware("test_all_learnware")) return _suite + if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite()) diff --git a/tests/test_learnware_client/test_upload.py b/tests/test_learnware_client/test_upload.py index c2f9015..18e4055 100644 --- a/tests/test_learnware_client/test_upload.py +++ b/tests/test_learnware_client/test_upload.py @@ -6,13 +6,14 @@ import tempfile from learnware.client import LearnwareClient from learnware.specification import generate_semantic_spec + class TestUpload(unittest.TestCase): client = LearnwareClient() - + @classmethod def setUpClass(cls) -> None: config_path = os.path.join(os.path.dirname(__file__), "config.json") - + if not os.path.exists(config_path): data = {"email": None, "token": None} with open(config_path, "w") as file: @@ -22,50 +23,55 @@ class TestUpload(unittest.TestCase): data = json.load(file) email = data.get("email") token = data.get("token") - + if email is None or token is None: print("Please set email and token in config.json.") else: cls.client.login(email, token) - @unittest.skipIf(not client.is_login(), "Client doest not login!") + def _skip_test(self): + if not self.client.is_login(): + print("Client does not login!") + return True + return False + def test_upload(self): - input_description = { - "Dimension": 13, - "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, - } - output_description = { - "Dimension": 1, - "Description": { - "0": "the probability of being a cat", - }, - } - semantic_spec = generate_semantic_spec( - name="learnware_example", - description="Just a example for uploading a learnware", - data_type="Table", - task_type="Classification", - library_type="Scikit-learn", - scenarios=["Business", "Financial"], - input_description=input_description, - output_description=output_description, - ) - assert isinstance(semantic_spec, dict) - - download_learnware_id = "00000084" - with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: - zip_path = os.path.join(tempdir, f"test.zip") - self.client.download_learnware(download_learnware_id, zip_path) - learnware_id = self.client.upload_learnware( - learnware_zip_path=zip_path, semantic_specification=semantic_spec + if not self._skip_test(): + input_description = { + "Dimension": 13, + "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, + } + output_description = { + "Dimension": 2, + "Description": {"0": "cat", "1": "not cat"}, + } + semantic_spec = generate_semantic_spec( + name="learnware_example", + description="Just a example for uploading a learnware", + data_type="Table", + task_type="Classification", + library_type="Scikit-learn", + scenarios=["Business", "Financial"], + license="MIT", + input_description=input_description, + output_description=output_description, ) + assert isinstance(semantic_spec, dict) + + download_learnware_id = "00000084" + with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: + zip_path = os.path.join(tempdir, f"test.zip") + self.client.download_learnware(download_learnware_id, zip_path) + learnware_id = self.client.upload_learnware( + learnware_zip_path=zip_path, semantic_specification=semantic_spec + ) - uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] - assert learnware_id in uploaded_ids + uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] + assert learnware_id in uploaded_ids - self.client.delete_learnware(learnware_id) - uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] - assert learnware_id not in uploaded_ids + self.client.delete_learnware(learnware_id) + uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] + assert learnware_id not in uploaded_ids def suite(): @@ -73,6 +79,7 @@ def suite(): _suite.addTest(TestUpload("test_upload")) return _suite + if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite())