import os import unittest import tempfile import logging import learnware learnware.init(logging_level=logging.WARNING) from learnware.learnware import Learnware from learnware.client import LearnwareClient from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker from learnware.config import C class TestSearch(unittest.TestCase): client = LearnwareClient() @classmethod def setUpClass(cls): cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True) if cls.client.is_connected(): cls._build_learnware_market() @classmethod def _build_learnware_market(cls): table_learnware_ids = ["00001951", "00001980", "00001987"] image_learnware_ids = ["00000851", "00000858", "00000841"] text_learnware_ids = ["00000652", "00000637"] learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir: for learnware_id in learnware_ids: learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip") try: cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath) semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec() except Exception: print("'learnware_id' is passed due to the network problem.") cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"]) @unittest.skipIf(not client.is_connected(), "Client can not connect!") def test_image_search(self): learnware_id = "00000619" try: learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) except Exception: print("'test_image_search' is passed due to the network problem.") user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) search_result = self.market.search_learnware(user_info) print("Single Search Results:", search_result.get_single_results()) print("Multiple Search Results:", search_result.get_multiple_results()) @unittest.skipIf(not client.is_connected(), "Client can not connect!") def test_text_search(self): learnware_id = "00000653" try: learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) except Exception: print("'test_text_search' is passed due to the network problem.") user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) search_result = self.market.search_learnware(user_info) print("Single Search Results:", search_result.get_single_results()) print("Multiple Search Results:", search_result.get_multiple_results()) @unittest.skipIf(not client.is_connected(), "Client can not connect!") def test_table_search(self): learnware_id = "00001950" try: learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) except Exception: print("'test_table_search' is passed due to the network problem.") user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) search_result = self.market.search_learnware(user_info) print("Single Search Results:", search_result.get_single_results()) print("Multiple Search Results:", search_result.get_multiple_results()) def suite(): _suite = unittest.TestSuite() _suite.addTest(TestSearch("test_image_search")) _suite.addTest(TestSearch("test_text_search")) _suite.addTest(TestSearch("test_table_search")) return _suite if __name__ == "__main__": runner = unittest.TextTestRunner() runner.run(suite())