You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_load_learnware.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import unittest
  3. import numpy as np
  4. from learnware.client import LearnwareClient
  5. from learnware.reuse import AveragingReuser
  6. class TestLearnwareLoad(unittest.TestCase):
  7. def __init__(self, method_name="runTest", mode="all"):
  8. super(TestLearnwareLoad, self).__init__(method_name)
  9. self.runnable_options = []
  10. if mode in {"all", "conda"}:
  11. self.runnable_options.append("conda")
  12. if mode in {"all", "docker"}:
  13. self.runnable_options.append("docker")
  14. def setUp(self):
  15. self.client = LearnwareClient()
  16. root = os.path.dirname(__file__)
  17. self.learnware_ids = ["00000910", "00000899", "00000900"]
  18. self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]
  19. def _test_load_learnware_by_zippath(self, runnable_option):
  20. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  21. self.client.download_learnware(learnware_id, zip_path)
  22. learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option=runnable_option)
  23. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  24. input_array = np.random.random(size=(20, 13))
  25. print(reuser.predict(input_array))
  26. for learnware in learnware_list:
  27. print(learnware.id, learnware.predict(input_array))
  28. def _test_load_learnware_by_id(self, runnable_option):
  29. learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option)
  30. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  31. input_array = np.random.random(size=(20, 13))
  32. print(reuser.predict(input_array))
  33. for learnware in learnware_list:
  34. print(learnware.id, learnware.predict(input_array))
  35. def test_load_learnware_by_zippath(self):
  36. for runnable_option in self.runnable_options:
  37. self._test_load_learnware_by_zippath(runnable_option=runnable_option)
  38. def test_load_learnware_by_id(self):
  39. for runnable_option in self.runnable_options:
  40. self._test_load_learnware_by_id(runnable_option=runnable_option)
  41. def suite():
  42. _suite = unittest.TestSuite()
  43. _suite.addTest(TestLearnwareLoad("test_load_learnware_by_zippath", mode="all"))
  44. _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all"))
  45. return _suite
  46. if __name__ == "__main__":
  47. runner = unittest.TextTestRunner()
  48. runner.run(suite())