You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_load_learnware.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. import unittest
  3. import argparse
  4. import numpy as np
  5. import learnware
  6. from learnware.learnware import get_learnware_from_dirpath
  7. from learnware.client import LearnwareClient
  8. from learnware.client.container import ModelCondaContainer, LearnwaresContainer
  9. from learnware.reuse import AveragingReuser
  10. class TestLearnwareLoadWithConda(unittest.TestCase):
  11. def setUp(self):
  12. self.client = LearnwareClient()
  13. root = os.path.dirname(__file__)
  14. self.learnware_ids = ["00000084", "00000154", "00000155"]
  15. self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]]
  16. self.runnable_option = "conda"
  17. #def test_load_single_learnware_by_zippath(self):
  18. # for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  19. # self.client.download_learnware(learnware_id, zip_path)
  20. #
  21. # learnware_list = [
  22. # self.client.load_learnware(learnware_path=zippath, runnable_option=self.runnable_option) for zippath in self.zip_paths
  23. # ]
  24. # reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  25. # input_array = np.random.random(size=(20, 13))
  26. # print(reuser.predict(input_array))
  27. #
  28. # for learnware in learnware_list:
  29. # print(learnware.id, learnware.predict(input_array))
  30. #
  31. #def test_load_multi_learnware_by_zippath(self):
  32. # for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths):
  33. # self.client.download_learnware(learnware_id, zip_path)
  34. #
  35. # learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option=self.runnable_option)
  36. # reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  37. # input_array = np.random.random(size=(20, 13))
  38. # print(reuser.predict(input_array))
  39. #
  40. # for learnware in learnware_list:
  41. # print(learnware.id, learnware.predict(input_array))
  42. #
  43. #def test_load_single_learnware_by_id(self):
  44. # learnware_list = [
  45. # self.client.load_learnware(learnware_id=idx, runnable_option=self.runnable_option) for idx in self.learnware_ids
  46. # ]
  47. # reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  48. # input_array = np.random.random(size=(20, 13))
  49. # print(reuser.predict(input_array))
  50. #
  51. # for learnware in learnware_list:
  52. # print(learnware.id, learnware.predict(input_array))
  53. #
  54. #def test_load_multi_learnware_by_id(self):
  55. # learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=self.runnable_option)
  56. # reuser = AveragingReuser(learnware_list, mode="vote_by_label")
  57. # input_array = np.random.random(size=(20, 13))
  58. # print(reuser.predict(input_array))
  59. #
  60. # for learnware in learnware_list:
  61. # print(learnware.id, learnware.predict(input_array))
  62. #
  63. def test_load_single_learnware_by_id_pip(self):
  64. learnware_id = "00000147"
  65. learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option=self.runnable_option)
  66. input_array = np.random.random(size=(20, 23))
  67. print(learnware.predict(input_array))
  68. #
  69. def test_load_single_learnware_by_id_conda(self):
  70. learnware_id = "00000148"
  71. learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option=self.runnable_option)
  72. input_array = np.random.random(size=(20, 204))
  73. print(learnware.predict(input_array))
  74. #
  75. #
  76. class TestLearnwareLoadWithDocker(TestLearnwareLoadWithConda):
  77. def setUp(self):
  78. super(TestLearnwareLoadWithDocker, self).setUp()
  79. self.runnable_option = "docker"
  80. def suite(mode):
  81. _suite = unittest.TestSuite()
  82. #_suite.addTest(TestLearnwareLoadWithDocker())
  83. if mode == "all" or mode == "conda":
  84. _suite.addTest(unittest.makeSuite(TestLearnwareLoadWithConda))
  85. if mode == "all" or mode == "docker":
  86. _suite.addTest(unittest.makeSuite(TestLearnwareLoadWithDocker))
  87. return _suite
  88. if __name__ == "__main__":
  89. parser = argparse.ArgumentParser()
  90. parser.add_argument("--mode", type=str, required=False, default="all", help="The mode to run load learnware, must be in ['all', 'conda', 'docker']")
  91. args = parser.parse_args()
  92. assert args.mode in {"all", "conda", "docker"}, f"The mode must be in ['all', 'conda', 'docker'], instead of '{args.mode}'"
  93. runner = unittest.TextTestRunner()
  94. runner.run(suite(args.mode))