You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_load_conda.py 3.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import os
  2. import unittest
  3. import zipfile
  4. import numpy as np
  5. import learnware
  6. from learnware.learnware import get_learnware_from_dirpath
  7. from learnware.client import LearnwareClient
  8. from learnware.client.container import ModelCondaContainer, LearnwaresContainer
  9. from learnware.reuse import AveragingReuser
  10. class TestLearnwareLoad(unittest.TestCase):
  11. def setUp(self):
  12. unittest.TestCase.setUpClass()
  13. self.client = LearnwareClient()
  14. root = os.path.dirname(__file__)
  15. self.learnware_ids = ["00000084", "00000154", "00000155"]
  16. self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]
  17. def test_load_single_learnware_by_zippath(self):
  18. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  19. self.client.download_learnware(learnware_id, zip_path)
  20. learnware_list = [
  21. self.client.load_learnware(learnware_path=zippath, runnable_option="conda") for zippath in self.zip_paths
  22. ]
  23. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  24. input_array = np.random.random(size=(20, 13))
  25. print(reuser.predict(input_array))
  26. for learnware in learnware_list:
  27. print(learnware.id, learnware.predict(input_array))
  28. def test_load_multi_learnware_by_zippath(self):
  29. for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  30. self.client.download_learnware(learnware_id, zip_path)
  31. learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda")
  32. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  33. input_array = np.random.random(size=(20, 13))
  34. print(reuser.predict(input_array))
  35. for learnware in learnware_list:
  36. print(learnware.id, learnware.predict(input_array))
  37. def test_load_single_learnware_by_id(self):
  38. learnware_list = [
  39. self.client.load_learnware(learnware_id=idx, runnable_option="conda") for idx in self.learnware_ids
  40. ]
  41. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  42. input_array = np.random.random(size=(20, 13))
  43. print(reuser.predict(input_array))
  44. for learnware in learnware_list:
  45. print(learnware.id, learnware.predict(input_array))
  46. def test_load_multi_learnware_by_id(self):
  47. learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda")
  48. reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  49. input_array = np.random.random(size=(20, 13))
  50. print(reuser.predict(input_array))
  51. for learnware in learnware_list:
  52. print(learnware.id, learnware.predict(input_array))
  53. def test_load_single_learnware_by_id_pip(self):
  54. learnware_id = "00000147"
  55. learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda")
  56. input_array = np.random.random(size=(20, 23))
  57. print(learnware.predict(input_array))
  58. def test_load_single_learnware_by_id_conda(self):
  59. learnware_id = "00000148"
  60. learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda")
  61. input_array = np.random.random(size=(20, 204))
  62. print(learnware.predict(input_array))
  63. if __name__ == "__main__":
  64. unittest.main()