You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_upload.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. import json
  3. import unittest
  4. import tempfile
  5. from learnware.client import LearnwareClient
  6. from learnware.specification import generate_semantic_spec
  7. class TestUpload(unittest.TestCase):
  8. client = LearnwareClient()
  9. @classmethod
  10. def setUpClass(cls) -> None:
  11. config_path = os.path.join(os.path.dirname(__file__), "config.json")
  12. if not os.path.exists(config_path):
  13. data = {"email": None, "token": None}
  14. with open(config_path, "w") as file:
  15. json.dump(data, file)
  16. with open(config_path, "r") as file:
  17. data = json.load(file)
  18. email = data.get("email")
  19. token = data.get("token")
  20. if email is None or token is None:
  21. print("Please set email and token in config.json.")
  22. else:
  23. cls.client.login(email, token)
  24. def _skip_test(self):
  25. if not self.client.is_login():
  26. print("Client does not login!")
  27. return True
  28. return False
  29. def test_upload(self):
  30. if not self._skip_test():
  31. input_description = {
  32. "Dimension": 13,
  33. "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"},
  34. }
  35. output_description = {
  36. "Dimension": 2,
  37. "Description": {"0": "cat", "1": "not cat"},
  38. }
  39. semantic_spec = generate_semantic_spec(
  40. name="learnware_example",
  41. description="Just a example for uploading a learnware",
  42. data_type="Table",
  43. task_type="Classification",
  44. library_type="Scikit-learn",
  45. scenarios=["Business", "Financial"],
  46. license="MIT",
  47. input_description=input_description,
  48. output_description=output_description,
  49. )
  50. assert isinstance(semantic_spec, dict)
  51. download_learnware_id = "00000084"
  52. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  53. zip_path = os.path.join(tempdir, "test.zip")
  54. self.client.download_learnware(download_learnware_id, zip_path)
  55. learnware_id = self.client.upload_learnware(
  56. learnware_zip_path=zip_path, semantic_specification=semantic_spec
  57. )
  58. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  59. assert learnware_id in uploaded_ids
  60. self.client.delete_learnware(learnware_id)
  61. uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()]
  62. assert learnware_id not in uploaded_ids
  63. def suite():
  64. _suite = unittest.TestSuite()
  65. _suite.addTest(TestUpload("test_upload"))
  66. return _suite
  67. if __name__ == "__main__":
  68. runner = unittest.TextTestRunner()
  69. runner.run(suite())