You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hetero_workflow.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import logging
  2. import os
  3. import pickle
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. import torch
  8. from hetero_config import input_description_list, input_shape_list, output_description_list, user_description_list
  9. from sklearn.datasets import make_regression
  10. from sklearn.linear_model import Ridge
  11. from sklearn.metrics import mean_squared_error
  12. import learnware
  13. from learnware.market import BaseUserInfo, instantiate_learnware_market
  14. from learnware.reuse import AveragingReuser, EnsemblePruningReuser, HeteroMapAlignLearnware
  15. from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
  16. from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
  17. learnware.init(logging_level=logging.WARNING)
  18. curr_root = os.path.dirname(os.path.abspath(__file__))
  19. class TestHeteroWorkflow(unittest.TestCase):
  20. universal_semantic_config = {
  21. "data_type": "Table",
  22. "task_type": "Regression",
  23. "library_type": "Scikit-learn",
  24. "scenarios": "Education",
  25. "license": "MIT",
  26. }
  27. def _init_learnware_market(self, organizer_kwargs=None):
  28. """initialize learnware market"""
  29. hetero_market = instantiate_learnware_market(
  30. market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
  31. )
  32. return hetero_market
  33. def test_prepare_learnware_randomly(self, learnware_num=5):
  34. self.zip_path_list = []
  35. for i in range(learnware_num):
  36. learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
  37. os.makedirs(learnware_pool_dirpath, exist_ok=True)
  38. learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
  39. print("Preparing Learnware: %d" % (i))
  40. X, y = make_regression(
  41. n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42
  42. )
  43. clf = Ridge(alpha=1.0)
  44. clf.fit(X, y)
  45. pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
  46. with open(pickle_filepath, "wb") as fout:
  47. pickle.dump(clf, fout)
  48. spec = generate_rkme_table_spec(X=X, gamma=0.1)
  49. spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
  50. spec.save(spec_filepath)
  51. LearnwareTemplate.generate_learnware_zipfile(
  52. learnware_zippath=learnware_zippath,
  53. model_template=PickleModelTemplate(
  54. pickle_filepath=pickle_filepath,
  55. model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)},
  56. ),
  57. stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
  58. requirements=["scikit-learn==0.22"],
  59. )
  60. self.zip_path_list.append(learnware_zippath)
  61. def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
  62. self.test_prepare_learnware_randomly(learnware_num)
  63. self.learnware_num = learnware_num
  64. print("Total Item:", len(hetero_market))
  65. assert len(hetero_market) == 0, "The market should be empty!"
  66. for idx, zip_path in enumerate(self.zip_path_list):
  67. semantic_spec = generate_semantic_spec(
  68. name=f"learnware_{idx}",
  69. description=f"test_learnware_number_{idx}",
  70. input_description=input_description_list[idx % 2],
  71. output_description=output_description_list[idx % 2],
  72. **self.universal_semantic_config,
  73. )
  74. hetero_market.add_learnware(zip_path, semantic_spec)
  75. print("Total Item:", len(hetero_market))
  76. assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  77. curr_inds = hetero_market.get_learnware_ids()
  78. print("Available ids After Uploading Learnwares:", curr_inds)
  79. assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  80. if delete:
  81. for learnware_id in curr_inds:
  82. hetero_market.delete_learnware(learnware_id)
  83. self.learnware_num -= 1
  84. assert (
  85. len(hetero_market) == self.learnware_num
  86. ), f"The number of learnwares must be {self.learnware_num}!"
  87. curr_inds = hetero_market.get_learnware_ids()
  88. print("Available ids After Deleting Learnwares:", curr_inds)
  89. assert len(curr_inds) == 0, "The market should be empty!"
  90. return hetero_market
  91. def test_upload_delete_learnware(self, learnware_num=5, delete=True):
  92. hetero_market = self._init_learnware_market()
  93. return self._upload_delete_learnware(hetero_market, learnware_num, delete)
  94. def test_train_market_model(self, learnware_num=5, delete=False):
  95. hetero_market = self._init_learnware_market(
  96. organizer_kwargs={"auto_update": True, "auto_update_limit": learnware_num}
  97. )
  98. hetero_market = self._upload_delete_learnware(hetero_market, learnware_num, delete)
  99. # organizer=hetero_market.learnware_organizer
  100. # organizer.train(hetero_market.learnware_organizer.learnware_list.values())
  101. return hetero_market
  102. def test_search_semantics(self, learnware_num=5):
  103. hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  104. print("Total Item:", len(hetero_market))
  105. assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  106. semantic_spec = generate_semantic_spec(
  107. name=f"learnware_{learnware_num - 1}",
  108. **self.universal_semantic_config,
  109. )
  110. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  111. search_result = hetero_market.search_learnware(user_info)
  112. single_result = search_result.get_single_results()
  113. print("Search result1:")
  114. assert len(single_result) == 1, "Exact semantic search failed!"
  115. for search_item in single_result:
  116. semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
  117. print("Choose learnware:", search_item.learnware.id)
  118. assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], "Exact semantic search failed!"
  119. semantic_spec["Name"]["Values"] = "laernwaer"
  120. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  121. search_result = hetero_market.search_learnware(user_info)
  122. single_result = search_result.get_single_results()
  123. print("Search result2:")
  124. assert len(single_result) == self.learnware_num, "Fuzzy semantic search failed!"
  125. for search_item in single_result:
  126. print("Choose learnware:", search_item.learnware.id)
  127. def test_hetero_stat_search(self, learnware_num=5):
  128. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  129. print("Total Item:", len(hetero_market))
  130. user_dim = 15
  131. with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
  132. for idx, zip_path in enumerate(self.zip_path_list):
  133. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  134. zip_obj.extractall(path=test_folder)
  135. user_spec = RKMETableSpecification()
  136. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  137. z = user_spec.get_z()
  138. z = z[:, :user_dim]
  139. device = user_spec.device
  140. z = torch.tensor(z, device=device)
  141. user_spec.z = z
  142. print(">> normal case test:")
  143. semantic_spec = generate_semantic_spec(
  144. input_description={
  145. "Dimension": user_dim,
  146. "Description": {
  147. str(key): input_description_list[idx % 2]["Description"][str(key)]
  148. for key in range(user_dim)
  149. },
  150. },
  151. **self.universal_semantic_config,
  152. )
  153. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  154. search_result = hetero_market.search_learnware(user_info)
  155. single_result = search_result.get_single_results()
  156. multiple_result = search_result.get_multiple_results()
  157. print(f"search result of user{idx}:")
  158. for single_item in single_result:
  159. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  160. for multiple_item in multiple_result:
  161. print(
  162. f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
  163. )
  164. # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
  165. print(">> test for key 'Task' has empty 'Values':")
  166. semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}
  167. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  168. search_result = hetero_market.search_learnware(user_info)
  169. single_result = search_result.get_single_results()
  170. assert len(single_result) == 0, "Statistical search failed!"
  171. # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
  172. print(">> delele key 'Task' test:")
  173. semantic_spec.pop("Task")
  174. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  175. search_result = hetero_market.search_learnware(user_info)
  176. single_result = search_result.get_single_results()
  177. assert len(single_result) == 0, "Statistical search failed!"
  178. # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
  179. print(">> mismatch dim test")
  180. semantic_spec = generate_semantic_spec(
  181. input_description={
  182. "Dimension": user_dim - 2,
  183. "Description": {
  184. str(key): input_description_list[idx % 2]["Description"][str(key)]
  185. for key in range(user_dim)
  186. },
  187. },
  188. **self.universal_semantic_config,
  189. )
  190. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  191. search_result = hetero_market.search_learnware(user_info)
  192. single_result = search_result.get_single_results()
  193. assert len(single_result) == 0, "Statistical search failed!"
  194. def test_homo_stat_search(self, learnware_num=5):
  195. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  196. print("Total Item:", len(hetero_market))
  197. with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
  198. for idx, zip_path in enumerate(self.zip_path_list):
  199. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  200. zip_obj.extractall(path=test_folder)
  201. user_spec = RKMETableSpecification()
  202. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  203. user_semantic = generate_semantic_spec(**self.universal_semantic_config)
  204. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
  205. search_result = hetero_market.search_learnware(user_info)
  206. single_result = search_result.get_single_results()
  207. multiple_result = search_result.get_multiple_results()
  208. assert len(single_result) >= 1, "Statistical search failed!"
  209. print(f"search result of user{idx}:")
  210. for single_item in single_result:
  211. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  212. for multiple_item in multiple_result:
  213. print(f"mixture_score: {multiple_item.score}\n")
  214. mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
  215. print(f"mixture_learnware: {mixture_id}\n")
  216. def test_model_reuse(self, learnware_num=5):
  217. # generate toy regression problem
  218. X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)
  219. # generate rkme
  220. user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)
  221. # generate specification
  222. semantic_spec = generate_semantic_spec(
  223. input_description=user_description_list[0], **self.universal_semantic_config
  224. )
  225. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  226. # learnware market search
  227. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  228. search_result = hetero_market.search_learnware(user_info)
  229. single_result = search_result.get_single_results()
  230. multiple_result = search_result.get_multiple_results()
  231. # print search results
  232. for single_item in single_result:
  233. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  234. for multiple_item in multiple_result:
  235. print(
  236. f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
  237. )
  238. # single model reuse
  239. hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
  240. hetero_learnware.align(user_spec, X[:100], y[:100])
  241. single_predict_y = hetero_learnware.predict(X)
  242. # multi model reuse
  243. hetero_learnware_list = []
  244. for org_learnware in multiple_result[0].learnwares:
  245. hetero_learnware = HeteroMapAlignLearnware(org_learnware, mode="regression")
  246. hetero_learnware.align(user_spec, X[:100], y[:100])
  247. hetero_learnware_list.append(hetero_learnware)
  248. # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
  249. reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean")
  250. ensemble_predict_y = reuse_ensemble.predict(user_data=X)
  251. # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
  252. reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression")
  253. reuse_ensemble.fit(X[:100], y[:100])
  254. ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X)
  255. print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False))
  256. print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False))
  257. print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False))
  258. def suite():
  259. _suite = unittest.TestSuite()
  260. # _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
  261. # _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
  262. # _suite.addTest(TestHeteroWorkflow("test_train_market_model"))
  263. _suite.addTest(TestHeteroWorkflow("test_search_semantics"))
  264. _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
  265. _suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))
  266. _suite.addTest(TestHeteroWorkflow("test_model_reuse"))
  267. return _suite
  268. if __name__ == "__main__":
  269. runner = unittest.TextTestRunner(verbosity=2)
  270. runner.run(suite())