You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

homo.py 8.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import os
  2. import warnings
  3. import numpy as np
  4. warnings.filterwarnings("ignore")
  5. from learnware.market import BaseUserInfo
  6. from learnware.logger import get_module_logger
  7. from learnware.specification import generate_stat_spec
  8. from learnware.reuse import AveragingReuser, JobSelectorReuser, EnsemblePruningReuser
  9. from methods import *
  10. from base import TableWorkflow
  11. from config import homo_n_labeled_list, homo_n_repeat_list
  12. from utils import Recorder, plot_performance_curves
  13. logger = get_module_logger("homo_table", level="INFO")
  14. class HomogeneousDatasetWorkflow(TableWorkflow):
  15. def unlabeled_homo_table_example(self):
  16. logger.info("Total Item: %d" % (len(self.market)))
  17. learnware_rmse_list = []
  18. single_score_list = []
  19. job_selector_score_list = []
  20. ensemble_score_list = []
  21. all_learnwares = self.market.get_learnwares()
  22. user = self.benchmark.name
  23. for idx in range(self.benchmark.user_num):
  24. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  25. test_x, test_y = test_x.values, test_y.values
  26. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  27. user_info = BaseUserInfo(
  28. semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}
  29. )
  30. logger.info(f"Searching Market for user: {user}_{idx}")
  31. search_result = self.market.search_learnware(user_info, max_search_num=2)
  32. single_result = search_result.get_single_results()
  33. multiple_result = search_result.get_multiple_results()
  34. logger.info(f"search result of user {user}_{idx}:")
  35. logger.info(
  36. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  37. )
  38. pred_y = single_result[0].learnware.predict(test_x)
  39. single_score_list.append(loss_func_rmse(pred_y, test_y))
  40. rmse_list = []
  41. for learnware in all_learnwares:
  42. semantic_spec = learnware.specification.get_semantic_spec()
  43. if semantic_spec["Input"]["Dimension"] == test_x.shape[1]:
  44. pred_y = learnware.predict(test_x)
  45. rmse_list.append(loss_func_rmse(pred_y, test_y))
  46. logger.info(
  47. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[-1]}"
  48. )
  49. if len(multiple_result) > 0:
  50. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  51. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  52. mixture_learnware_list = multiple_result[0].learnwares
  53. else:
  54. mixture_learnware_list = [single_result[0].learnware]
  55. # test reuse (job selector)
  56. reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100)
  57. reuse_predict = reuse_baseline.predict(user_data=test_x)
  58. reuse_score = loss_func_rmse(reuse_predict, test_y)
  59. job_selector_score_list.append(reuse_score)
  60. logger.info(f"mixture reuse rmse (job selector): {reuse_score}")
  61. # test reuse (ensemble)
  62. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean")
  63. ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
  64. ensemble_score = loss_func_rmse(ensemble_predict_y, test_y)
  65. ensemble_score_list.append(ensemble_score)
  66. logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}")
  67. learnware_rmse_list.append(rmse_list)
  68. single_list = np.array(learnware_rmse_list)
  69. avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
  70. oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
  71. logger.info(
  72. "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
  73. % (
  74. np.mean(single_score_list),
  75. np.std(single_score_list),
  76. np.mean(avg_score_list),
  77. np.std(avg_score_list),
  78. np.mean(oracle_score_list),
  79. np.std(oracle_score_list),
  80. )
  81. )
  82. logger.info(
  83. "Average Job Selector Reuse Performance: %.3f +/- %.3f"
  84. % (np.mean(job_selector_score_list), np.std(job_selector_score_list))
  85. )
  86. logger.info(
  87. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  88. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  89. )
  90. def labeled_homo_table_example(self, skip_test=True):
  91. logger.info("Total Item: %d" % (len(self.market)))
  92. methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"]
  93. methods_to_retest = []
  94. recorders = {method: Recorder() for method in methods}
  95. user = self.benchmark.name
  96. if not skip_test:
  97. for idx in range(self.benchmark.user_num):
  98. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  99. test_x, test_y = test_x.values, test_y.values
  100. train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
  101. train_x, train_y = train_x.values, train_y.values
  102. train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y)
  103. logger.info(f"Searching Market for user: {user}_{idx}")
  104. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  105. user_info = BaseUserInfo(
  106. semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}
  107. )
  108. logger.info(f"Searching Market for user: {user}_{idx}")
  109. search_result = self.market.search_learnware(user_info)
  110. single_result = search_result.get_single_results()
  111. multiple_result = search_result.get_multiple_results()
  112. logger.info(f"search result of user {user}_{idx}:")
  113. logger.info(
  114. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  115. )
  116. if len(multiple_result) > 0:
  117. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  118. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  119. mixture_learnware_list = multiple_result[0].learnwares
  120. else:
  121. mixture_learnware_list = [single_result[0].learnware]
  122. test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y}
  123. common_config = {"learnwares": mixture_learnware_list}
  124. method_configs = {
  125. "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
  126. "homo_single_aug": {"single_learnware": [single_result[0].learnware]},
  127. "homo_ensemble_pruning": common_config
  128. }
  129. for method_name in methods:
  130. logger.info(f"Testing method {method_name}")
  131. test_info["method_name"] = method_name
  132. test_info["force"] = method_name in methods_to_retest
  133. test_info.update(method_configs[method_name])
  134. self.test_method(test_info, recorders, loss_func=loss_func_rmse)
  135. for method, recorder in recorders.items():
  136. recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
  137. plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list)