You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import json
  2. import os
  3. import random
  4. from collections import defaultdict
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import torch
  8. from config import labels, styles
  9. from learnware.logger import get_module_logger
  10. logger = get_module_logger("base_table", level="INFO")
  11. class Recorder:
  12. def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]):
  13. assert len(headers) == len(formats), "Headers and formats length must match."
  14. self.data = defaultdict(list)
  15. self.headers = headers
  16. self.formats = formats
  17. def record(self, user, scores):
  18. self.data[user].append(scores)
  19. def get_performance_data(self, user):
  20. return self.data.get(user, [])
  21. def save(self, path):
  22. with open(path, "w") as f:
  23. json.dump(self.data, f, indent=4, default=list)
  24. def load(self, path):
  25. with open(path, "r") as f:
  26. self.data = json.load(f, object_hook=lambda x: defaultdict(list, x))
  27. def should_test_method(self, user, idx, path):
  28. if os.path.exists(path):
  29. self.load(path)
  30. return user not in self.data or idx > len(self.data[user]) - 1
  31. return True
  32. def plot_performance_curves(path, user, recorders, task, n_labeled_list):
  33. plt.figure(figsize=(10, 6))
  34. plt.xticks(range(len(n_labeled_list)), n_labeled_list)
  35. for method, recorder in recorders.items():
  36. data_path = os.path.join(path, f"{user}/{user}_{method}_performance.json")
  37. recorder.load(data_path)
  38. scores_array = recorder.get_performance_data(user)
  39. mean_curve, std_curve = [], []
  40. for i in range(len(n_labeled_list)):
  41. sub_scores_array = np.vstack([lst[i] for lst in scores_array])
  42. sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
  43. mean_curve.append(np.mean(sub_scores_mean))
  44. std_curve.append(np.std(sub_scores_mean))
  45. mean_curve = np.array(mean_curve)
  46. std_curve = np.array(std_curve)
  47. method_plot = (
  48. "_".join(method.split("_")[1:])
  49. if method not in ["user_model", "oracle_score", "select_score", "mean_score"]
  50. else method
  51. )
  52. style = styles.get(method_plot, {"color": "black", "linestyle": "-"})
  53. plt.plot(mean_curve, label=labels.get(method_plot), **style)
  54. plt.fill_between(
  55. range(len(mean_curve)), mean_curve - std_curve, mean_curve + std_curve, color=style["color"], alpha=0.2
  56. )
  57. plt.xlabel("Amount of Labeled User Data", fontsize=14)
  58. plt.ylabel("RMSE", fontsize=14)
  59. plt.title(f"Results on {task} Table Experimental Scenario", fontsize=16)
  60. plt.legend(fontsize=12)
  61. plt.tight_layout()
  62. root_path = os.path.abspath(os.path.join(__file__, ".."))
  63. fig_path = os.path.join(root_path, "results", "figs")
  64. os.makedirs(fig_path, exist_ok=True)
  65. plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700)
  66. def set_seed(seed):
  67. random.seed(seed)
  68. os.environ["PYTHONHASHSEED"] = str(seed)
  69. np.random.seed(seed)
  70. torch.manual_seed(seed)
  71. torch.cuda.manual_seed(seed)
  72. torch.cuda.manual_seed_all(seed)
  73. torch.backends.cudnn.benchmark = False
  74. torch.backends.cudnn.deterministic = True