| @@ -4,6 +4,7 @@ import tempfile | |||
| import logging | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.learnware import Learnware | |||
| @@ -11,9 +12,10 @@ from learnware.client import LearnwareClient | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker | |||
| from learnware.config import C | |||
| class TestSearch(unittest.TestCase): | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls): | |||
| cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True) | |||
| @@ -31,46 +33,62 @@ class TestSearch(unittest.TestCase): | |||
| learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip") | |||
| try: | |||
| cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath) | |||
| semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec() | |||
| semantic_spec = ( | |||
| cls.client.load_learnware(learnware_path=learnware_zippath) | |||
| .get_specification() | |||
| .get_semantic_spec() | |||
| ) | |||
| except Exception: | |||
| print("'learnware_id' is passed due to the network problem.") | |||
| cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"]) | |||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||
| cls.market.add_learnware( | |||
| learnware_zippath, | |||
| learnware_id=learnware_id, | |||
| semantic_spec=semantic_spec, | |||
| checker_names=["EasySemanticChecker"], | |||
| ) | |||
| def _skip_test(self): | |||
| if not self.client.is_connected(): | |||
| print("Client can not connect!") | |||
| return True | |||
| return False | |||
| def test_image_search(self): | |||
| learnware_id = "00000619" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_image_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||
| if not self._skip_test(): | |||
| learnware_id = "00000619" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_image_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def test_text_search(self): | |||
| learnware_id = "00000653" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_text_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||
| if not self._skip_test(): | |||
| learnware_id = "00000653" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_text_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def test_table_search(self): | |||
| learnware_id = "00001950" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_table_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| if not self._skip_test(): | |||
| learnware_id = "00001950" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_table_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| @@ -79,6 +97,7 @@ def suite(): | |||
| _suite.addTest(TestSearch("test_table_search")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| runner.run(suite()) | |||
| @@ -9,13 +9,14 @@ from learnware.client import LearnwareClient | |||
| from learnware.specification import generate_semantic_spec | |||
| from learnware.market import BaseUserInfo | |||
| class TestAllLearnware(unittest.TestCase): | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| config_path = os.path.join(os.path.dirname(__file__), "config.json") | |||
| if not os.path.exists(config_path): | |||
| data = {"email": None, "token": None} | |||
| with open(config_path, "w") as file: | |||
| @@ -25,40 +26,46 @@ class TestAllLearnware(unittest.TestCase): | |||
| data = json.load(file) | |||
| email = data.get("email") | |||
| token = data.get("token") | |||
| if email is None or token is None: | |||
| print("Please set email and token in config.json.") | |||
| else: | |||
| cls.client.login(email, token) | |||
| @unittest.skipIf(not client.is_login(), "Client doest not login!") | |||
| def _skip_test(self): | |||
| if not self.client.is_login(): | |||
| print("Client does not login!") | |||
| return True | |||
| return False | |||
| def test_all_learnware(self): | |||
| max_learnware_num = 2000 | |||
| semantic_spec = generate_semantic_spec() | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) | |||
| result = self.client.search_learnware(user_info, page_size=max_learnware_num) | |||
| learnware_ids = result["single"]["learnware_ids"] | |||
| keys = [key for key in result["single"]["semantic_specifications"][0]] | |||
| print(f"result size: {len(learnware_ids)}") | |||
| print(f"key in result: {keys}") | |||
| failed_ids = [] | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| for idx in learnware_ids: | |||
| zip_path = os.path.join(tempdir, f"test_{idx}.zip") | |||
| self.client.download_learnware(idx, zip_path) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| try: | |||
| LearnwareClient.check_learnware(zip_path, semantic_spec) | |||
| print(f"check learnware {idx} succeed") | |||
| except: | |||
| failed_ids.append(idx) | |||
| print(f"check learnware {idx} failed!!!") | |||
| print(f"The currently failed learnware ids: {failed_ids}") | |||
| if not self._skip_test(): | |||
| max_learnware_num = 2000 | |||
| semantic_spec = generate_semantic_spec() | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) | |||
| result = self.client.search_learnware(user_info, page_size=max_learnware_num) | |||
| learnware_ids = result["single"]["learnware_ids"] | |||
| keys = [key for key in result["single"]["semantic_specifications"][0]] | |||
| print(f"result size: {len(learnware_ids)}") | |||
| print(f"key in result: {keys}") | |||
| failed_ids = [] | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| for idx in learnware_ids: | |||
| zip_path = os.path.join(tempdir, f"test_{idx}.zip") | |||
| self.client.download_learnware(idx, zip_path) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| try: | |||
| LearnwareClient.check_learnware(zip_path, semantic_spec) | |||
| print(f"check learnware {idx} succeed") | |||
| except: | |||
| failed_ids.append(idx) | |||
| print(f"check learnware {idx} failed!!!") | |||
| print(f"The currently failed learnware ids: {failed_ids}") | |||
| def suite(): | |||
| @@ -66,6 +73,7 @@ def suite(): | |||
| _suite.addTest(TestAllLearnware("test_all_learnware")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -6,13 +6,14 @@ import tempfile | |||
| from learnware.client import LearnwareClient | |||
| from learnware.specification import generate_semantic_spec | |||
| class TestUpload(unittest.TestCase): | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| config_path = os.path.join(os.path.dirname(__file__), "config.json") | |||
| if not os.path.exists(config_path): | |||
| data = {"email": None, "token": None} | |||
| with open(config_path, "w") as file: | |||
| @@ -22,50 +23,55 @@ class TestUpload(unittest.TestCase): | |||
| data = json.load(file) | |||
| email = data.get("email") | |||
| token = data.get("token") | |||
| if email is None or token is None: | |||
| print("Please set email and token in config.json.") | |||
| else: | |||
| cls.client.login(email, token) | |||
| @unittest.skipIf(not client.is_login(), "Client doest not login!") | |||
| def _skip_test(self): | |||
| if not self.client.is_login(): | |||
| print("Client does not login!") | |||
| return True | |||
| return False | |||
| def test_upload(self): | |||
| input_description = { | |||
| "Dimension": 13, | |||
| "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, | |||
| } | |||
| output_description = { | |||
| "Dimension": 1, | |||
| "Description": { | |||
| "0": "the probability of being a cat", | |||
| }, | |||
| } | |||
| semantic_spec = generate_semantic_spec( | |||
| name="learnware_example", | |||
| description="Just a example for uploading a learnware", | |||
| data_type="Table", | |||
| task_type="Classification", | |||
| library_type="Scikit-learn", | |||
| scenarios=["Business", "Financial"], | |||
| input_description=input_description, | |||
| output_description=output_description, | |||
| ) | |||
| assert isinstance(semantic_spec, dict) | |||
| download_learnware_id = "00000084" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| zip_path = os.path.join(tempdir, f"test.zip") | |||
| self.client.download_learnware(download_learnware_id, zip_path) | |||
| learnware_id = self.client.upload_learnware( | |||
| learnware_zip_path=zip_path, semantic_specification=semantic_spec | |||
| if not self._skip_test(): | |||
| input_description = { | |||
| "Dimension": 13, | |||
| "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, | |||
| } | |||
| output_description = { | |||
| "Dimension": 2, | |||
| "Description": {"0": "cat", "1": "not cat"}, | |||
| } | |||
| semantic_spec = generate_semantic_spec( | |||
| name="learnware_example", | |||
| description="Just a example for uploading a learnware", | |||
| data_type="Table", | |||
| task_type="Classification", | |||
| library_type="Scikit-learn", | |||
| scenarios=["Business", "Financial"], | |||
| license="MIT", | |||
| input_description=input_description, | |||
| output_description=output_description, | |||
| ) | |||
| assert isinstance(semantic_spec, dict) | |||
| download_learnware_id = "00000084" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| zip_path = os.path.join(tempdir, f"test.zip") | |||
| self.client.download_learnware(download_learnware_id, zip_path) | |||
| learnware_id = self.client.upload_learnware( | |||
| learnware_zip_path=zip_path, semantic_specification=semantic_spec | |||
| ) | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id in uploaded_ids | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id in uploaded_ids | |||
| self.client.delete_learnware(learnware_id) | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id not in uploaded_ids | |||
| self.client.delete_learnware(learnware_id) | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id not in uploaded_ids | |||
| def suite(): | |||
| @@ -73,6 +79,7 @@ def suite(): | |||
| _suite.addTest(TestUpload("test_upload")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||