You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_workflow.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import logging
  2. import os
  3. import pickle
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. import numpy as np
  8. from sklearn import svm
  9. from sklearn.datasets import load_digits
  10. from sklearn.model_selection import train_test_split
  11. import learnware
  12. from learnware.market import BaseUserInfo, instantiate_learnware_market
  13. from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, JobSelectorReuser
  14. from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
  15. from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
  16. learnware.init(logging_level=logging.WARNING)
  17. curr_root = os.path.dirname(os.path.abspath(__file__))
  18. class TestWorkflow(unittest.TestCase):
  19. universal_semantic_config = {
  20. "data_type": "Table",
  21. "task_type": "Classification",
  22. "library_type": "Scikit-learn",
  23. "scenarios": "Education",
  24. "license": "MIT",
  25. }
  26. def _init_learnware_market(self):
  27. """initialize learnware market"""
  28. easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True)
  29. return easy_market
  30. def test_prepare_learnware_randomly(self, learnware_num=5):
  31. self.zip_path_list = []
  32. X, y = load_digits(return_X_y=True)
  33. for i in range(learnware_num):
  34. learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool")
  35. os.makedirs(learnware_pool_dirpath, exist_ok=True)
  36. learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i))
  37. print("Preparing Learnware: %d" % (i))
  38. data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
  39. clf = svm.SVC(kernel="linear", probability=True)
  40. clf.fit(data_X, data_y)
  41. pickle_filepath = os.path.join(learnware_pool_dirpath, "model.pkl")
  42. with open(pickle_filepath, "wb") as fout:
  43. pickle.dump(clf, fout)
  44. spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
  45. spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
  46. spec.save(spec_filepath)
  47. LearnwareTemplate.generate_learnware_zipfile(
  48. learnware_zippath=learnware_zippath,
  49. model_template=PickleModelTemplate(
  50. pickle_filepath=pickle_filepath,
  51. model_kwargs={"input_shape": (64,), "output_shape": (10,), "predict_method": "predict_proba"},
  52. ),
  53. stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
  54. requirements=["scikit-learn==0.22"],
  55. )
  56. self.zip_path_list.append(learnware_zippath)
  57. def test_upload_delete_learnware(self, learnware_num=5, delete=True):
  58. easy_market = self._init_learnware_market()
  59. self.test_prepare_learnware_randomly(learnware_num)
  60. self.learnware_num = learnware_num
  61. print("Total Item:", len(easy_market))
  62. assert len(easy_market) == 0, "The market should be empty!"
  63. for idx, zip_path in enumerate(self.zip_path_list):
  64. semantic_spec = generate_semantic_spec(
  65. name=f"learnware_{idx}",
  66. description=f"test_learnware_number_{idx}",
  67. input_description={
  68. "Dimension": 64,
  69. "Description": {
  70. f"{i}": f"The value in the grid {i // 8}{i % 8} of the image of hand-written digit."
  71. for i in range(64)
  72. },
  73. },
  74. output_description={
  75. "Dimension": 10,
  76. "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)},
  77. },
  78. **self.universal_semantic_config,
  79. )
  80. easy_market.add_learnware(zip_path, semantic_spec)
  81. print("Total Item:", len(easy_market))
  82. assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  83. curr_inds = easy_market.get_learnware_ids()
  84. print("Available ids After Uploading Learnwares:", curr_inds)
  85. assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  86. if delete:
  87. for learnware_id in curr_inds:
  88. easy_market.delete_learnware(learnware_id)
  89. self.learnware_num -= 1
  90. assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  91. curr_inds = easy_market.get_learnware_ids()
  92. print("Available ids After Deleting Learnwares:", curr_inds)
  93. assert len(curr_inds) == 0, "The market should be empty!"
  94. return easy_market
  95. def test_search_semantics(self, learnware_num=5):
  96. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  97. print("Total Item:", len(easy_market))
  98. assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  99. with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder:
  100. with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj:
  101. zip_obj.extractall(path=test_folder)
  102. semantic_spec = generate_semantic_spec(
  103. name=f"learnware_{learnware_num - 1}",
  104. description=f"test_learnware_number_{learnware_num - 1}",
  105. **self.universal_semantic_config,
  106. )
  107. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  108. search_result = easy_market.search_learnware(user_info)
  109. single_result = search_result.get_single_results()
  110. print("Search result:")
  111. for search_item in single_result:
  112. print("Choose learnware:", search_item.learnware.id)
  113. def test_stat_search(self, learnware_num=5):
  114. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  115. print("Total Item:", len(easy_market))
  116. with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder:
  117. for idx, zip_path in enumerate(self.zip_path_list):
  118. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  119. zip_obj.extractall(path=test_folder)
  120. user_spec = RKMETableSpecification()
  121. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  122. user_semantic = generate_semantic_spec(**self.universal_semantic_config)
  123. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
  124. search_results = easy_market.search_learnware(user_info)
  125. single_result = search_results.get_single_results()
  126. multiple_result = search_results.get_multiple_results()
  127. assert len(single_result) >= 1, "Statistical search failed!"
  128. print(f"search result of user{idx}:")
  129. for search_item in single_result:
  130. print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}")
  131. for mixture_item in multiple_result:
  132. print(f"mixture_score: {mixture_item.score}\n")
  133. mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares])
  134. print(f"mixture_learnware: {mixture_id}\n")
  135. def test_learnware_reuse(self, learnware_num=5):
  136. easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  137. print("Total Item:", len(easy_market))
  138. X, y = load_digits(return_X_y=True)
  139. train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True)
  140. stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0)
  141. user_semantic = generate_semantic_spec(**self.universal_semantic_config)
  142. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})
  143. search_results = easy_market.search_learnware(user_info)
  144. multiple_result = search_results.get_multiple_results()
  145. mixture_item = multiple_result[0]
  146. # Based on user information, the learnware market returns a list of learnwares (learnware_list)
  147. # Use jobselector reuser to reuse the searched learnwares to make prediction
  148. reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares)
  149. job_selector_predict_y = reuse_job_selector.predict(user_data=data_X)
  150. # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
  151. reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob")
  152. ensemble_predict_y = reuse_ensemble.predict(user_data=data_X)
  153. # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
  154. reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification")
  155. reuse_ensemble.fit(train_X[-200:], train_y[-200:])
  156. ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X)
  157. # Use feature augment reuser to reuse the searched learnwares to make prediction
  158. reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification")
  159. reuse_feature_augment.fit(train_X[-200:], train_y[-200:])
  160. feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X)
  161. print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y))
  162. print("Averaging Reuser Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y))
  163. print("Ensemble Pruning Reuser Acc:", np.sum(ensemble_pruning_predict_y == data_y) / len(data_y))
  164. print("Feature Augment Reuser Acc:", np.sum(feature_augment_predict_y == data_y) / len(data_y))
  165. def suite():
  166. _suite = unittest.TestSuite()
  167. # _suite.addTest(TestWorkflow("test_prepare_learnware_randomly"))
  168. # _suite.addTest(TestWorkflow("test_upload_delete_learnware"))
  169. _suite.addTest(TestWorkflow("test_search_semantics"))
  170. _suite.addTest(TestWorkflow("test_stat_search"))
  171. _suite.addTest(TestWorkflow("test_learnware_reuse"))
  172. return _suite
  173. if __name__ == "__main__":
  174. runner = unittest.TextTestRunner(verbosity=2)
  175. runner.run(suite())