You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hetero.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import warnings
  3. import numpy as np
  4. from base import TableWorkflow
  5. from config import align_model_params, hetero_n_labeled_list, hetero_n_repeat_list, user_semantic
  6. from methods import loss_func_rmse
  7. from utils import Recorder, plot_performance_curves, set_seed
  8. from learnware.logger import get_module_logger
  9. from learnware.market import BaseUserInfo
  10. from learnware.reuse import AveragingReuser, FeatureAlignLearnware
  11. from learnware.specification import generate_stat_spec
  12. warnings.filterwarnings("ignore")
  13. logger = get_module_logger("hetero_test", level="INFO")
  14. class HeterogeneousDatasetWorkflow(TableWorkflow):
  15. def unlabeled_hetero_table_example(self):
  16. set_seed(0)
  17. logger.info("Total Item: %d" % len(self.market))
  18. learnware_rmse_list = []
  19. single_score_list = []
  20. ensemble_score_list = []
  21. all_learnwares = self.market.get_learnwares()
  22. user = self.benchmark.name
  23. for idx in range(self.benchmark.user_num):
  24. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  25. test_x, test_y, feature_descriptions = test_x.values, test_y.values, test_x.columns
  26. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  27. input_description = {
  28. "Dimension": len(feature_descriptions),
  29. "Description": {str(i): feature_descriptions[i] for i in range(len(feature_descriptions))},
  30. }
  31. user_semantic["Input"] = input_description
  32. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={user_stat_spec.type: user_stat_spec})
  33. logger.info(f"Searching Market for user: {user}_{idx}")
  34. search_result = self.market.search_learnware(user_info, search_method="auto")
  35. single_result = search_result.get_single_results()
  36. multiple_result = search_result.get_multiple_results()
  37. logger.info(f"hetero search result of user {user}_{idx}: {single_result[0].learnware.id}")
  38. logger.info(
  39. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  40. )
  41. single_hetero_learnware = FeatureAlignLearnware(single_result[0].learnware, **align_model_params)
  42. single_hetero_learnware.align(user_rkme=user_stat_spec)
  43. pred_y = single_hetero_learnware.predict(test_x)
  44. single_score_list.append(loss_func_rmse(pred_y, test_y))
  45. rmse_list = []
  46. for learnware in all_learnwares:
  47. hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params)
  48. hetero_learnware.align(user_rkme=user_stat_spec)
  49. pred_y = hetero_learnware.predict(test_x)
  50. rmse_list.append(loss_func_rmse(pred_y, test_y))
  51. logger.info(
  52. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}"
  53. )
  54. if len(multiple_result) > 0:
  55. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  56. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  57. mixture_learnware_list = []
  58. for learnware in multiple_result[0].learnwares:
  59. hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params)
  60. hetero_learnware.align(user_rkme=user_stat_spec)
  61. mixture_learnware_list.append(hetero_learnware)
  62. else:
  63. hetero_learnware = FeatureAlignLearnware(single_result[0].learnware, **align_model_params)
  64. hetero_learnware.align(user_rkme=user_stat_spec)
  65. mixture_learnware_list = [hetero_learnware]
  66. # test reuse (ensemble)
  67. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean")
  68. ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
  69. ensemble_score = loss_func_rmse(ensemble_predict_y, test_y)
  70. ensemble_score_list.append(ensemble_score)
  71. logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}")
  72. learnware_rmse_list.append(rmse_list)
  73. single_list = np.array(learnware_rmse_list)
  74. avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
  75. oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
  76. logger.info(
  77. "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
  78. % (
  79. np.mean(single_score_list),
  80. np.std(single_score_list),
  81. np.mean(avg_score_list),
  82. np.std(avg_score_list),
  83. np.mean(oracle_score_list),
  84. np.std(oracle_score_list),
  85. )
  86. )
  87. logger.info(
  88. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  89. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  90. )
  91. def labeled_hetero_table_example(self, skip_test):
  92. set_seed(0)
  93. logger.info("Total Items: %d" % len(self.market))
  94. methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"]
  95. recorders = {method: Recorder() for method in methods}
  96. user = self.benchmark.name
  97. if not skip_test:
  98. for idx in range(self.benchmark.user_num):
  99. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  100. test_x, test_y = test_x.values, test_y.values
  101. train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
  102. train_x, train_y, feature_descriptions = train_x.values, train_y.values, train_x.columns
  103. train_subsets = self.get_train_subsets(hetero_n_labeled_list, hetero_n_repeat_list, train_x, train_y)
  104. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  105. input_description = {
  106. "Dimension": len(feature_descriptions),
  107. "Description": {str(i): feature_descriptions[i] for i in range(len(feature_descriptions))},
  108. }
  109. user_semantic["Input"] = input_description
  110. user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={user_stat_spec.type: user_stat_spec})
  111. logger.info(f"Searching Market for user: {user}_{idx}")
  112. search_result = self.market.search_learnware(user_info)
  113. single_result = search_result.get_single_results()
  114. multiple_result = search_result.get_multiple_results()
  115. if len(multiple_result) > 0:
  116. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  117. logger.info(f"Mixture score: {multiple_result[0].score}, Mixture learnware: {mixture_id}")
  118. mixture_learnware_list = multiple_result[0].learnwares
  119. else:
  120. mixture_learnware_list = [single_result[0].learnware]
  121. logger.info(
  122. f"Hetero search result of user {user}_{idx}: mixture learnware num: {len(mixture_learnware_list)}"
  123. )
  124. test_info = {
  125. "user": user,
  126. "idx": idx,
  127. "train_subsets": train_subsets,
  128. "test_x": test_x,
  129. "test_y": test_y,
  130. "n_labeled_list": hetero_n_labeled_list,
  131. }
  132. common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list}
  133. method_configs = {
  134. "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
  135. "hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware},
  136. "hetero_multiple_avg": common_config,
  137. "hetero_ensemble_pruning": common_config,
  138. }
  139. for method_name in methods:
  140. logger.info(f"Testing method {method_name}")
  141. test_info["method_name"] = method_name
  142. test_info.update(method_configs[method_name])
  143. self.test_method(test_info, recorders, loss_func=loss_func_rmse)
  144. for method, recorder in recorders.items():
  145. recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
  146. plot_performance_curves(
  147. self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list
  148. )