You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_all_learnware.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. import json
  3. import zipfile
  4. import unittest
  5. import tempfile
  6. import argparse
  7. from learnware.client import LearnwareClient
  8. from learnware.specification import generate_semantic_spec
  9. from learnware.market import BaseUserInfo
  10. class TestAllLearnware(unittest.TestCase):
  11. client = LearnwareClient()
  12. @classmethod
  13. def setUpClass(cls) -> None:
  14. config_path = os.path.join(os.path.dirname(__file__), "config.json")
  15. if not os.path.exists(config_path):
  16. data = {"email": None, "token": None}
  17. with open(config_path, "w") as file:
  18. json.dump(data, file)
  19. with open(config_path, "r") as file:
  20. data = json.load(file)
  21. email = data.get("email")
  22. token = data.get("token")
  23. if email is None or token is None:
  24. print("Please set email and token in config.json.")
  25. else:
  26. cls.client.login(email, token)
  27. def _skip_test(self):
  28. if not self.client.is_login():
  29. print("Client does not login!")
  30. return True
  31. return False
  32. def test_all_learnware(self):
  33. if not self._skip_test():
  34. max_learnware_num = 2000
  35. semantic_spec = generate_semantic_spec()
  36. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={})
  37. result = self.client.search_learnware(user_info, page_size=max_learnware_num)
  38. learnware_ids = result["single"]["learnware_ids"]
  39. keys = [key for key in result["single"]["semantic_specifications"][0]]
  40. print(f"result size: {len(learnware_ids)}")
  41. print(f"key in result: {keys}")
  42. failed_ids = []
  43. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  44. for idx in learnware_ids:
  45. zip_path = os.path.join(tempdir, f"test_{idx}.zip")
  46. self.client.download_learnware(idx, zip_path)
  47. with zipfile.ZipFile(zip_path, "r") as zip_file:
  48. with zip_file.open("semantic_specification.json") as json_file:
  49. semantic_spec = json.load(json_file)
  50. try:
  51. LearnwareClient.check_learnware(zip_path, semantic_spec)
  52. print(f"check learnware {idx} succeed")
  53. except:
  54. failed_ids.append(idx)
  55. print(f"check learnware {idx} failed!!!")
  56. print(f"The currently failed learnware ids: {failed_ids}")
  57. def suite():
  58. _suite = unittest.TestSuite()
  59. _suite.addTest(TestAllLearnware("test_all_learnware"))
  60. return _suite
  61. if __name__ == "__main__":
  62. runner = unittest.TextTestRunner()
  63. runner.run(suite())