You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

homo.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. import warnings
  3. import numpy as np
  4. warnings.filterwarnings("ignore")
  5. from learnware.market import BaseUserInfo
  6. from learnware.logger import get_module_logger
  7. from learnware.specification import generate_stat_spec
  8. from learnware.reuse import AveragingReuser, JobSelectorReuser, EnsemblePruningReuser
  9. from methods import *
  10. from base import TableWorkflow
  11. from config import n_labeled_list
  12. from utils import Recorder, plot_performance_curves
  13. logger = get_module_logger("homo_table", level="INFO")
  14. class CorporacionDatasetWorkflow(TableWorkflow):
  15. def unlabeled_homo_table_example(self):
  16. logger.info("Total Item: %d" % (len(self.market)))
  17. learnware_rmse_list = []
  18. single_score_list = []
  19. job_selector_score_list = []
  20. ensemble_score_list = []
  21. pruning_score_list = []
  22. all_learnwares = self.market.get_learnwares()
  23. user = self.benchmark.name
  24. for idx in range(self.benchmark.user_num):
  25. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  26. test_x, test_y = test_x.values, test_y.values
  27. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  28. user_info = BaseUserInfo(
  29. semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}
  30. )
  31. logger.info(f"Searching Market for user: {user}_{idx}")
  32. search_result = self.market.search_learnware(user_info)
  33. single_result = search_result.get_single_results()
  34. multiple_result = search_result.get_multiple_results()
  35. logger.info(f"search result of user {user}_{idx}:")
  36. logger.info(
  37. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  38. )
  39. pred_y = single_result[0].learnware.predict(test_x)
  40. single_score_list.append(loss_func_rmse(pred_y, test_y))
  41. rmse_list = []
  42. for learnware in all_learnwares:
  43. pred_y = learnware.predict(test_x)
  44. rmse_list.append(loss_func_rmse(pred_y, test_y))
  45. logger.info(
  46. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[-1]}"
  47. )
  48. if len(multiple_result) > 0:
  49. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  50. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  51. mixture_learnware_list = multiple_result[0].learnwares
  52. else:
  53. mixture_learnware_list = [single_result[0].learnware]
  54. # test reuse (job selector)
  55. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  56. reuse_predict = reuse_baseline.predict(user_data=test_x)
  57. reuse_score = loss_func_rmse(reuse_predict, test_y)
  58. job_selector_score_list.append(reuse_score)
  59. logger.info(f"mixture reuse rmse (job selector): {reuse_score}")
  60. # test reuse (ensemble)
  61. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean")
  62. ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
  63. ensemble_score = loss_func_rmse(ensemble_predict_y, test_y)
  64. ensemble_score_list.append(ensemble_score)
  65. logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}")
  66. # test reuse (ensemblePruning)
  67. reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="regression")
  68. pruning_predict_y = reuse_pruning.predict(user_data=test_x)
  69. pruning_score = loss_func_rmse(pruning_predict_y, test_y)
  70. pruning_score_list.append(pruning_score)
  71. logger.info(f"mixture reuse rmse (ensemble Pruning): {pruning_score}\n")
  72. learnware_rmse_list.append(rmse_list)
  73. single_list = np.array(learnware_rmse_list)
  74. avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
  75. oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
  76. logger.info(
  77. "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
  78. % (
  79. np.mean(single_score_list),
  80. np.std(single_score_list),
  81. np.mean(avg_score_list),
  82. np.std(avg_score_list),
  83. np.mean(oracle_score_list),
  84. np.std(oracle_score_list),
  85. )
  86. )
  87. logger.info(
  88. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  89. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  90. )
  91. logger.info(
  92. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  93. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  94. )
  95. logger.info(
  96. "Selective Ensemble Reuse Performance: %.3f +/- %.3f"
  97. % (np.mean(pruning_score_list), np.std(pruning_score_list))
  98. )
  99. def labeled_homo_table_example(self):
  100. logger.info("Total Item: %d" % (len(self.market)))
  101. methods = ["user_model", "homo_single_aug", "homo_multiple_aug", "homo_multiple_avg", "homo_ensemble_pruning"]
  102. methods_to_retest = []
  103. recorders = {method: Recorder() for method in methods}
  104. user = self.benchmark.name
  105. for idx in range(self.benchmark.user_num):
  106. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  107. test_x, test_y = test_x.values, test_y.values
  108. train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
  109. train_x, train_y = train_x.values, train_y.values
  110. train_subsets = self.get_train_subsets(idx, train_x, train_y)
  111. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  112. user_info = BaseUserInfo(
  113. semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}
  114. )
  115. logger.info(f"Searching Market for user: {user}_{idx}")
  116. search_result = self.market.search_learnware(user_info)
  117. single_result = search_result.get_single_results()
  118. multiple_result = search_result.get_multiple_results()
  119. logger.info(f"search result of user {user}_{idx}:")
  120. logger.info(
  121. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  122. )
  123. if len(multiple_result) > 0:
  124. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  125. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  126. mixture_learnware_list = multiple_result[0].learnwares
  127. else:
  128. mixture_learnware_list = [single_result[0].learnware]
  129. test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y}
  130. common_config = {"learnwares": mixture_learnware_list}
  131. method_configs = {
  132. "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
  133. "homo_single_aug": {"single_learnware": [single_result[0].learnware]},
  134. "homo_multiple_aug": common_config,
  135. "homo_multiple_avg": common_config,
  136. "homo_ensemble_pruning": common_config
  137. }
  138. for method_name in methods:
  139. logger.info(f"Testing method {method_name}")
  140. test_info["method_name"] = method_name
  141. test_info["force"] = method_name in methods_to_retest
  142. test_info.update(method_configs[method_name])
  143. self.test_method(test_info, recorders, loss_func=loss_func_rmse)
  144. for method, recorder in recorders.items():
  145. recorder.save(os.path.join(self.curves_result_path, f"{user}_{method}_performance.json"))
  146. methods_to_plot = ["user_model", "homo_single_aug", "homo_ensemble_pruning"]
  147. plot_performance_curves(user, {method: recorders[method] for method in methods_to_plot}, task="Homo", n_labeled_list=n_labeled_list)