You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_search.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import unittest
  3. import tempfile
  4. import logging
  5. import learnware
  6. learnware.init(logging_level=logging.WARNING)
  7. from learnware.learnware import Learnware
  8. from learnware.client import LearnwareClient
  9. from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker
  10. from learnware.config import C
  11. class TestSearch(unittest.TestCase):
  12. client = LearnwareClient()
  13. @classmethod
  14. def setUpClass(cls):
  15. cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True)
  16. if cls.client.is_connected():
  17. cls._build_learnware_market()
  18. @classmethod
  19. def _build_learnware_market(cls):
  20. table_learnware_ids = ["00001951", "00001980", "00001987"]
  21. image_learnware_ids = ["00000851", "00000858", "00000841"]
  22. text_learnware_ids = ["00000652", "00000637"]
  23. learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids
  24. with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir:
  25. for learnware_id in learnware_ids:
  26. learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip")
  27. try:
  28. cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath)
  29. semantic_spec = cls.client.load_learnware(learnware_path=learnware_zippath).get_specification().get_semantic_spec()
  30. except Exception:
  31. print("'learnware_id' is passed due to the network problem.")
  32. cls.market.add_learnware(learnware_zippath, learnware_id=learnware_id, semantic_spec=semantic_spec, checker_names=["EasySemanticChecker"])
  33. @unittest.skipIf(not client.is_connected(), "Client can not connect!")
  34. def test_image_search(self):
  35. learnware_id = "00000619"
  36. try:
  37. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  38. except Exception:
  39. print("'test_image_search' is passed due to the network problem.")
  40. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  41. search_result = self.market.search_learnware(user_info)
  42. print("Single Search Results:", search_result.get_single_results())
  43. print("Multiple Search Results:", search_result.get_multiple_results())
  44. @unittest.skipIf(not client.is_connected(), "Client can not connect!")
  45. def test_text_search(self):
  46. learnware_id = "00000653"
  47. try:
  48. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  49. except Exception:
  50. print("'test_text_search' is passed due to the network problem.")
  51. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  52. search_result = self.market.search_learnware(user_info)
  53. print("Single Search Results:", search_result.get_single_results())
  54. print("Multiple Search Results:", search_result.get_multiple_results())
  55. @unittest.skipIf(not client.is_connected(), "Client can not connect!")
  56. def test_table_search(self):
  57. learnware_id = "00001950"
  58. try:
  59. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  60. except Exception:
  61. print("'test_table_search' is passed due to the network problem.")
  62. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  63. search_result = self.market.search_learnware(user_info)
  64. print("Single Search Results:", search_result.get_single_results())
  65. print("Multiple Search Results:", search_result.get_multiple_results())
  66. def suite():
  67. _suite = unittest.TestSuite()
  68. _suite.addTest(TestSearch("test_image_search"))
  69. _suite.addTest(TestSearch("test_text_search"))
  70. _suite.addTest(TestSearch("test_table_search"))
  71. return _suite
  72. if __name__ == "__main__":
  73. runner = unittest.TextTestRunner()
  74. runner.run(suite())