You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import os
  2. import json
  3. import traceback
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from collections import defaultdict
  7. from learnware.logger import get_module_logger
  8. from config import *
  9. logger = get_module_logger("base_table", level="INFO")
  10. class Recorder:
  11. def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]):
  12. assert len(headers) == len(formats), "Headers and formats length must match."
  13. self.data = defaultdict(list)
  14. self.headers = headers
  15. self.formats = formats
  16. def record(self, user, scores):
  17. self.data[user].append(scores)
  18. def get_performance_data(self, user):
  19. return self.data.get(user, [])
  20. def save(self, path):
  21. with open(path, "w") as f:
  22. json.dump(self.data, f, indent=4, default=list)
  23. def load(self, path):
  24. with open(path, "r") as f:
  25. self.data = json.load(f, object_hook=lambda x: defaultdict(list, x))
  26. def should_test_method(self, user, idx, path):
  27. if os.path.exists(path):
  28. self.load(path)
  29. return user not in self.data or idx > len(self.data[user]) - 1
  30. return True
  31. def process_single_aug(user, idx, scores, recorders, root_path):
  32. try:
  33. scores_array = np.array(scores)
  34. while scores_array.ndim < 3:
  35. scores_array = scores_array[np.newaxis, :]
  36. select_scores = scores_array[:, 0, :].tolist()
  37. mean_scores = np.mean(scores_array, axis=1).tolist()
  38. oracle_scores = np.min(scores_array, axis=1).tolist()
  39. for method_name, scores in zip(["select_score", "mean_score", "oracle_score"],
  40. [select_scores, mean_scores, oracle_scores]):
  41. recorders[method_name].record(user, scores)
  42. save_path = os.path.join(root_path, f"{method_name}_performance.json")
  43. recorders[method_name].save(save_path)
  44. except Exception as e:
  45. error_message = traceback.format_exc()
  46. logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}")
  47. def analyze_performance(user, recorders):
  48. oracle_score_list = recorders["hetero_oracle_score"].get_performance_data(user)
  49. select_score_list = recorders["hetero_select_score"].get_performance_data(user)
  50. multi_avg_score_list = recorders["hetero_multiple_avg"].get_performance_data(user)
  51. mean_differences = {}
  52. for user_id in range(len(oracle_score_list)):
  53. select_scores = select_score_list[user_id]
  54. oracle_scores = oracle_score_list[user_id]
  55. mean_difference = np.mean(select_scores) - np.mean(oracle_scores)
  56. mean_differences[user_id] = mean_difference
  57. sorted_user_ids = sorted(mean_differences, key=mean_differences.get, reverse=True)
  58. for user_id in sorted_user_ids:
  59. single_multi_diff = np.mean(select_score_list[user_id]) - np.mean(multi_avg_score_list[user_id])
  60. logger.info(f"{user}, {user_id}, {mean_differences[user_id]}, {single_multi_diff}")
  61. def plot_performance_curves(user, recorders, task, n_labeled_list):
  62. plt.figure(figsize=(10, 6))
  63. plt.xticks(range(len(n_labeled_list)), n_labeled_list)
  64. for method, recorder in recorders.items():
  65. if method == "hetero_single_aug":
  66. continue
  67. scores_array = recorder.get_performance_data(user)
  68. if scores_array:
  69. mean_curve, std_curve = [], []
  70. for i in range(len(n_labeled_list)):
  71. sub_scores_array = np.vstack([lst[i] for lst in scores_array])
  72. sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
  73. mean_curve.append(np.mean(sub_scores_mean))
  74. std_curve.append(np.std(sub_scores_mean))
  75. mean_curve = np.array(mean_curve)
  76. std_curve = np.array(std_curve)
  77. method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method
  78. style = styles.get(method_plot, {"color": "black", "linestyle": "-"})
  79. plt.plot(mean_curve, label=labels.get(method_plot), **style)
  80. plt.fill_between(
  81. range(len(mean_curve)),
  82. mean_curve - std_curve,
  83. mean_curve + std_curve,
  84. color=style["color"],
  85. alpha=0.2
  86. )
  87. plt.xlabel("Amount of Labeled User Data", fontsize=14)
  88. plt.ylabel("RMSE", fontsize=14)
  89. plt.title(f"Results on Homo Table Experimental Scenario", fontsize=16)
  90. plt.legend(fontsize=14)
  91. plt.tight_layout()
  92. root_path = os.path.abspath(os.path.join(__file__, ".."))
  93. fig_path = os.path.join(root_path, "results", "figs")
  94. os.makedirs(fig_path, exist_ok=True)
  95. plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700)