You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_search.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import logging
  2. import os
  3. import tempfile
  4. import unittest
  5. import learnware
  6. from learnware.client import LearnwareClient
  7. from learnware.learnware import Learnware
  8. from learnware.market import BaseUserInfo, instantiate_learnware_market
  9. learnware.init(logging_level=logging.WARNING)
  10. class TestSearch(unittest.TestCase):
  11. client = LearnwareClient()
  12. @classmethod
  13. def setUpClass(cls):
  14. cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True)
  15. if cls.client.is_connected():
  16. cls._build_learnware_market()
  17. @classmethod
  18. def _build_learnware_market(cls):
  19. table_learnware_ids = ["00001951", "00001980", "00001987"]
  20. image_learnware_ids = ["00000851", "00000858", "00000841"]
  21. text_learnware_ids = ["00000652", "00000637"]
  22. learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids
  23. with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir:
  24. for learnware_id in learnware_ids:
  25. learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip")
  26. try:
  27. cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath)
  28. semantic_spec = (
  29. cls.client.load_learnware(learnware_path=learnware_zippath)
  30. .get_specification()
  31. .get_semantic_spec()
  32. )
  33. except Exception:
  34. print("'learnware_id' is passed due to the network problem.")
  35. cls.market.add_learnware(
  36. learnware_zippath,
  37. learnware_id=learnware_id,
  38. semantic_spec=semantic_spec,
  39. checker_names=["EasySemanticChecker"],
  40. )
  41. def _skip_test(self):
  42. if not self.client.is_connected():
  43. print("Client can not connect!")
  44. return True
  45. return False
  46. def test_image_search(self):
  47. if not self._skip_test():
  48. learnware_id = "00000619"
  49. try:
  50. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  51. except Exception:
  52. print("'test_image_search' is passed due to the network problem.")
  53. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  54. search_result = self.market.search_learnware(user_info)
  55. print("Single Search Results:", search_result.get_single_results())
  56. print("Multiple Search Results:", search_result.get_multiple_results())
  57. def test_text_search(self):
  58. if not self._skip_test():
  59. learnware_id = "00000653"
  60. try:
  61. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  62. except Exception:
  63. print("'test_text_search' is passed due to the network problem.")
  64. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  65. search_result = self.market.search_learnware(user_info)
  66. print("Single Search Results:", search_result.get_single_results())
  67. print("Multiple Search Results:", search_result.get_multiple_results())
  68. def test_table_search(self):
  69. if not self._skip_test():
  70. learnware_id = "00001950"
  71. try:
  72. learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id)
  73. except Exception:
  74. print("'test_table_search' is passed due to the network problem.")
  75. user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec())
  76. search_result = self.market.search_learnware(user_info)
  77. print("Single Search Results:", search_result.get_single_results())
  78. print("Multiple Search Results:", search_result.get_multiple_results())
  79. def suite():
  80. _suite = unittest.TestSuite()
  81. _suite.addTest(TestSearch("test_image_search"))
  82. _suite.addTest(TestSearch("test_text_search"))
  83. _suite.addTest(TestSearch("test_table_search"))
  84. return _suite
  85. if __name__ == "__main__":
  86. runner = unittest.TextTestRunner()
  87. runner.run(suite())