|
- import logging
- import os
- import tempfile
- import unittest
-
- import learnware
- from learnware.client import LearnwareClient
- from learnware.learnware import Learnware
- from learnware.market import BaseUserInfo, instantiate_learnware_market
-
- learnware.init(logging_level=logging.WARNING)
-
-
- class TestSearch(unittest.TestCase):
- client = LearnwareClient()
-
- @classmethod
- def setUpClass(cls):
- cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True)
- if cls.client.is_connected():
- cls._build_learnware_market()
-
- @classmethod
- def _build_learnware_market(cls):
- table_learnware_ids = ["00001951", "00001980", "00001987"]
- image_learnware_ids = ["00000851", "00000858", "00000841"]
- text_learnware_ids = ["00000652", "00000637"]
- learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids
- with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir:
- for learnware_id in learnware_ids:
- learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip")
- try:
- cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath)
- semantic_spec = (
- cls.client.load_learnware(learnware_path=learnware_zippath)
- .get_specification()
- .get_semantic_spec()
- )
- except Exception:
- print("'learnware_id' is passed due to the network problem.")
- cls.market.add_learnware(
- learnware_zippath,
- learnware_id=learnware_id,
- semantic_spec=semantic_spec,
- checker_names=["EasySemanticChecker"],
- )
-
- def _skip_test(self):
- if not self.client.is_connected():
- print("Client can not connect!")
- return True
- return False
-
- def test_image_search(self):
- if not self._skip_test():
- learnware_id = "00000619"
- try:
- learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
- except Exception:
- print("'test_image_search' is passed due to the network problem.")
- user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
- search_result = self.market.search_learnware(user_info)
- print("Single Search Results:", search_result.get_single_results())
- print("Multiple Search Results:", search_result.get_multiple_results())
-
- def test_text_search(self):
- if not self._skip_test():
- learnware_id = "00000653"
- try:
- learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
- except Exception:
- print("'test_text_search' is passed due to the network problem.")
- user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
- search_result = self.market.search_learnware(user_info)
- print("Single Search Results:", search_result.get_single_results())
- print("Multiple Search Results:", search_result.get_multiple_results())
-
- def test_table_search(self):
- if not self._skip_test():
- learnware_id = "00001950"
- try:
- learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
- except Exception:
- print("'test_table_search' is passed due to the network problem.")
- user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
- search_result = self.market.search_learnware(user_info)
- print("Single Search Results:", search_result.get_single_results())
- print("Multiple Search Results:", search_result.get_multiple_results())
-
-
- def suite():
- _suite = unittest.TestSuite()
- _suite.addTest(TestSearch("test_image_search"))
- _suite.addTest(TestSearch("test_text_search"))
- _suite.addTest(TestSearch("test_table_search"))
- return _suite
-
-
- if __name__ == "__main__":
- runner = unittest.TextTestRunner()
- runner.run(suite())
|