You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hetero.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. import warnings
  3. warnings.filterwarnings("ignore")
  4. import numpy as np
  5. from learnware.logger import get_module_logger
  6. from learnware.specification import generate_stat_spec
  7. from learnware.market import BaseUserInfo
  8. from learnware.reuse import AveragingReuser, FeatureAlignLearnware
  9. from methods import *
  10. from base import TableWorkflow
  11. from config import align_model_params, user_semantic, hetero_n_labeled_list, hetero_n_repeat_list
  12. from utils import Recorder, plot_performance_curves
  13. logger = get_module_logger("hetero_test", level="INFO")
  14. class HeterogeneousDatasetWorkflow(TableWorkflow):
  15. def unlabeled_hetero_table_example(self):
  16. logger.info("Total Item: %d" % len(self.market))
  17. learnware_rmse_list = []
  18. single_score_list = []
  19. ensemble_score_list = []
  20. all_learnwares = self.market.get_learnwares()
  21. user = self.benchmark.name
  22. for idx in range(self.benchmark.user_num):
  23. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  24. test_x, test_y, feature_descriptions = test_x.values, test_y.values, test_x.columns
  25. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  26. input_description = {
  27. "Dimension": len(feature_descriptions),
  28. "Description": {str(i): feature_descriptions[i] for i in range(len(feature_descriptions))}
  29. }
  30. user_semantic["Input"] = input_description
  31. user_info = BaseUserInfo(
  32. semantic_spec=user_semantic, stat_info={user_stat_spec.type: user_stat_spec}
  33. )
  34. logger.info(f"Searching Market for user: {user}_{idx}")
  35. search_result = self.market.search_learnware(user_info, max_search_num=10)
  36. single_result = search_result.get_single_results()
  37. multiple_result = search_result.get_multiple_results()
  38. logger.info(f"hetero search result of user {user}_{idx}: {single_result[0].learnware.id}")
  39. logger.info(
  40. f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}"
  41. )
  42. single_hetero_learnware = FeatureAlignLearnware(single_result[0].learnware, **align_model_params)
  43. single_hetero_learnware.align(user_rkme=user_stat_spec)
  44. pred_y = single_hetero_learnware.predict(test_x)
  45. single_score_list.append(loss_func_rmse(pred_y, test_y))
  46. rmse_list = []
  47. for learnware in all_learnwares:
  48. hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params)
  49. hetero_learnware.align(user_rkme=user_stat_spec)
  50. pred_y = hetero_learnware.predict(test_x)
  51. rmse_list.append(loss_func_rmse(pred_y, test_y))
  52. logger.info(
  53. f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}"
  54. )
  55. if len(multiple_result) > 0:
  56. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  57. logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}")
  58. mixture_learnware_list = []
  59. for learnware in multiple_result[0].learnwares:
  60. hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params)
  61. hetero_learnware.align(user_rkme=user_stat_spec)
  62. mixture_learnware_list.append(hetero_learnware)
  63. else:
  64. hetero_learnware = FeatureAlignLearnware(single_result[0].learnware, **align_model_params)
  65. hetero_learnware.align(user_rkme=user_stat_spec)
  66. mixture_learnware_list = [hetero_learnware]
  67. # test reuse (ensemble)
  68. reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean")
  69. ensemble_predict_y = reuse_ensemble.predict(user_data=test_x)
  70. ensemble_score = loss_func_rmse(ensemble_predict_y, test_y)
  71. ensemble_score_list.append(ensemble_score)
  72. logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}")
  73. learnware_rmse_list.append(rmse_list)
  74. single_list = np.array(learnware_rmse_list)
  75. avg_score_list = [np.mean(lst, axis=0) for lst in single_list]
  76. oracle_score_list = [np.min(lst, axis=0) for lst in single_list]
  77. logger.info(
  78. "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f"
  79. % (
  80. np.mean(single_score_list),
  81. np.std(single_score_list),
  82. np.mean(avg_score_list),
  83. np.std(avg_score_list),
  84. np.mean(oracle_score_list),
  85. np.std(oracle_score_list)
  86. )
  87. )
  88. logger.info(
  89. "Averaging Ensemble Reuse Performance: %.3f +/- %.3f"
  90. % (np.mean(ensemble_score_list), np.std(ensemble_score_list))
  91. )
  92. def labeled_hetero_table_example(self, skip_test=False):
  93. logger.info("Total Items: %d" % len(self.market))
  94. methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"]
  95. recorders = {method: Recorder() for method in methods}
  96. user = self.benchmark.name
  97. if not skip_test:
  98. for idx in range(self.benchmark.user_num):
  99. test_x, test_y = self.benchmark.get_test_data(user_ids=idx)
  100. test_x, test_y = test_x.values, test_y.values
  101. train_x, train_y = self.benchmark.get_train_data(user_ids=idx)
  102. train_x, train_y, feature_descriptions = train_x.values, train_y.values, train_x.columns
  103. train_subsets = self.get_train_subsets(hetero_n_labeled_list, hetero_n_repeat_list, train_x, train_y)
  104. user_stat_spec = generate_stat_spec(type="table", X=test_x)
  105. input_description = {
  106. "Dimension": len(feature_descriptions),
  107. "Description": {str(i): feature_descriptions[i] for i in range(len(feature_descriptions))}
  108. }
  109. user_semantic["Input"] = input_description
  110. user_info = BaseUserInfo(
  111. semantic_spec=user_semantic, stat_info={user_stat_spec.type: user_stat_spec}
  112. )
  113. logger.info(f"Searching Market for user: {user}_{idx}")
  114. search_result = self.market.search_learnware(user_info)
  115. single_result = search_result.get_single_results()
  116. multiple_result = search_result.get_multiple_results()
  117. if len(multiple_result) > 0:
  118. mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares])
  119. logger.info(f"Mixture score: {multiple_result[0].score}, Mixture learnware: {mixture_id}")
  120. mixture_learnware_list = multiple_result[0].learnwares
  121. else:
  122. mixture_learnware_list = [single_result[0].learnware]
  123. logger.info(f"Hetero search result of user {user}_{idx}: mixture learnware num: {len(mixture_learnware_list)}")
  124. test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y, "n_labeled_list": hetero_n_labeled_list}
  125. common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list}
  126. method_configs = {
  127. "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"},
  128. "hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware},
  129. "hetero_multiple_avg": common_config,
  130. "hetero_ensemble_pruning": common_config
  131. }
  132. for method_name in methods:
  133. logger.info(f"Testing method {method_name}")
  134. test_info["method_name"] = method_name
  135. test_info.update(method_configs[method_name])
  136. self.test_method(test_info, recorders, loss_func=loss_func_rmse)
  137. for method, recorder in recorders.items():
  138. recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json"))
  139. plot_performance_curves(self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list)