You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_load_docker.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import os
  2. import unittest
  3. import zipfile
  4. import numpy as np
  5. import learnware
  6. from learnware.learnware import get_learnware_from_dirpath
  7. from learnware.client import LearnwareClient
  8. from learnware.client.container import ModelCondaContainer, LearnwaresContainer
  9. from learnware.reuse import AveragingReuser
  10. class TestLearnwareLoad(unittest.TestCase):
  11. def setUp(self):
  12. unittest.TestCase.setUpClass()
  13. self.client = LearnwareClient()
  14. root = os.path.dirname(__file__)
  15. self.learnware_ids = ["00000084", "00000154", "00000155"]
  16. self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]
  17. def test_load_multi_learnware_by_zippath(self):
  18. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  19. self.client.download_learnware(learnware_id, zip_path)
  20. learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="docker")
  21. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  22. input_array = np.random.random(size=(20, 13))
  23. print(reuser.predict(input_array))
  24. for learnware in learnware_list:
  25. print(learnware.id, learnware.predict(input_array))
  26. def test_load_multi_learnware_by_id(self):
  27. learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="docker")
  28. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  29. input_array = np.random.random(size=(20, 13))
  30. print(reuser.predict(input_array))
  31. for learnware in learnware_list:
  32. print(learnware.id, learnware.predict(input_array))
  33. def test_load_single_learnware_by_id_pip(self):
  34. learnware_id = "00000147"
  35. learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker")
  36. input_array = np.random.random(size=(20, 23))
  37. print(learnware.predict(input_array))
  38. def test_load_single_learnware_by_id_conda(self):
  39. learnware_id = "00000148"
  40. learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker")
  41. input_array = np.random.random(size=(20, 204))
  42. print(learnware.predict(input_array))
  43. if __name__ == "__main__":
  44. unittest.main()