You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hetero_workflow.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import torch
  2. import pickle
  3. import unittest
  4. import os
  5. import logging
  6. import tempfile
  7. import zipfile
  8. from sklearn.linear_model import Ridge
  9. from sklearn.datasets import make_regression
  10. from shutil import copyfile, rmtree
  11. from sklearn.metrics import mean_squared_error
  12. import learnware
  13. learnware.init(logging_level=logging.WARNING)
  14. from learnware.market import instantiate_learnware_market, BaseUserInfo
  15. from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
  16. from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
  17. from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
  18. from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list
  19. curr_root = os.path.dirname(os.path.abspath(__file__))
  20. class TestHeteroWorkflow(unittest.TestCase):
  21. universal_semantic_config = {
  22. "data_type": "Table",
  23. "task_type": "Regression",
  24. "library_type": "Scikit-learn",
  25. "scenarios": "Education",
  26. "license": "MIT",
  27. }
  28. def _init_learnware_market(self, organizer_kwargs=None):
  29. """initialize learnware market"""
  30. hetero_market = instantiate_learnware_market(
  31. market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
  32. )
  33. return hetero_market
  34. def test_prepare_learnware_randomly(self, learnware_num=5):
  35. self.zip_path_list = []
  36. for i in range(learnware_num):
  37. learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
  38. os.makedirs(learnware_pool_dirpath, exist_ok=True)
  39. learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
  40. print("Preparing Learnware: %d" % (i))
  41. X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42)
  42. clf = Ridge(alpha=1.0)
  43. clf.fit(X, y)
  44. pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
  45. with open(pickle_filepath, "wb") as fout:
  46. pickle.dump(clf, fout)
  47. spec = generate_rkme_table_spec(X=X, gamma=0.1)
  48. spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
  49. spec.save(spec_filepath)
  50. LearnwareTemplate.generate_learnware_zipfile(
  51. learnware_zippath=learnware_zippath,
  52. model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}),
  53. stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
  54. requirements=["scikit-learn==0.22"],
  55. )
  56. self.zip_path_list.append(learnware_zippath)
  57. def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
  58. self.test_prepare_learnware_randomly(learnware_num)
  59. self.learnware_num = learnware_num
  60. print("Total Item:", len(hetero_market))
  61. assert len(hetero_market) == 0, f"The market should be empty!"
  62. for idx, zip_path in enumerate(self.zip_path_list):
  63. semantic_spec = generate_semantic_spec(
  64. name=f"learnware_{idx}",
  65. description=f"test_learnware_number_{idx}",
  66. input_description=input_description_list[idx % 2],
  67. output_description=output_description_list[idx % 2],
  68. **self.universal_semantic_config
  69. )
  70. hetero_market.add_learnware(zip_path, semantic_spec)
  71. print("Total Item:", len(hetero_market))
  72. assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  73. curr_inds = hetero_market.get_learnware_ids()
  74. print("Available ids After Uploading Learnwares:", curr_inds)
  75. assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  76. if delete:
  77. for learnware_id in curr_inds:
  78. hetero_market.delete_learnware(learnware_id)
  79. self.learnware_num -= 1
  80. assert (
  81. len(hetero_market) == self.learnware_num
  82. ), f"The number of learnwares must be {self.learnware_num}!"
  83. curr_inds = hetero_market.get_learnware_ids()
  84. print("Available ids After Deleting Learnwares:", curr_inds)
  85. assert len(curr_inds) == 0, f"The market should be empty!"
  86. return hetero_market
  87. def test_upload_delete_learnware(self, learnware_num=5, delete=True):
  88. hetero_market = self._init_learnware_market()
  89. return self._upload_delete_learnware(hetero_market, learnware_num, delete)
  90. def test_train_market_model(self, learnware_num=5, delete=False):
  91. hetero_market = self._init_learnware_market(
  92. organizer_kwargs={"auto_update": True, "auto_update_limit": learnware_num}
  93. )
  94. hetero_market = self._upload_delete_learnware(hetero_market, learnware_num, delete)
  95. # organizer=hetero_market.learnware_organizer
  96. # organizer.train(hetero_market.learnware_organizer.learnware_list.values())
  97. return hetero_market
  98. def test_search_semantics(self, learnware_num=5):
  99. hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  100. print("Total Item:", len(hetero_market))
  101. assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  102. semantic_spec = generate_semantic_spec(
  103. name=f"learnware_{learnware_num - 1}",
  104. **self.universal_semantic_config,
  105. )
  106. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  107. search_result = hetero_market.search_learnware(user_info)
  108. single_result = search_result.get_single_results()
  109. print(f"Search result1:")
  110. assert len(single_result) == 1, f"Exact semantic search failed!"
  111. for search_item in single_result:
  112. semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
  113. print("Choose learnware:", search_item.learnware.id)
  114. assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!"
  115. semantic_spec["Name"]["Values"] = "laernwaer"
  116. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  117. search_result = hetero_market.search_learnware(user_info)
  118. single_result = search_result.get_single_results()
  119. print(f"Search result2:")
  120. assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!"
  121. for search_item in single_result:
  122. print("Choose learnware:", search_item.learnware.id)
  123. def test_hetero_stat_search(self, learnware_num=5):
  124. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  125. print("Total Item:", len(hetero_market))
  126. user_dim = 15
  127. with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
  128. for idx, zip_path in enumerate(self.zip_path_list):
  129. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  130. zip_obj.extractall(path=test_folder)
  131. user_spec = RKMETableSpecification()
  132. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  133. z = user_spec.get_z()
  134. z = z[:, :user_dim]
  135. device = user_spec.device
  136. z = torch.tensor(z, device=device)
  137. user_spec.z = z
  138. print(">> normal case test:")
  139. semantic_spec = generate_semantic_spec(
  140. input_description={
  141. "Dimension": user_dim,
  142. "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
  143. },
  144. **self.universal_semantic_config,
  145. )
  146. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  147. search_result = hetero_market.search_learnware(user_info)
  148. single_result = search_result.get_single_results()
  149. multiple_result = search_result.get_multiple_results()
  150. print(f"search result of user{idx}:")
  151. for single_item in single_result:
  152. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  153. for multiple_item in multiple_result:
  154. print(
  155. f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
  156. )
  157. # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
  158. print(">> test for key 'Task' has empty 'Values':")
  159. semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}
  160. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  161. search_result = hetero_market.search_learnware(user_info)
  162. single_result = search_result.get_single_results()
  163. assert len(single_result) == 0, f"Statistical search failed!"
  164. # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
  165. print(">> delele key 'Task' test:")
  166. semantic_spec.pop("Task")
  167. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  168. search_result = hetero_market.search_learnware(user_info)
  169. single_result = search_result.get_single_results()
  170. assert len(single_result) == 0, f"Statistical search failed!"
  171. # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
  172. print(">> mismatch dim test")
  173. semantic_spec = generate_semantic_spec(
  174. input_description={
  175. "Dimension": user_dim - 2,
  176. "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)},
  177. },
  178. **self.universal_semantic_config,
  179. )
  180. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  181. search_result = hetero_market.search_learnware(user_info)
  182. single_result = search_result.get_single_results()
  183. assert len(single_result) == 0, f"Statistical search failed!"
  184. def test_homo_stat_search(self, learnware_num=5):
  185. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  186. print("Total Item:", len(hetero_market))
  187. with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
  188. for idx, zip_path in enumerate(self.zip_path_list):
  189. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  190. zip_obj.extractall(path=test_folder)
  191. user_spec = RKMETableSpecification()
  192. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  193. user_semantic = generate_semantic_spec(**self.universal_semantic_config)
  194. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
  195. search_result = hetero_market.search_learnware(user_info)
  196. single_result = search_result.get_single_results()
  197. multiple_result = search_result.get_multiple_results()
  198. assert len(single_result) >= 1, f"Statistical search failed!"
  199. print(f"search result of user{idx}:")
  200. for single_item in single_result:
  201. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  202. for multiple_item in multiple_result:
  203. print(f"mixture_score: {multiple_item.score}\n")
  204. mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
  205. print(f"mixture_learnware: {mixture_id}\n")
  206. def test_model_reuse(self, learnware_num=5):
  207. # generate toy regression problem
  208. X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)
  209. # generate rkme
  210. user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)
  211. # generate specification
  212. semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config)
  213. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  214. # learnware market search
  215. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  216. search_result = hetero_market.search_learnware(user_info)
  217. single_result = search_result.get_single_results()
  218. multiple_result = search_result.get_multiple_results()
  219. # print search results
  220. for single_item in single_result:
  221. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  222. for multiple_item in multiple_result:
  223. print(
  224. f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
  225. )
  226. # single model reuse
  227. hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
  228. hetero_learnware.align(user_spec, X[:100], y[:100])
  229. single_predict_y = hetero_learnware.predict(X)
  230. # multi model reuse
  231. hetero_learnware_list = []
  232. for learnware in multiple_result[0].learnwares:
  233. hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression")
  234. hetero_learnware.align(user_spec, X[:100], y[:100])
  235. hetero_learnware_list.append(hetero_learnware)
  236. # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
  237. reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean")
  238. ensemble_predict_y = reuse_ensemble.predict(user_data=X)
  239. # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
  240. reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression")
  241. reuse_ensemble.fit(X[:100], y[:100])
  242. ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X)
  243. print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False))
  244. print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False))
  245. print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False))
  246. def suite():
  247. _suite = unittest.TestSuite()
  248. #_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
  249. #_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
  250. #_suite.addTest(TestHeteroWorkflow("test_train_market_model"))
  251. _suite.addTest(TestHeteroWorkflow("test_search_semantics"))
  252. _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
  253. _suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))
  254. _suite.addTest(TestHeteroWorkflow("test_model_reuse"))
  255. return _suite
  256. if __name__ == "__main__":
  257. runner = unittest.TextTestRunner(verbosity=2)
  258. runner.run(suite())