You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import os
  2. import fire
  3. import pickle
  4. import time
  5. import zipfile
  6. from shutil import copyfile, rmtree
  7. import numpy as np
  8. import learnware.specification as specification
  9. from get_data import get_data
  10. from learnware.logger import get_module_logger
  11. from learnware.market import instantiate_learnware_market, BaseUserInfo
  12. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
  13. from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction
  14. logger = get_module_logger("text_test", level="INFO")
  15. origin_data_root = "./data/origin_data"
  16. processed_data_root = "./data/processed_data"
  17. tmp_dir = "./data/tmp"
  18. learnware_pool_dir = "./data/learnware_pool"
  19. dataset = "ae" # argumentative essays
  20. n_uploaders = 7
  21. n_users = 7
  22. n_classes = 3
  23. data_root = os.path.join(origin_data_root, dataset)
  24. data_save_root = os.path.join(processed_data_root, dataset)
  25. user_save_root = os.path.join(data_save_root, "user")
  26. uploader_save_root = os.path.join(data_save_root, "uploader")
  27. model_save_root = os.path.join(data_save_root, "uploader_model")
  28. os.makedirs(data_root, exist_ok=True)
  29. os.makedirs(user_save_root, exist_ok=True)
  30. os.makedirs(uploader_save_root, exist_ok=True)
  31. os.makedirs(model_save_root, exist_ok=True)
  32. output_description = {
  33. "Dimension": 1,
  34. "Description": {
  35. "0": "classify as 0(ineffective), 1(effective), or 2(adequate).",
  36. },
  37. }
  38. semantic_specs = [
  39. {
  40. "Data": {"Values": ["Text"], "Type": "Class"},
  41. "Task": {"Values": ["Classification"], "Type": "Class"},
  42. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  43. "Scenario": {"Values": ["Education"], "Type": "Tag"},
  44. "Description": {"Values": "", "Type": "String"},
  45. "Name": {"Values": "learnware_1", "Type": "String"},
  46. "Output": output_description,
  47. }
  48. ]
  49. user_semantic = {
  50. "Data": {"Values": ["Text"], "Type": "Class"},
  51. "Task": {"Values": ["Classification"], "Type": "Class"},
  52. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  53. "Scenario": {"Values": ["Education"], "Type": "Tag"},
  54. "Description": {"Values": "", "Type": "String"},
  55. "Name": {"Values": "", "Type": "String"},
  56. "Output": output_description,
  57. }
  58. class TextDatasetWorkflow:
  59. def _init_text_dataset(self):
  60. self._prepare_data()
  61. self._prepare_model()
  62. def _prepare_data(self):
  63. X_train, y_train, X_test, y_test = get_data(data_root)
  64. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  65. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  66. def _prepare_model(self):
  67. dataloader = TextDataLoader(data_save_root, train=True)
  68. for i in range(n_uploaders):
  69. logger.info("Train on uploader: %d" % (i))
  70. X, y = dataloader.get_idx_data(i)
  71. vectorizer, lgbm = train(X, y, out_classes=n_classes)
  72. modelv_save_path = os.path.join(model_save_root, "uploader_v_%d.pth" % (i))
  73. modell_save_path = os.path.join(model_save_root, "uploader_l_%d.pth" % (i))
  74. with open(modelv_save_path, "wb") as f:
  75. pickle.dump(vectorizer, f)
  76. with open(modell_save_path, "wb") as f:
  77. pickle.dump(lgbm, f)
  78. logger.info("Model saved to '%s' and '%s'" % (modelv_save_path, modell_save_path))
  79. def _prepare_learnware(self,
  80. data_path, modelv_path, modell_path, init_file_path, yaml_path, env_file_path, save_root, zip_name
  81. ):
  82. os.makedirs(save_root, exist_ok=True)
  83. tmp_spec_path = os.path.join(save_root, "rkme.json")
  84. tmp_modelv_path = os.path.join(save_root, "modelv.pth")
  85. tmp_modell_path = os.path.join(save_root, "modell.pth")
  86. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  87. tmp_init_path = os.path.join(save_root, "__init__.py")
  88. tmp_env_path = os.path.join(save_root, "requirements.txt")
  89. with open(data_path, "rb") as f:
  90. X = pickle.load(f)
  91. semantic_spec = semantic_specs[0]
  92. st = time.time()
  93. user_spec = specification.RKMETextSpecification()
  94. user_spec.generate_stat_spec_from_data(X=X)
  95. ed = time.time()
  96. logger.info("Stat spec generated in %.3f s" % (ed - st))
  97. user_spec.save(tmp_spec_path)
  98. copyfile(modelv_path, tmp_modelv_path)
  99. copyfile(modell_path, tmp_modell_path)
  100. copyfile(yaml_path, tmp_yaml_path)
  101. copyfile(init_file_path, tmp_init_path)
  102. copyfile(env_file_path, tmp_env_path)
  103. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  104. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  105. zip_obj.write(tmp_spec_path, "rkme.json")
  106. zip_obj.write(tmp_modelv_path, "modelv.pth")
  107. zip_obj.write(tmp_modell_path, "modell.pth")
  108. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  109. zip_obj.write(tmp_init_path, "__init__.py")
  110. zip_obj.write(tmp_env_path, "requirements.txt")
  111. rmtree(save_root)
  112. logger.info("New Learnware Saved to %s" % (zip_file_name))
  113. return zip_file_name
  114. def prepare_market(self, regenerate_flag=False):
  115. if regenerate_flag:
  116. self._init_text_dataset()
  117. text_market = instantiate_learnware_market(market_id="ae", rebuild=True)
  118. try:
  119. rmtree(learnware_pool_dir)
  120. except:
  121. pass
  122. os.makedirs(learnware_pool_dir, exist_ok=True)
  123. for i in range(n_uploaders):
  124. data_path = os.path.join(uploader_save_root, "uploader_%d_X.pkl" % (i))
  125. modelv_path = os.path.join(model_save_root, "uploader_v_%d.pth" % (i))
  126. modell_path = os.path.join(model_save_root, "uploader_l_%d.pth" % (i))
  127. init_file_path = "./example_files/example_init.py"
  128. yaml_file_path = "./example_files/example_yaml.yaml"
  129. env_file_path = "./example_files/requirements.txt"
  130. new_learnware_path = self._prepare_learnware(
  131. data_path,
  132. modelv_path,
  133. modell_path,
  134. init_file_path,
  135. yaml_file_path,
  136. env_file_path,
  137. tmp_dir,
  138. "%s_%d" % (dataset, i),
  139. )
  140. semantic_spec = semantic_specs[0]
  141. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  142. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  143. text_market.add_learnware(new_learnware_path, semantic_spec)
  144. logger.info("Total Item: %d" % (len(text_market)))
  145. def test(self, regenerate_flag=False):
  146. self.prepare_market(regenerate_flag)
  147. text_market = instantiate_learnware_market(market_id="ae")
  148. print("Total Item: %d" % len(text_market))
  149. select_list = []
  150. avg_list = []
  151. improve_list = []
  152. job_selector_score_list = []
  153. ensemble_score_list = []
  154. pruning_score_list = []
  155. for i in range(n_users):
  156. user_data_path = os.path.join(user_save_root, "user_%d_X.pkl" % (i))
  157. user_label_path = os.path.join(user_save_root, "user_%d_y.pkl" % (i))
  158. with open(user_data_path, "rb") as f:
  159. user_data = pickle.load(f)
  160. with open(user_label_path, "rb") as f:
  161. user_label = pickle.load(f)
  162. user_stat_spec = specification.RKMETextSpecification()
  163. user_stat_spec.generate_stat_spec_from_data(X=user_data)
  164. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec})
  165. logger.info("Searching Market for user: %d" % (i))
  166. search_result = text_market.search_learnware(user_info)
  167. single_result = search_result.get_single_results()
  168. multiple_result = search_result.get_multiple_results()
  169. print(f"search result of user{i}:")
  170. print(
  171. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  172. )
  173. l = len(single_result)
  174. acc_list = []
  175. for idx in range(l):
  176. learnware = single_result[idx].learnware
  177. score = single_result[idx].score
  178. pred_y = learnware.predict(user_data)
  179. acc = eval_prediction(pred_y, user_label)
  180. acc_list.append(acc)
  181. print(
  182. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {acc_list[0]}"
  183. )
  184. if len(multiple_result) > 0:
  185. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  186. print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  187. mixture_learnware_list = multiple_result[0].learnwares
  188. else:
  189. mixture_learnware_list = [single_result[0].learnware]
  190. # test reuse (job selector)
  191. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  192. reuse_predict = reuse_baseline.predict(user_data=user_data)
  193. reuse_score = eval_prediction(reuse_predict, user_label)
  194. job_selector_score_list.append(reuse_score)
  195. print(f"mixture reuse loss(job selector): {reuse_score}")
  196. # test reuse (ensemble)
  197. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label")
  198. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  199. ensemble_score = eval_prediction(ensemble_predict_y, user_label)
  200. ensemble_score_list.append(ensemble_score)
  201. print(f"mixture reuse accuracy (ensemble): {ensemble_score}")
  202. # test reuse (ensemblePruning)
  203. reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list)
  204. pruning_predict_y = reuse_pruning.predict(user_data=user_data)
  205. pruning_score = eval_prediction(pruning_predict_y, user_label)
  206. pruning_score_list.append(pruning_score)
  207. print(f"mixture reuse accuracy (ensemble Pruning): {pruning_score}\n")
  208. select_list.append(acc_list[0])
  209. avg_list.append(np.mean(acc_list))
  210. improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))
  211. logger.info(
  212. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
  213. % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
  214. )
  215. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  216. logger.info(
  217. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  218. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  219. )
  220. logger.info(
  221. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  222. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  223. )
  224. logger.info(
  225. "Selective Ensemble Reuse Performance: %.3f +/- %.3f"
  226. % (np.mean(pruning_score_list), np.std(pruning_score_list))
  227. )
  228. if __name__ == "__main__":
  229. fire.Fire(TextDatasetWorkflow)