You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_workflow.py 9.5 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import sys
  2. import unittest
  3. import os
  4. import copy
  5. import joblib
  6. import zipfile
  7. import numpy as np
  8. from sklearn import svm
  9. from sklearn.datasets import load_digits
  10. from sklearn.model_selection import train_test_split
  11. from shutil import copyfile, rmtree
  12. import learnware
  13. from learnware.market import EasyMarket, BaseUserInfo
  14. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
  15. import learnware.specification as specification
  16. curr_root = os.path.dirname(os.path.abspath(__file__))
  17. user_semantic = {
  18. "Data": {"Values": ["Image"], "Type": "Class"},
  19. "Task": {
  20. "Values": ["Classification"],
  21. "Type": "Class",
  22. },
  23. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  24. "Scenario": {"Values": ["Education"], "Type": "Tag"},
  25. "Description": {"Values": "", "Type": "String"},
  26. "Name": {"Values": "", "Type": "String"},
  27. }
  28. class TestAllWorkflow(unittest.TestCase):
  29. @classmethod
  30. def setUpClass(cls) -> None:
  31. np.random.seed(2023)
  32. learnware.init()
  33. def _init_learnware_market(self):
  34. """initialize learnware market"""
  35. easy_market = EasyMarket(market_id="sklearn_digits", rebuild=True)
  36. return easy_market
  37. def test_prepare_learnware_randomly(self, learnware_num=5):
  38. self.zip_path_list = []
  39. X, y = load_digits(return_X_y=True)
  40. for i in range(learnware_num):
  41. dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
  42. os.makedirs(dir_path, exist_ok=True)
  43. print("Preparing Learnware: %d" % (i))
  44. data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
  45. clf = svm.SVC(kernel="linear", probability=True)
  46. clf.fit(data_X, data_y)
  47. joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))
  48. spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
  49. spec.save(os.path.join(dir_path, "svm.json"))
  50. init_file = os.path.join(dir_path, "__init__.py")
  51. copyfile(
  52. os.path.join(curr_root, "learnware_example/example_init.py"), init_file
  53. ) # cp example_init.py init_file
  54. yaml_file = os.path.join(dir_path, "learnware.yaml")
  55. copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file
  56. env_file = os.path.join(dir_path, "environment.yaml")
  57. copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file)
  58. zip_file = dir_path + ".zip"
  59. # zip -q -r -j zip_file dir_path
  60. with zipfile.ZipFile(zip_file, "w") as zip_obj:
  61. for foldername, subfolders, filenames in os.walk(dir_path):
  62. for filename in filenames:
  63. file_path = os.path.join(foldername, filename)
  64. zip_info = zipfile.ZipInfo(filename)
  65. zip_info.compress_type = zipfile.ZIP_STORED
  66. with open(file_path, "rb") as file:
  67. zip_obj.writestr(zip_info, file.read())
  68. rmtree(dir_path) # rm -r dir_path
  69. self.zip_path_list.append(zip_file)
  70. def test_upload_delete_learnware(self, learnware_num=5, delete=False):
  71. easy_market = self._init_learnware_market()
  72. self.test_prepare_learnware_randomly(learnware_num)
  73. print("Total Item:", len(easy_market))
  74. for idx, zip_path in enumerate(self.zip_path_list):
  75. semantic_spec = copy.deepcopy(user_semantic)
  76. semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
  77. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
  78. semantic_spec["Output"] = {"Dimension": 1, "Description": {"0": "The label of the hand-written digit."}}
  79. easy_market.add_learnware(zip_path, semantic_spec)
  80. print("Total Item:", len(easy_market))
  81. curr_inds = easy_market._get_ids()
  82. print("Available ids After Uploading Learnwares:", curr_inds)
  83. if delete:
  84. for learnware_id in curr_inds:
  85. easy_market.delete_learnware(learnware_id)
  86. curr_inds = easy_market._get_ids()
  87. print("Available ids After Deleting Learnwares:", curr_inds)
  88. return easy_market
  89. def test_search_semantics(self, learnware_num=5):
  90. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  91. print("Total Item:", len(easy_market))
  92. test_folder = os.path.join(curr_root, "test_semantics")
  93. # unzip -o -q zip_path -d unzip_dir
  94. if os.path.exists(test_folder):
  95. rmtree(test_folder)
  96. os.makedirs(test_folder, exist_ok=True)
  97. with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj:
  98. zip_obj.extractall(path=test_folder)
  99. semantic_spec = copy.deepcopy(user_semantic)
  100. semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}"
  101. semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}"
  102. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  103. _, single_learnware_list, _, _ = easy_market.search_learnware(user_info)
  104. print("User info:", user_info.get_semantic_spec())
  105. print(f"Search result:")
  106. for learnware in single_learnware_list:
  107. print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())
  108. rmtree(test_folder) # rm -r test_folder
  109. def test_stat_search(self, learnware_num=5):
  110. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  111. print("Total Item:", len(easy_market))
  112. test_folder = os.path.join(curr_root, "test_stat")
  113. for idx, zip_path in enumerate(self.zip_path_list):
  114. unzip_dir = os.path.join(test_folder, f"{idx}")
  115. # unzip -o -q zip_path -d unzip_dir
  116. if os.path.exists(unzip_dir):
  117. rmtree(unzip_dir)
  118. os.makedirs(unzip_dir, exist_ok=True)
  119. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  120. zip_obj.extractall(path=unzip_dir)
  121. user_spec = specification.RKMETableSpecification()
  122. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  123. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
  124. (
  125. sorted_score_list,
  126. single_learnware_list,
  127. mixture_score,
  128. mixture_learnware_list,
  129. ) = easy_market.search_learnware(user_info)
  130. print(f"search result of user{idx}:")
  131. for score, learnware in zip(sorted_score_list, single_learnware_list):
  132. print(f"score: {score}, learnware_id: {learnware.id}")
  133. print(f"mixture_score: {mixture_score}\n")
  134. mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
  135. print(f"mixture_learnware: {mixture_id}\n")
  136. rmtree(test_folder) # rm -r test_folder
  137. def test_learnware_reuse(self, learnware_num=5):
  138. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  139. print("Total Item:", len(easy_market))
  140. X, y = load_digits(return_X_y=True)
  141. train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True)
  142. stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
  143. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})
  144. _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info)
  145. # Based on user information, the learnware market returns a list of learnwares (learnware_list)
  146. # Use jobselector reuser to reuse the searched learnwares to make prediction
  147. reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list)
  148. job_selector_predict_y = reuse_job_selector.predict(user_data=data_X)
  149. # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
  150. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob")
  151. ensemble_predict_y = reuse_ensemble.predict(user_data=data_X)
  152. # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
  153. reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification")
  154. reuse_ensemble.fit(train_X[-200:], train_y[-200:])
  155. ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X)
  156. print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y))
  157. print("Averaging Reuser Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y))
  158. print("Ensemble Pruning Reuser Acc:", np.sum(ensemble_pruning_predict_y == data_y) / len(data_y))
  159. def suite():
  160. _suite = unittest.TestSuite()
  161. _suite.addTest(TestAllWorkflow("test_prepare_learnware_randomly"))
  162. _suite.addTest(TestAllWorkflow("test_upload_delete_learnware"))
  163. _suite.addTest(TestAllWorkflow("test_search_semantics"))
  164. _suite.addTest(TestAllWorkflow("test_stat_search"))
  165. _suite.addTest(TestAllWorkflow("test_learnware_reuse"))
  166. return _suite
  167. if __name__ == "__main__":
  168. runner = unittest.TextTestRunner()
  169. runner.run(suite())