import os import json import traceback import numpy as np import matplotlib.pyplot as plt from collections import defaultdict from learnware.logger import get_module_logger from config import * logger = get_module_logger("base_table", level="INFO") class Recorder: def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]): assert len(headers) == len(formats), "Headers and formats length must match." self.data = defaultdict(list) self.headers = headers self.formats = formats def record(self, user, scores): self.data[user].append(scores) def get_performance_data(self, user): return self.data.get(user, []) def save(self, path): with open(path, "w") as f: json.dump(self.data, f, indent=4, default=list) def load(self, path): with open(path, "r") as f: self.data = json.load(f, object_hook=lambda x: defaultdict(list, x)) def should_test_method(self, user, idx, path): if os.path.exists(path): self.load(path) return user not in self.data or idx > len(self.data[user]) - 1 return True def process_single_aug(user, idx, scores, recorders, root_path): try: scores_array = np.array(scores) while scores_array.ndim < 3: scores_array = scores_array[np.newaxis, :] select_scores = scores_array[:, 0, :].tolist() mean_scores = np.mean(scores_array, axis=1).tolist() oracle_scores = np.min(scores_array, axis=1).tolist() for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], [select_scores, mean_scores, oracle_scores]): recorders[method_name].record(user, scores) save_path = os.path.join(root_path, f"{method_name}_performance.json") recorders[method_name].save(save_path) except Exception as e: error_message = traceback.format_exc() logger.error(f"Error in process_single_aug for user {user}, idx {idx}: {error_message}") def analyze_performance(user, recorders): oracle_score_list = recorders["hetero_oracle_score"].get_performance_data(user) select_score_list = recorders["hetero_select_score"].get_performance_data(user) multi_avg_score_list = recorders["hetero_multiple_avg"].get_performance_data(user) mean_differences = {} for user_id in range(len(oracle_score_list)): select_scores = select_score_list[user_id] oracle_scores = oracle_score_list[user_id] mean_difference = np.mean(select_scores) - np.mean(oracle_scores) mean_differences[user_id] = mean_difference sorted_user_ids = sorted(mean_differences, key=mean_differences.get, reverse=True) for user_id in sorted_user_ids: single_multi_diff = np.mean(select_score_list[user_id]) - np.mean(multi_avg_score_list[user_id]) logger.info(f"{user}, {user_id}, {mean_differences[user_id]}, {single_multi_diff}") def plot_performance_curves(user, recorders, task, n_labeled_list): plt.figure(figsize=(10, 6)) plt.xticks(range(len(n_labeled_list)), n_labeled_list) for method, recorder in recorders.items(): if method == "hetero_single_aug": continue scores_array = recorder.get_performance_data(user) if scores_array: mean_curve, std_curve = [], [] for i in range(len(n_labeled_list)): sub_scores_array = np.vstack([lst[i] for lst in scores_array]) sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) mean_curve.append(np.mean(sub_scores_mean)) std_curve.append(np.std(sub_scores_mean)) mean_curve = np.array(mean_curve) std_curve = np.array(std_curve) method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method style = styles.get(method_plot, {"color": "black", "linestyle": "-"}) plt.plot(mean_curve, label=labels.get(method_plot), **style) plt.fill_between( range(len(mean_curve)), mean_curve - std_curve, mean_curve + std_curve, color=style["color"], alpha=0.2 ) plt.xlabel("Amount of Labeled User Data", fontsize=14) plt.ylabel("RMSE", fontsize=14) plt.title(f"Results on Homo Table Experimental Scenario", fontsize=16) plt.legend(fontsize=14) plt.tight_layout() root_path = os.path.abspath(os.path.join(__file__, "..")) fig_path = os.path.join(root_path, "results", "figs") os.makedirs(fig_path, exist_ok=True) plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700)