You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 10 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import os
  2. import pickle
  3. import time
  4. import zipfile
  5. from shutil import copyfile, rmtree
  6. import numpy as np
  7. import learnware.specification as specification
  8. from get_data import get_data
  9. from learnware.logger import get_module_logger
  10. from learnware.market import instantiate_learnware_market, BaseUserInfo
  11. from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
  12. from utils import generate_uploader, generate_user, TextDataLoader, train, eval_prediction
  13. logger = get_module_logger("text_test", level="INFO")
  14. origin_data_root = "./data/origin_data"
  15. processed_data_root = "./data/processed_data"
  16. tmp_dir = "./data/tmp"
  17. learnware_pool_dir = "./data/learnware_pool"
  18. dataset = "ae" # argumentative essays
  19. n_uploaders = 7
  20. n_users = 7
  21. n_classes = 3
  22. data_root = os.path.join(origin_data_root, dataset)
  23. data_save_root = os.path.join(processed_data_root, dataset)
  24. user_save_root = os.path.join(data_save_root, "user")
  25. uploader_save_root = os.path.join(data_save_root, "uploader")
  26. model_save_root = os.path.join(data_save_root, "uploader_model")
  27. os.makedirs(data_root, exist_ok=True)
  28. os.makedirs(user_save_root, exist_ok=True)
  29. os.makedirs(uploader_save_root, exist_ok=True)
  30. os.makedirs(model_save_root, exist_ok=True)
  31. output_description = {
  32. "Dimension": 1,
  33. "Description": {
  34. "0": "classify as 0(ineffective), 1(effective), or 2(adequate).",
  35. },
  36. }
  37. semantic_specs = [
  38. {
  39. "Data": {"Values": ["Text"], "Type": "Class"},
  40. "Task": {"Values": ["Classification"], "Type": "Class"},
  41. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  42. "Scenario": {"Values": ["Education"], "Type": "Tag"},
  43. "Description": {"Values": "", "Type": "String"},
  44. "Name": {"Values": "learnware_1", "Type": "String"},
  45. "Output": output_description,
  46. }
  47. ]
  48. user_semantic = {
  49. "Data": {"Values": ["Text"], "Type": "Class"},
  50. "Task": {"Values": ["Classification"], "Type": "Class"},
  51. "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
  52. "Scenario": {"Values": ["Education"], "Type": "Tag"},
  53. "Description": {"Values": "", "Type": "String"},
  54. "Name": {"Values": "", "Type": "String"},
  55. "Output": output_description,
  56. }
  57. def prepare_data():
  58. X_train, y_train, X_test, y_test = get_data(data_root)
  59. generate_uploader(X_train, y_train, n_uploaders=n_uploaders, data_save_root=uploader_save_root)
  60. generate_user(X_test, y_test, n_users=n_users, data_save_root=user_save_root)
  61. def prepare_model():
  62. dataloader = TextDataLoader(data_save_root, train=True)
  63. for i in range(n_uploaders):
  64. logger.info("Train on uploader: %d" % (i))
  65. X, y = dataloader.get_idx_data(i)
  66. vectorizer, lgbm = train(X, y, out_classes=n_classes)
  67. modelv_save_path = os.path.join(model_save_root, "uploader_v_%d.pth" % (i))
  68. modell_save_path = os.path.join(model_save_root, "uploader_l_%d.pth" % (i))
  69. with open(modelv_save_path, "wb") as f:
  70. pickle.dump(vectorizer, f)
  71. with open(modell_save_path, "wb") as f:
  72. pickle.dump(lgbm, f)
  73. logger.info("Model saved to '%s' and '%s'" % (modelv_save_path, modell_save_path))
  74. def prepare_learnware(
  75. data_path, modelv_path, modell_path, init_file_path, yaml_path, env_file_path, save_root, zip_name
  76. ):
  77. os.makedirs(save_root, exist_ok=True)
  78. tmp_spec_path = os.path.join(save_root, "rkme.json")
  79. tmp_modelv_path = os.path.join(save_root, "modelv.pth")
  80. tmp_modell_path = os.path.join(save_root, "modell.pth")
  81. tmp_yaml_path = os.path.join(save_root, "learnware.yaml")
  82. tmp_init_path = os.path.join(save_root, "__init__.py")
  83. tmp_env_path = os.path.join(save_root, "requirements.txt")
  84. with open(data_path, "rb") as f:
  85. X = pickle.load(f)
  86. semantic_spec = semantic_specs[0]
  87. st = time.time()
  88. user_spec = specification.RKMETextSpecification()
  89. user_spec.generate_stat_spec_from_data(X=X)
  90. ed = time.time()
  91. logger.info("Stat spec generated in %.3f s" % (ed - st))
  92. user_spec.save(tmp_spec_path)
  93. copyfile(modelv_path, tmp_modelv_path)
  94. copyfile(modell_path, tmp_modell_path)
  95. copyfile(yaml_path, tmp_yaml_path)
  96. copyfile(init_file_path, tmp_init_path)
  97. copyfile(env_file_path, tmp_env_path)
  98. zip_file_name = os.path.join(learnware_pool_dir, "%s.zip" % (zip_name))
  99. with zipfile.ZipFile(zip_file_name, "w", compression=zipfile.ZIP_DEFLATED) as zip_obj:
  100. zip_obj.write(tmp_spec_path, "rkme.json")
  101. zip_obj.write(tmp_modelv_path, "modelv.pth")
  102. zip_obj.write(tmp_modell_path, "modell.pth")
  103. zip_obj.write(tmp_yaml_path, "learnware.yaml")
  104. zip_obj.write(tmp_init_path, "__init__.py")
  105. zip_obj.write(tmp_env_path, "requirements.txt")
  106. rmtree(save_root)
  107. logger.info("New Learnware Saved to %s" % (zip_file_name))
  108. return zip_file_name
  109. def prepare_market():
  110. text_market = instantiate_learnware_market(market_id="ae", rebuild=True)
  111. try:
  112. rmtree(learnware_pool_dir)
  113. except:
  114. pass
  115. os.makedirs(learnware_pool_dir, exist_ok=True)
  116. for i in range(n_uploaders):
  117. data_path = os.path.join(uploader_save_root, "uploader_%d_X.pkl" % (i))
  118. modelv_path = os.path.join(model_save_root, "uploader_v_%d.pth" % (i))
  119. modell_path = os.path.join(model_save_root, "uploader_l_%d.pth" % (i))
  120. init_file_path = "./example_files/example_init.py"
  121. yaml_file_path = "./example_files/example_yaml.yaml"
  122. env_file_path = "./example_files/requirements.txt"
  123. new_learnware_path = prepare_learnware(
  124. data_path,
  125. modelv_path,
  126. modell_path,
  127. init_file_path,
  128. yaml_file_path,
  129. env_file_path,
  130. tmp_dir,
  131. "%s_%d" % (dataset, i),
  132. )
  133. semantic_spec = semantic_specs[0]
  134. semantic_spec["Name"]["Values"] = "learnware_%d" % (i)
  135. semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (i)
  136. text_market.add_learnware(new_learnware_path, semantic_spec)
  137. logger.info("Total Item: %d" % (len(text_market)))
  138. def test_search(load_market=True):
  139. if load_market:
  140. text_market = instantiate_learnware_market(market_id="ae")
  141. else:
  142. prepare_market()
  143. text_market = instantiate_learnware_market(market_id="ae")
  144. logger.info("Number of items in the market: %d" % len(text_market))
  145. select_list = []
  146. avg_list = []
  147. improve_list = []
  148. job_selector_score_list = []
  149. ensemble_score_list = []
  150. pruning_score_list = []
  151. for i in range(n_users):
  152. user_data_path = os.path.join(user_save_root, "user_%d_X.pkl" % (i))
  153. user_label_path = os.path.join(user_save_root, "user_%d_y.pkl" % (i))
  154. with open(user_data_path, "rb") as f:
  155. user_data = pickle.load(f)
  156. with open(user_label_path, "rb") as f:
  157. user_label = pickle.load(f)
  158. user_stat_spec = specification.RKMETextSpecification()
  159. user_stat_spec.generate_stat_spec_from_data(X=user_data)
  160. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec})
  161. logger.info("Searching Market for user: %d" % (i))
  162. sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware(
  163. user_info
  164. )
  165. l = len(sorted_score_list)
  166. acc_list = []
  167. for idx in range(l):
  168. learnware = single_learnware_list[idx]
  169. score = sorted_score_list[idx]
  170. pred_y = learnware.predict(user_data)
  171. acc = eval_prediction(pred_y, user_label)
  172. acc_list.append(acc)
  173. logger.info("search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc))
  174. # test reuse (job selector)
  175. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  176. reuse_predict = reuse_baseline.predict(user_data=user_data)
  177. reuse_score = eval_prediction(reuse_predict, user_label)
  178. job_selector_score_list.append(reuse_score)
  179. print(f"mixture reuse loss(job selector): {reuse_score}")
  180. # test reuse (ensemble)
  181. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_label")
  182. ensemble_predict_y = reuse_ensemble.predict(user_data=user_data)
  183. ensemble_score = eval_prediction(ensemble_predict_y, user_label)
  184. ensemble_score_list.append(ensemble_score)
  185. print(f"mixture reuse accuracy (ensemble): {ensemble_score}")
  186. # test reuse (ensemblePruning)
  187. reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list)
  188. pruning_predict_y = reuse_pruning.predict(user_data=user_data)
  189. pruning_score = eval_prediction(pruning_predict_y, user_label)
  190. pruning_score_list.append(pruning_score)
  191. print(f"mixture reuse accuracy (ensemble Pruning): {pruning_score}\n")
  192. select_list.append(acc_list[0])
  193. avg_list.append(np.mean(acc_list))
  194. improve_list.append((acc_list[0] - np.mean(acc_list)) / np.mean(acc_list))
  195. logger.info(
  196. "Accuracy of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f"
  197. % (np.mean(select_list), np.std(select_list), np.mean(avg_list), np.std(avg_list))
  198. )
  199. logger.info("Average performance improvement: %.3f" % (np.mean(improve_list)))
  200. logger.info(
  201. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  202. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  203. )
  204. logger.info(
  205. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  206. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  207. )
  208. logger.info(
  209. "Selective Ensemble Reuse Performance: %.3f +/- %.3f"
  210. % (np.mean(pruning_score_list), np.std(pruning_score_list))
  211. )
  212. if __name__ == "__main__":
  213. prepare_data()
  214. prepare_model()
  215. test_search(load_market=False)