You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_upload.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os
  2. import json
  3. import unittest
  4. import tempfile
  5. from learnware.client import LearnwareClient
  6. from learnware.specification import generate_semantic_spec
  7. class TestUpload(unittest.TestCase):
  8. client = LearnwareClient()
  9. @classmethod
  10. def setUpClass(cls) -> None:
  11. config_path = os.path.join(os.path.dirname(__file__), "config.json")
  12. if not os.path.exists(config_path):
  13. data = {"email": None, "token": None}
  14. with open(config_path, "w") as file:
  15. json.dump(data, file)
  16. with open(config_path, "r") as file:
  17. data = json.load(file)
  18. email = data.get("email")
  19. token = data.get("token")
  20. if email is None or token is None:
  21. print("Please set email and token in config.json.")
  22. else:
  23. cls.client.login(email, token)
  24. @unittest.skipIf(not client.is_login(), "Client doest not login!")
  25. def test_upload(self):
  26. input_description = {
  27. "Dimension": 13,
  28. "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
  29. }
  30. output_description = {
  31. "Dimension": 1,
  32. "Description": {
  33. "0": "the probability of being a cat",
  34. },
  35. }
  36. semantic_spec = generate_semantic_spec(
  37. name="learnware_example",
  38. description="Just a example for uploading a learnware",
  39. data_type="Table",
  40. task_type="Classification",
  41. library_type="Scikit-learn",
  42. scenarios=["Business", "Financial"],
  43. input_description=input_description,
  44. output_description=output_description,
  45. )
  46. assert isinstance(semantic_spec, dict)
  47. download_learnware_id = "00000084"
  48. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  49. zip_path = os.path.join(tempdir, f"test.zip")
  50. self.client.download_learnware(download_learnware_id, zip_path)
  51. learnware_id = self.client.upload_learnware(
  52. learnware_zip_path=zip_path, semantic_specification=semantic_spec
  53. )
  54. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  55. assert learnware_id in uploaded_ids
  56. self.client.delete_learnware(learnware_id)
  57. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  58. assert learnware_id not in uploaded_ids
  59. def suite():
  60. _suite = unittest.TestSuite()
  61. _suite.addTest(TestUpload("test_upload"))
  62. return _suite
  63. if __name__ == "__main__":
  64. runner = unittest.TextTestRunner()
  65. runner.run(suite())