You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. import json
  3. import random
  4. import torch
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from collections import defaultdict
  8. from learnware.logger import get_module_logger
  9. from config import *
  10. logger = get_module_logger("base_table", level="INFO")
  11. class Recorder:
  12. def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]):
  13. assert len(headers) == len(formats), "Headers and formats length must match."
  14. self.data = defaultdict(list)
  15. self.headers = headers
  16. self.formats = formats
  17. def record(self, user, scores):
  18. self.data[user].append(scores)
  19. def get_performance_data(self, user):
  20. return self.data.get(user, [])
  21. def save(self, path):
  22. with open(path, "w") as f:
  23. json.dump(self.data, f, indent=4, default=list)
  24. def load(self, path):
  25. with open(path, "r") as f:
  26. self.data = json.load(f, object_hook=lambda x: defaultdict(list, x))
  27. def should_test_method(self, user, idx, path):
  28. if os.path.exists(path):
  29. self.load(path)
  30. return user not in self.data or idx > len(self.data[user]) - 1
  31. return True
  32. def plot_performance_curves(path, user, recorders, task, n_labeled_list):
  33. plt.figure(figsize=(10, 6))
  34. plt.xticks(range(len(n_labeled_list)), n_labeled_list)
  35. for method, recorder in recorders.items():
  36. data_path = os.path.join(path, f"{user}/{user}_{method}_performance.json")
  37. recorder.load(data_path)
  38. scores_array = recorder.get_performance_data(user)
  39. mean_curve, std_curve = [], []
  40. for i in range(len(n_labeled_list)):
  41. sub_scores_array = np.vstack([lst[i] for lst in scores_array])
  42. sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
  43. mean_curve.append(np.mean(sub_scores_mean))
  44. std_curve.append(np.std(sub_scores_mean))
  45. mean_curve = np.array(mean_curve)
  46. std_curve = np.array(std_curve)
  47. method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method
  48. style = styles.get(method_plot, {"color": "black", "linestyle": "-"})
  49. plt.plot(mean_curve, label=labels.get(method_plot), **style)
  50. plt.fill_between(
  51. range(len(mean_curve)),
  52. mean_curve - std_curve,
  53. mean_curve + std_curve,
  54. color=style["color"],
  55. alpha=0.2
  56. )
  57. plt.xlabel("Amount of Labeled User Data", fontsize=14)
  58. plt.ylabel("RMSE", fontsize=14)
  59. plt.title(f"Results on {task} Table Experimental Scenario", fontsize=16)
  60. plt.legend(fontsize=12)
  61. plt.tight_layout()
  62. root_path = os.path.abspath(os.path.join(__file__, ".."))
  63. fig_path = os.path.join(root_path, "results", "figs")
  64. os.makedirs(fig_path, exist_ok=True)
  65. plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700)
  66. def set_seed(seed):
  67. random.seed(seed)
  68. os.environ["PYTHONHASHSEED"] = str(seed)
  69. np.random.seed(seed)
  70. torch.manual_seed(seed)
  71. torch.cuda.manual_seed(seed)
  72. torch.cuda.manual_seed_all(seed)
  73. torch.backends.cudnn.benchmark = False
  74. torch.backends.cudnn.deterministic = True