|
- import json
- import os
- import random
- from collections import defaultdict
-
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- from config import labels, styles
-
- from learnware.logger import get_module_logger
-
- logger = get_module_logger("base_table", level="INFO")
-
-
- class Recorder:
- def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]):
- assert len(headers) == len(formats), "Headers and formats length must match."
- self.data = defaultdict(list)
- self.headers = headers
- self.formats = formats
-
- def record(self, user, scores):
- self.data[user].append(scores)
-
- def get_performance_data(self, user):
- return self.data.get(user, [])
-
- def save(self, path):
- with open(path, "w") as f:
- json.dump(self.data, f, indent=4, default=list)
-
- def load(self, path):
- with open(path, "r") as f:
- self.data = json.load(f, object_hook=lambda x: defaultdict(list, x))
-
- def should_test_method(self, user, idx, path):
- if os.path.exists(path):
- self.load(path)
- return user not in self.data or idx > len(self.data[user]) - 1
- return True
-
-
- def plot_performance_curves(path, user, recorders, task, n_labeled_list):
- plt.figure(figsize=(10, 6))
- plt.xticks(range(len(n_labeled_list)), n_labeled_list)
- for method, recorder in recorders.items():
- data_path = os.path.join(path, f"{user}/{user}_{method}_performance.json")
- recorder.load(data_path)
- scores_array = recorder.get_performance_data(user)
-
- mean_curve, std_curve = [], []
- for i in range(len(n_labeled_list)):
- sub_scores_array = np.vstack([lst[i] for lst in scores_array])
- sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0))
- mean_curve.append(np.mean(sub_scores_mean))
- std_curve.append(np.std(sub_scores_mean))
-
- mean_curve = np.array(mean_curve)
- std_curve = np.array(std_curve)
-
- method_plot = (
- "_".join(method.split("_")[1:])
- if method not in ["user_model", "oracle_score", "select_score", "mean_score"]
- else method
- )
- style = styles.get(method_plot, {"color": "black", "linestyle": "-"})
- plt.plot(mean_curve, label=labels.get(method_plot), **style)
-
- plt.fill_between(
- range(len(mean_curve)), mean_curve - std_curve, mean_curve + std_curve, color=style["color"], alpha=0.2
- )
-
- plt.xlabel("Amount of Labeled User Data", fontsize=14)
- plt.ylabel("RMSE", fontsize=14)
- plt.title(f"Results on {task} Table Experimental Scenario", fontsize=16)
- plt.legend(fontsize=12)
- plt.tight_layout()
-
- root_path = os.path.abspath(os.path.join(__file__, ".."))
- fig_path = os.path.join(root_path, "results", "figs")
- os.makedirs(fig_path, exist_ok=True)
- plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700)
-
-
- def set_seed(seed):
- random.seed(seed)
- os.environ["PYTHONHASHSEED"] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.backends.cudnn.benchmark = False
- torch.backends.cudnn.deterministic = True
|