You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_workflow.py 11 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import sys
  2. import unittest
  3. import os
  4. import copy
  5. import joblib
  6. import zipfile
  7. import numpy as np
  8. from sklearn import svm
  9. from sklearn.datasets import load_digits
  10. from sklearn.model_selection import train_test_split
  11. from shutil import copyfile, rmtree
  12. import learnware
  13. from learnware.market import instantiate_learnware_market, BaseUserInfo
  14. from learnware.specification import RKMETableSpecification, generate_rkme_table_spec
  15. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser
  16. curr_root = os.path.dirname(os.path.abspath(__file__))
  17. user_semantic = {
  18. "Data": {"Values": ["Table"], "Type": "Class"},
  19. "Task": {
  20. "Values": ["Classification"],
  21. "Type": "Class",
  22. },
  23. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  24. "Scenario": {"Values": ["Education"], "Type": "Tag"},
  25. "Description": {"Values": "", "Type": "String"},
  26. "Name": {"Values": "", "Type": "String"},
  27. }
  28. class TestWorkflow(unittest.TestCase):
  29. def _init_learnware_market(self):
  30. """initialize learnware market"""
  31. easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True)
  32. return easy_market
  33. def test_prepare_learnware_randomly(self, learnware_num=5):
  34. self.zip_path_list = []
  35. X, y = load_digits(return_X_y=True)
  36. for i in range(learnware_num):
  37. dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
  38. os.makedirs(dir_path, exist_ok=True)
  39. print("Preparing Learnware: %d" % (i))
  40. data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
  41. clf = svm.SVC(kernel="linear", probability=True)
  42. clf.fit(data_X, data_y)
  43. joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))
  44. spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
  45. spec.save(os.path.join(dir_path, "svm.json"))
  46. init_file = os.path.join(dir_path, "__init__.py")
  47. copyfile(
  48. os.path.join(curr_root, "learnware_example/example_init.py"), init_file
  49. ) # cp example_init.py init_file
  50. yaml_file = os.path.join(dir_path, "learnware.yaml")
  51. copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file
  52. env_file = os.path.join(dir_path, "environment.yaml")
  53. copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file)
  54. zip_file = dir_path + ".zip"
  55. # zip -q -r -j zip_file dir_path
  56. with zipfile.ZipFile(zip_file, "w") as zip_obj:
  57. for foldername, subfolders, filenames in os.walk(dir_path):
  58. for filename in filenames:
  59. file_path = os.path.join(foldername, filename)
  60. zip_info = zipfile.ZipInfo(filename)
  61. zip_info.compress_type = zipfile.ZIP_STORED
  62. with open(file_path, "rb") as file:
  63. zip_obj.writestr(zip_info, file.read())
  64. rmtree(dir_path) # rm -r dir_path
  65. self.zip_path_list.append(zip_file)
  66. def test_upload_delete_learnware(self, learnware_num=5, delete=True):
  67. easy_market = self._init_learnware_market()
  68. self.test_prepare_learnware_randomly(learnware_num)
  69. self.learnware_num = learnware_num
  70. print("Total Item:", len(easy_market))
  71. assert len(easy_market) == 0, f"The market should be empty!"
  72. for idx, zip_path in enumerate(self.zip_path_list):
  73. semantic_spec = copy.deepcopy(user_semantic)
  74. semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
  75. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
  76. semantic_spec["Input"] = {
  77. "Dimension": 64,
  78. "Description": {
  79. f"{i}": f"The value in the grid {i // 8}{i % 8} of the image of hand-written digit."
  80. for i in range(64)
  81. },
  82. }
  83. semantic_spec["Output"] = {
  84. "Dimension": 10,
  85. "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)},
  86. }
  87. easy_market.add_learnware(zip_path, semantic_spec)
  88. print("Total Item:", len(easy_market))
  89. assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  90. curr_inds = easy_market.get_learnware_ids()
  91. print("Available ids After Uploading Learnwares:", curr_inds)
  92. assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  93. if delete:
  94. for learnware_id in curr_inds:
  95. easy_market.delete_learnware(learnware_id)
  96. self.learnware_num -= 1
  97. assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  98. curr_inds = easy_market.get_learnware_ids()
  99. print("Available ids After Deleting Learnwares:", curr_inds)
  100. assert len(curr_inds) == 0, f"The market should be empty!"
  101. return easy_market
  102. def test_search_semantics(self, learnware_num=5):
  103. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  104. print("Total Item:", len(easy_market))
  105. assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  106. test_folder = os.path.join(curr_root, "test_semantics")
  107. # unzip -o -q zip_path -d unzip_dir
  108. if os.path.exists(test_folder):
  109. rmtree(test_folder)
  110. os.makedirs(test_folder, exist_ok=True)
  111. with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj:
  112. zip_obj.extractall(path=test_folder)
  113. semantic_spec = copy.deepcopy(user_semantic)
  114. semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}"
  115. semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}"
  116. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  117. search_result = easy_market.search_learnware(user_info)
  118. single_result = search_result.get_single_results()
  119. print("User info:", user_info.get_semantic_spec())
  120. print(f"Search result:")
  121. for search_item in single_result:
  122. print("Choose learnware:", search_item.learnware.id, search_item.learnware.get_specification().get_semantic_spec())
  123. rmtree(test_folder) # rm -r test_folder
  124. def test_stat_search(self, learnware_num=5):
  125. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  126. print("Total Item:", len(easy_market))
  127. test_folder = os.path.join(curr_root, "test_stat")
  128. for idx, zip_path in enumerate(self.zip_path_list):
  129. unzip_dir = os.path.join(test_folder, f"{idx}")
  130. # unzip -o -q zip_path -d unzip_dir
  131. if os.path.exists(unzip_dir):
  132. rmtree(unzip_dir)
  133. os.makedirs(unzip_dir, exist_ok=True)
  134. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  135. zip_obj.extractall(path=unzip_dir)
  136. user_spec = RKMETableSpecification()
  137. user_spec.load(os.path.join(unzip_dir, "svm.json"))
  138. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
  139. search_results = easy_market.search_learnware(user_info)
  140. single_result = search_results.get_single_results()
  141. multiple_result = search_results.get_multiple_results()
  142. assert len(single_result) >= 1, f"Statistical search failed!"
  143. print(f"search result of user{idx}:")
  144. for search_item in single_result:
  145. print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}")
  146. for mixture_item in multiple_result:
  147. print(f"mixture_score: {mixture_item.score}\n")
  148. mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares])
  149. print(f"mixture_learnware: {mixture_id}\n")
  150. rmtree(test_folder) # rm -r test_folder
  151. def test_learnware_reuse(self, learnware_num=5):
  152. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  153. print("Total Item:", len(easy_market))
  154. X, y = load_digits(return_X_y=True)
  155. train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True)
  156. stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
  157. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})
  158. search_results = easy_market.search_learnware(user_info)
  159. multiple_result = search_results.get_multiple_results()
  160. mixture_item = multiple_result[0]
  161. # Based on user information, the learnware market returns a list of learnwares (learnware_list)
  162. # Use jobselector reuser to reuse the searched learnwares to make prediction
  163. reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares)
  164. job_selector_predict_y = reuse_job_selector.predict(user_data=data_X)
  165. # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
  166. reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob")
  167. ensemble_predict_y = reuse_ensemble.predict(user_data=data_X)
  168. # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
  169. reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification")
  170. reuse_ensemble.fit(train_X[-200:], train_y[-200:])
  171. ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X)
  172. # Use feature augment reuser to reuse the searched learnwares to make prediction
  173. reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification")
  174. reuse_feature_augment.fit(train_X[-200:], train_y[-200:])
  175. feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X)
  176. print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y))
  177. print("Averaging Reuser Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y))
  178. print("Ensemble Pruning Reuser Acc:", np.sum(ensemble_pruning_predict_y == data_y) / len(data_y))
  179. print("Feature Augment Reuser Acc:", np.sum(feature_augment_predict_y == data_y) / len(data_y))
  180. def suite():
  181. _suite = unittest.TestSuite()
  182. #_suite.addTest(TestWorkflow("test_prepare_learnware_randomly"))
  183. #_suite.addTest(TestWorkflow("test_upload_delete_learnware"))
  184. _suite.addTest(TestWorkflow("test_search_semantics"))
  185. _suite.addTest(TestWorkflow("test_stat_search"))
  186. _suite.addTest(TestWorkflow("test_learnware_reuse"))
  187. return _suite
  188. if __name__ == "__main__":
  189. runner = unittest.TextTestRunner()
  190. runner.run(suite())