You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

homo.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import os
  2. import warnings
  3. import numpy as np
  4. from base import TableWorkflow
  5. from config import homo_n_labeled_list, homo_n_repeat_list
  6. from methods import loss_func_rmse
  7. from utils import Recorder, plot_performance_curves
  8. from learnware.logger import get_module_logger
  9. from learnware.market import BaseUserInfo
  10. from learnware.reuse import AveragingReuser, JobSelectorReuser
  11. from learnware.specification import generate_stat_spec
  12. warnings.filterwarnings("ignore")
  13. logger = get_module_logger("homo_table", level="INFO")
  14. class HomogeneousDatasetWorkflow(TableWorkflow):
  15. def unlabeled_homo_table_example(self):
  16. logger.info("Total Item: %d" % (len(self.market)))
  17. learnware_rmse_list = []
  18. single_score_list = []
  19. job_selector_score_list = []
  20. ensemble_score_list = []
  21. all_learnwares = self.market.get_learnwares()
  22. user = self.benchmark.name
  23. for idx in range(self.benchmark.user_num):
  24. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  25. test_x, test_y = test_x.values, test_y.values
  26. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  27. user_info = BaseUserInfo(semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec})
  28. logger.info(f"Searching Market for user: {user}_{idx}")
  29. search_result = self.market.search_learnware(user_info, max_search_num=2)
  30. single_result = search_result.get_single_results()
  31. multiple_result = search_result.get_multiple_results()
  32. logger.info(f"search result of user {user}_{idx}:")
  33. logger.info(
  34. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  35. )
  36. pred_y = single_result[0].learnware.predict(test_x)
  37. single_score_list.append(loss_func_rmse(pred_y, test_y))
  38. rmse_list = []
  39. for learnware in all_learnwares:
  40. semantic_spec = learnware.specification.get_semantic_spec()
  41. if semantic_spec["Input"]["Dimension"] == test_x.shape[1]:
  42. pred_y = learnware.predict(test_x)
  43. rmse_list.append(loss_func_rmse(pred_y, test_y))
  44. logger.info(
  45. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[-1]}"
  46. )
  47. if len(multiple_result) > 0:
  48. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  49. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  50. mixture_learnware_list = multiple_result[0].learnwares
  51. else:
  52. mixture_learnware_list = [single_result[0].learnware]
  53. # test reuse (job selector)
  54. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  55. reuse_predict = reuse_baseline.predict(user_data=test_x)
  56. reuse_score = loss_func_rmse(reuse_predict, test_y)
  57. job_selector_score_list.append(reuse_score)
  58. logger.info(f"mixture reuse rmse (job selector): {reuse_score}")
  59. # test reuse (ensemble)
  60. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean")
  61. ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
  62. ensemble_score = loss_func_rmse(ensemble_predict_y, test_y)
  63. ensemble_score_list.append(ensemble_score)
  64. logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}")
  65. learnware_rmse_list.append(rmse_list)
  66. single_list = np.array(learnware_rmse_list)
  67. avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
  68. oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
  69. logger.info(
  70. "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
  71. % (
  72. np.mean(single_score_list),
  73. np.std(single_score_list),
  74. np.mean(avg_score_list),
  75. np.std(avg_score_list),
  76. np.mean(oracle_score_list),
  77. np.std(oracle_score_list),
  78. )
  79. )
  80. logger.info(
  81. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  82. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  83. )
  84. logger.info(
  85. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  86. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  87. )
  88. def labeled_homo_table_example(self, skip_test):
  89. logger.info("Total Item: %d" % (len(self.market)))
  90. methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"]
  91. recorders = {method: Recorder() for method in methods}
  92. user = self.benchmark.name
  93. if not skip_test:
  94. for idx in range(self.benchmark.user_num):
  95. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  96. test_x, test_y = test_x.values, test_y.values
  97. train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
  98. train_x, train_y = train_x.values, train_y.values
  99. train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y)
  100. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  101. user_info = BaseUserInfo(
  102. semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}
  103. )
  104. logger.info(f"Searching Market for user: {user}_{idx}")
  105. search_result = self.market.search_learnware(user_info)
  106. single_result = search_result.get_single_results()
  107. multiple_result = search_result.get_multiple_results()
  108. logger.info(f"search result of user {user}_{idx}:")
  109. logger.info(
  110. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  111. )
  112. if len(multiple_result) > 0:
  113. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  114. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  115. mixture_learnware_list = multiple_result[0].learnwares
  116. else:
  117. mixture_learnware_list = [single_result[0].learnware]
  118. test_info = {
  119. "user": user,
  120. "idx": idx,
  121. "train_subsets": train_subsets,
  122. "test_x": test_x,
  123. "test_y": test_y,
  124. }
  125. common_config = {"learnwares": mixture_learnware_list}
  126. method_configs = {
  127. "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
  128. "homo_single_aug": {"single_learnware": [single_result[0].learnware]},
  129. "homo_ensemble_pruning": common_config,
  130. }
  131. for method_name in methods:
  132. logger.info(f"Testing method {method_name}")
  133. test_info["method_name"] = method_name
  134. test_info.update(method_configs[method_name])
  135. self.test_method(test_info, recorders, loss_func=loss_func_rmse)
  136. for method, recorder in recorders.items():
  137. recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
  138. plot_performance_curves(
  139. self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list
  140. )