| @@ -0,0 +1,84 @@ | |||||
| import os | |||||
| import unittest | |||||
| import tempfile | |||||
| import logging | |||||
| import learnware | |||||
| learnware.init(logging_level=logging.WARNING) | |||||
| from learnware.learnware import Learnware | |||||
| from learnware.client import LearnwareClient | |||||
| from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker | |||||
| from learnware.config import C | |||||
| class TestSearch(unittest.TestCase): | |||||
| client = LearnwareClient() | |||||
| @classmethod | |||||
| def setUpClass(cls): | |||||
| cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True) | |||||
| if cls.client.is_connected(): | |||||
| cls._build_learnware_market() | |||||
| @classmethod | |||||
| def _build_learnware_market(cls): | |||||
| table_learnware_ids = ["00001951", "00001980", "00001987"] | |||||
| image_learnware_ids = ["00000851", "00000858", "00000841"] | |||||
| text_learnware_ids = ["00000652", "00000637"] | |||||
| learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids | |||||
| with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir: | |||||
| for learnware_id in learnware_ids: | |||||
| learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip") | |||||
| try: | |||||
| cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath) | |||||
| semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec() | |||||
| except Exception: | |||||
| print("'learnware_id' is passed due to the network problem.") | |||||
| cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"]) | |||||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||||
| def test_image_search(self): | |||||
| learnware_id = "00000619" | |||||
| try: | |||||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||||
| except Exception: | |||||
| print("'test_image_search' is passed due to the network problem.") | |||||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||||
| search_result = self.market.search_learnware(user_info) | |||||
| print("Single Search Results:", search_result.get_single_results()) | |||||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||||
| def test_text_search(self): | |||||
| learnware_id = "00000653" | |||||
| try: | |||||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||||
| except Exception: | |||||
| print("'test_text_search' is passed due to the network problem.") | |||||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||||
| search_result = self.market.search_learnware(user_info) | |||||
| print("Single Search Results:", search_result.get_single_results()) | |||||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||||
| @unittest.skipIf(not client.is_connected(), "Client can not connect!") | |||||
| def test_table_search(self): | |||||
| learnware_id = "00001950" | |||||
| try: | |||||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||||
| except Exception: | |||||
| print("'test_table_search' is passed due to the network problem.") | |||||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||||
| search_result = self.market.search_learnware(user_info) | |||||
| print("Single Search Results:", search_result.get_single_results()) | |||||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||||
| def suite(): | |||||
| _suite = unittest.TestSuite() | |||||
| _suite.addTest(TestSearch("test_image_search")) | |||||
| _suite.addTest(TestSearch("test_text_search")) | |||||
| _suite.addTest(TestSearch("test_table_search")) | |||||
| return _suite | |||||
| if __name__ == "__main__": | |||||
| runner = unittest.TextTestRunner() | |||||
| runner.run(suite()) | |||||