You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hetero_workflow.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. import torch
  2. import pickle
  3. import unittest
  4. import os
  5. import logging
  6. import tempfile
  7. import zipfile
  8. from sklearn.linear_model import Ridge
  9. from sklearn.datasets import make_regression
  10. from shutil import copyfile, rmtree
  11. from sklearn.metrics import mean_squared_error
  12. import learnware
  13. learnware.init(logging_level=logging.WARNING)
  14. from learnware.market import instantiate_learnware_market, BaseUserInfo
  15. from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
  16. from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
  17. from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
  18. from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list
  19. curr_root = os.path.dirname(os.path.abspath(__file__))
  20. class TestHeteroWorkflow(unittest.TestCase):
  21. universal_semantic_config = {
  22. "data_type": "Table",
  23. "task_type": "Regression",
  24. "library_type": "Scikit-learn",
  25. "scenarios": "Education",
  26. "license": "MIT",
  27. }
  28. def _init_learnware_market(self, organizer_kwargs=None):
  29. """initialize learnware market"""
  30. hetero_market = instantiate_learnware_market(
  31. market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
  32. )
  33. return hetero_market
  34. def test_prepare_learnware_randomly(self, learnware_num=5):
  35. self.zip_path_list = []
  36. for i in range(learnware_num):
  37. learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
  38. os.makedirs(learnware_pool_dirpath, exist_ok=True)
  39. learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
  40. print("Preparing Learnware: %d" % (i))
  41. X, y = make_regression(
  42. n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42
  43. )
  44. clf = Ridge(alpha=1.0)
  45. clf.fit(X, y)
  46. pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
  47. with open(pickle_filepath, "wb") as fout:
  48. pickle.dump(clf, fout)
  49. spec = generate_rkme_table_spec(X=X, gamma=0.1)
  50. spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
  51. spec.save(spec_filepath)
  52. LearnwareTemplate.generate_learnware_zipfile(
  53. learnware_zippath=learnware_zippath,
  54. model_template=PickleModelTemplate(
  55. pickle_filepath=pickle_filepath,
  56. model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)},
  57. ),
  58. stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
  59. requirements=["scikit-learn==0.22"],
  60. )
  61. self.zip_path_list.append(learnware_zippath)
  62. def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
  63. self.test_prepare_learnware_randomly(learnware_num)
  64. self.learnware_num = learnware_num
  65. print("Total Item:", len(hetero_market))
  66. assert len(hetero_market) == 0, f"The market should be empty!"
  67. for idx, zip_path in enumerate(self.zip_path_list):
  68. semantic_spec = generate_semantic_spec(
  69. name=f"learnware_{idx}",
  70. description=f"test_learnware_number_{idx}",
  71. input_description=input_description_list[idx % 2],
  72. output_description=output_description_list[idx % 2],
  73. **self.universal_semantic_config,
  74. )
  75. hetero_market.add_learnware(zip_path, semantic_spec)
  76. print("Total Item:", len(hetero_market))
  77. assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  78. curr_inds = hetero_market.get_learnware_ids()
  79. print("Available ids After Uploading Learnwares:", curr_inds)
  80. assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  81. if delete:
  82. for learnware_id in curr_inds:
  83. hetero_market.delete_learnware(learnware_id)
  84. self.learnware_num -= 1
  85. assert (
  86. len(hetero_market) == self.learnware_num
  87. ), f"The number of learnwares must be {self.learnware_num}!"
  88. curr_inds = hetero_market.get_learnware_ids()
  89. print("Available ids After Deleting Learnwares:", curr_inds)
  90. assert len(curr_inds) == 0, f"The market should be empty!"
  91. return hetero_market
  92. def test_upload_delete_learnware(self, learnware_num=5, delete=True):
  93. hetero_market = self._init_learnware_market()
  94. return self._upload_delete_learnware(hetero_market, learnware_num, delete)
  95. def test_train_market_model(self, learnware_num=5, delete=False):
  96. hetero_market = self._init_learnware_market(
  97. organizer_kwargs={"auto_update": True, "auto_update_limit": learnware_num}
  98. )
  99. hetero_market = self._upload_delete_learnware(hetero_market, learnware_num, delete)
  100. # organizer=hetero_market.learnware_organizer
  101. # organizer.train(hetero_market.learnware_organizer.learnware_list.values())
  102. return hetero_market
  103. def test_search_semantics(self, learnware_num=5):
  104. hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False)
  105. print("Total Item:", len(hetero_market))
  106. assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
  107. semantic_spec = generate_semantic_spec(
  108. name=f"learnware_{learnware_num - 1}",
  109. **self.universal_semantic_config,
  110. )
  111. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  112. search_result = hetero_market.search_learnware(user_info)
  113. single_result = search_result.get_single_results()
  114. print(f"Search result1:")
  115. assert len(single_result) == 1, f"Exact semantic search failed!"
  116. for search_item in single_result:
  117. semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
  118. print("Choose learnware:", search_item.learnware.id)
  119. assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!"
  120. semantic_spec["Name"]["Values"] = "laernwaer"
  121. user_info = BaseUserInfo(semantic_spec=semantic_spec)
  122. search_result = hetero_market.search_learnware(user_info)
  123. single_result = search_result.get_single_results()
  124. print(f"Search result2:")
  125. assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!"
  126. for search_item in single_result:
  127. print("Choose learnware:", search_item.learnware.id)
  128. def test_hetero_stat_search(self, learnware_num=5):
  129. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  130. print("Total Item:", len(hetero_market))
  131. user_dim = 15
  132. with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
  133. for idx, zip_path in enumerate(self.zip_path_list):
  134. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  135. zip_obj.extractall(path=test_folder)
  136. user_spec = RKMETableSpecification()
  137. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  138. z = user_spec.get_z()
  139. z = z[:, :user_dim]
  140. device = user_spec.device
  141. z = torch.tensor(z, device=device)
  142. user_spec.z = z
  143. print(">> normal case test:")
  144. semantic_spec = generate_semantic_spec(
  145. input_description={
  146. "Dimension": user_dim,
  147. "Description": {
  148. str(key): input_description_list[idx % 2]["Description"][str(key)]
  149. for key in range(user_dim)
  150. },
  151. },
  152. **self.universal_semantic_config,
  153. )
  154. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  155. search_result = hetero_market.search_learnware(user_info)
  156. single_result = search_result.get_single_results()
  157. multiple_result = search_result.get_multiple_results()
  158. print(f"search result of user{idx}:")
  159. for single_item in single_result:
  160. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  161. for multiple_item in multiple_result:
  162. print(
  163. f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
  164. )
  165. # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
  166. print(">> test for key 'Task' has empty 'Values':")
  167. semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}
  168. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  169. search_result = hetero_market.search_learnware(user_info)
  170. single_result = search_result.get_single_results()
  171. assert len(single_result) == 0, f"Statistical search failed!"
  172. # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
  173. print(">> delele key 'Task' test:")
  174. semantic_spec.pop("Task")
  175. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  176. search_result = hetero_market.search_learnware(user_info)
  177. single_result = search_result.get_single_results()
  178. assert len(single_result) == 0, f"Statistical search failed!"
  179. # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
  180. print(">> mismatch dim test")
  181. semantic_spec = generate_semantic_spec(
  182. input_description={
  183. "Dimension": user_dim - 2,
  184. "Description": {
  185. str(key): input_description_list[idx % 2]["Description"][str(key)]
  186. for key in range(user_dim)
  187. },
  188. },
  189. **self.universal_semantic_config,
  190. )
  191. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  192. search_result = hetero_market.search_learnware(user_info)
  193. single_result = search_result.get_single_results()
  194. assert len(single_result) == 0, f"Statistical search failed!"
  195. def test_homo_stat_search(self, learnware_num=5):
  196. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  197. print("Total Item:", len(hetero_market))
  198. with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
  199. for idx, zip_path in enumerate(self.zip_path_list):
  200. with zipfile.ZipFile(zip_path, "r") as zip_obj:
  201. zip_obj.extractall(path=test_folder)
  202. user_spec = RKMETableSpecification()
  203. user_spec.load(os.path.join(test_folder, "stat_spec.json"))
  204. user_semantic = generate_semantic_spec(**self.universal_semantic_config)
  205. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
  206. search_result = hetero_market.search_learnware(user_info)
  207. single_result = search_result.get_single_results()
  208. multiple_result = search_result.get_multiple_results()
  209. assert len(single_result) >= 1, f"Statistical search failed!"
  210. print(f"search result of user{idx}:")
  211. for single_item in single_result:
  212. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  213. for multiple_item in multiple_result:
  214. print(f"mixture_score: {multiple_item.score}\n")
  215. mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
  216. print(f"mixture_learnware: {mixture_id}\n")
  217. def test_model_reuse(self, learnware_num=5):
  218. # generate toy regression problem
  219. X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)
  220. # generate rkme
  221. user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)
  222. # generate specification
  223. semantic_spec = generate_semantic_spec(
  224. input_description=user_description_list[0], **self.universal_semantic_config
  225. )
  226. user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
  227. # learnware market search
  228. hetero_market = self.test_train_market_model(learnware_num, delete=False)
  229. search_result = hetero_market.search_learnware(user_info)
  230. single_result = search_result.get_single_results()
  231. multiple_result = search_result.get_multiple_results()
  232. # print search results
  233. for single_item in single_result:
  234. print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
  235. for multiple_item in multiple_result:
  236. print(
  237. f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
  238. )
  239. # single model reuse
  240. hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
  241. hetero_learnware.align(user_spec, X[:100], y[:100])
  242. single_predict_y = hetero_learnware.predict(X)
  243. # multi model reuse
  244. hetero_learnware_list = []
  245. for learnware in multiple_result[0].learnwares:
  246. hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression")
  247. hetero_learnware.align(user_spec, X[:100], y[:100])
  248. hetero_learnware_list.append(hetero_learnware)
  249. # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
  250. reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean")
  251. ensemble_predict_y = reuse_ensemble.predict(user_data=X)
  252. # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
  253. reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression")
  254. reuse_ensemble.fit(X[:100], y[:100])
  255. ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X)
  256. print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False))
  257. print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False))
  258. print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False))
  259. def suite():
  260. _suite = unittest.TestSuite()
  261. # _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
  262. # _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
  263. # _suite.addTest(TestHeteroWorkflow("test_train_market_model"))
  264. _suite.addTest(TestHeteroWorkflow("test_search_semantics"))
  265. _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
  266. _suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))
  267. _suite.addTest(TestHeteroWorkflow("test_model_reuse"))
  268. return _suite
  269. if __name__ == "__main__":
  270. runner = unittest.TextTestRunner(verbosity=2)
  271. runner.run(suite())