You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_container.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import unittest
  2. import numpy as np
  3. from learnware.client import LearnwareClient
  4. from learnware.client.container import LearnwaresContainer
  5. class TestContainer(unittest.TestCase):
  6. def __init__(self, method_name="runTest", mode="all"):
  7. super(TestContainer, self).__init__(method_name)
  8. self.modes = []
  9. if mode in {"all", "conda"}:
  10. self.modes.append("conda")
  11. if mode in {"all", "docker"}:
  12. self.modes.append("docker")
  13. def setUp(self):
  14. self.client = LearnwareClient()
  15. def _test_container_with_pip(self, mode):
  16. learnware_id = "00000147"
  17. learnware = self.client.load_learnware(learnware_id=learnware_id)
  18. with LearnwaresContainer(learnware, ignore_error=False, mode=mode) as env_container:
  19. learnware = env_container.get_learnwares_with_container()[0]
  20. input_array = np.random.random(size=(20, 23))
  21. print(learnware.predict(input_array))
  22. def _test_container_with_conda(self, mode):
  23. learnware_id = "00000148"
  24. learnware = self.client.load_learnware(learnware_id=learnware_id)
  25. with LearnwaresContainer(learnware, ignore_error=False, mode=mode) as env_container:
  26. learnware = env_container.get_learnwares_with_container()[0]
  27. input_array = np.random.random(size=(20, 204))
  28. print(learnware.predict(input_array))
  29. def test_container_with_pip(self):
  30. for mode in self.modes:
  31. self._test_container_with_pip(mode=mode)
  32. def test_container_with_conda(self):
  33. for mode in self.modes:
  34. self._test_container_with_conda(mode=mode)
  35. def suite():
  36. _suite = unittest.TestSuite()
  37. _suite.addTest(TestContainer("test_container_with_pip", mode="all"))
  38. _suite.addTest(TestContainer("test_container_with_conda", mode="all"))
  39. return _suite
  40. if __name__ == "__main__":
  41. runner = unittest.TextTestRunner()
  42. runner.run(suite())