|
|
|
@@ -14,18 +14,15 @@ logger = get_module_logger("base_table", level="INFO") |
|
|
|
class Recorder: |
|
|
|
def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]): |
|
|
|
assert len(headers) == len(formats), "Headers and formats length must match." |
|
|
|
self.data = defaultdict(lambda: defaultdict(list)) |
|
|
|
self.data = defaultdict(list) |
|
|
|
self.headers = headers |
|
|
|
self.formats = formats |
|
|
|
|
|
|
|
def record(self, user, idx, scores): |
|
|
|
self.data[user][idx].append(scores) |
|
|
|
def record(self, user, scores): |
|
|
|
self.data[user].append(scores) |
|
|
|
|
|
|
|
def get_performance_data(self, user): |
|
|
|
if user in self.data: |
|
|
|
return [idx_scores for idx_scores in self.data[user].values()] |
|
|
|
else: |
|
|
|
return [] |
|
|
|
return self.data.get(user, []) |
|
|
|
|
|
|
|
def save(self, path): |
|
|
|
with open(path, "w") as f: |
|
|
|
@@ -38,7 +35,7 @@ class Recorder: |
|
|
|
def should_test_method(self, user, idx, path): |
|
|
|
if os.path.exists(path): |
|
|
|
self.load(path) |
|
|
|
return user not in self.data or str(idx) not in self.data[user] |
|
|
|
return user not in self.data or idx > len(self.data[user]) - 1 |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
@@ -53,7 +50,7 @@ def process_single_aug(user, idx, scores, recorders, root_path): |
|
|
|
|
|
|
|
for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], |
|
|
|
[select_scores, mean_scores, oracle_scores]): |
|
|
|
recorders[method_name].record(user, idx, scores) |
|
|
|
recorders[method_name].record(user, scores) |
|
|
|
save_path = os.path.join(root_path, f"{method_name}_performance.json") |
|
|
|
recorders[method_name].save(save_path) |
|
|
|
except Exception as e: |
|
|
|
@@ -82,33 +79,43 @@ def analyze_performance(user, recorders): |
|
|
|
|
|
|
|
def plot_performance_curves(user, recorders, task, n_labeled_list): |
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
|
plt.xticks(range(len(n_labeled_list)), n_labeled_list) |
|
|
|
|
|
|
|
for method, recorder in recorders.items(): |
|
|
|
if method == "hetero_single_aug": |
|
|
|
continue |
|
|
|
|
|
|
|
user_data = recorder.get_performance_data(user) |
|
|
|
|
|
|
|
if user_data: |
|
|
|
scores_array = np.array([np.array(lst) for lst in user_data]) |
|
|
|
mean_scores = np.squeeze(np.mean(scores_array, axis=0)) |
|
|
|
std_scores = np.squeeze(np.std(scores_array, axis=0)) |
|
|
|
scores_array = recorder.get_performance_data(user) |
|
|
|
if scores_array: |
|
|
|
mean_curve, std_curve = [], [] |
|
|
|
for i in range(len(n_labeled_list)): |
|
|
|
sub_scores_array = np.vstack([lst[i] for lst in scores_array]) |
|
|
|
sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) |
|
|
|
mean_curve.append(np.mean(sub_scores_mean)) |
|
|
|
std_curve.append(np.std(sub_scores_mean)) |
|
|
|
|
|
|
|
mean_curve = np.array(mean_curve) |
|
|
|
std_curve = np.array(std_curve) |
|
|
|
|
|
|
|
method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method |
|
|
|
style = styles.get(method_plot, {"color": "black", "linestyle": "-"}) |
|
|
|
plt.plot(range(len(n_labeled_list)), mean_scores, label=labels.get(method_plot), **style) |
|
|
|
|
|
|
|
std_scale = 0.2 if task == "Hetero" else 0.5 |
|
|
|
plt.fill_between(range(len(n_labeled_list)), mean_scores - std_scale * std_scores, mean_scores + std_scale * std_scores, color=style["color"], alpha=0.2) |
|
|
|
|
|
|
|
plt.xticks(range(len(n_labeled_list)), n_labeled_list) |
|
|
|
plt.xlabel('Sample Size') |
|
|
|
plt.ylabel('RMSE') |
|
|
|
plt.title(f'Table {task} Limited Labeled Data') |
|
|
|
plt.legend() |
|
|
|
plt.plot(mean_curve, label=labels.get(method_plot), **style) |
|
|
|
|
|
|
|
plt.fill_between( |
|
|
|
range(len(mean_curve)), |
|
|
|
mean_curve - std_curve, |
|
|
|
mean_curve + std_curve, |
|
|
|
color=style["color"], |
|
|
|
alpha=0.2 |
|
|
|
) |
|
|
|
|
|
|
|
plt.xlabel("Amount of Labeled User Data", fontsize=14) |
|
|
|
plt.ylabel("RMSE", fontsize=14) |
|
|
|
plt.title(f"Results on Homo Table Experimental Scenario", fontsize=16) |
|
|
|
plt.legend(fontsize=14) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
root_path = os.path.abspath(os.path.join(__file__, "..")) |
|
|
|
fig_path = os.path.join(root_path, "results", "figs") |
|
|
|
os.makedirs(fig_path, exist_ok=True) |
|
|
|
plt.savefig(os.path.join(fig_path, f"{user}_labeled_{list(recorders.keys())}.svg"), bbox_inches="tight", dpi=700) |
|
|
|
plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700) |