diff --git a/.gitignore b/.gitignore index d8c41d5..f361d74 100644 --- a/.gitignore +++ b/.gitignore @@ -44,4 +44,5 @@ tmp/ learnware_pool/ PFS/ data/ -examples/results/ \ No newline at end of file +examples/results/ +examples/*/results/ \ No newline at end of file diff --git a/examples/dataset_table_workflow/base.py b/examples/dataset_table_workflow/base.py index fe31bd6..39924c8 100644 --- a/examples/dataset_table_workflow/base.py +++ b/examples/dataset_table_workflow/base.py @@ -34,13 +34,13 @@ class TableWorkflow: x_train, y_train = sample["x_train"], sample["y_train"] model = method(x_train, y_train, test_info) subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"])) - all_scores.append(np.mean(subset_scores)) + all_scores.append(subset_scores) return all_scores @staticmethod - def get_train_subsets(train_x, train_y): - np.random.seed(1) - random.seed(1) + def get_train_subsets(idx, train_x, train_y): + np.random.seed(idx) + random.seed(idx) train_subsets = [] for n_label, repeated in zip(n_labeled_list, n_repeat_list): train_subsets.append([]) @@ -82,24 +82,21 @@ class TableWorkflow: os.makedirs(save_root_path, exist_ok=True) save_path = os.path.join(save_root_path, f"{method_name}.json") - if method_name == "single_aug": + if method_name_full == "hetero_single_aug": if test_info["force"] or recorder.should_test_method(user, idx, save_path): for learnware in test_info["learnwares"]: test_info["single_learnware"] = [learnware] scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) - recorder.record(user, idx, scores) + recorder.record(user, scores) process_single_aug(user, idx, scores, recorders, save_root_path) recorder.save(save_path) - logger.info(f"Method {method_name} on {user}_{idx} finished") else: - process_single_aug(user, idx, recorder.data[user][str(idx)], recorders, save_root_path) - logger.info(f"Method {method_name} on {user}_{idx} already exists") + process_single_aug(user, idx, recorder.data[user], recorders, save_root_path) else: if test_info["force"] or recorder.should_test_method(user, idx, save_path): scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) - recorder.record(user, idx, scores) + recorder.record(user, scores) recorder.save(save_path) - logger.info(f"Method {method_name} on {user}_{idx} finished") - else: - logger.info(f"Method {method_name} on {user}_{idx} already exists") \ No newline at end of file + + logger.info(f"Method {method_name} on {user}_{idx} finished") \ No newline at end of file diff --git a/examples/dataset_table_workflow/homo.py b/examples/dataset_table_workflow/homo.py index d6bb8df..03e45a9 100644 --- a/examples/dataset_table_workflow/homo.py +++ b/examples/dataset_table_workflow/homo.py @@ -117,8 +117,8 @@ class CorporacionDatasetWorkflow(TableWorkflow): def labeled_homo_table_example(self): logger.info("Total Item: %d" % (len(self.market))) methods = ["user_model", "homo_single_aug", "homo_multiple_aug", "homo_multiple_avg", "homo_ensemble_pruning"] - recorders = {method: Recorder() for method in methods} methods_to_retest = [] + recorders = {method: Recorder() for method in methods} user = self.benchmark.name for idx in range(self.benchmark.user_num): @@ -127,7 +127,7 @@ class CorporacionDatasetWorkflow(TableWorkflow): train_x, train_y = self.benchmark.get_train_data(user_ids=idx) train_x, train_y = train_x.values, train_y.values - train_subsets = self.get_train_subsets(train_x, train_y) + train_subsets = self.get_train_subsets(idx, train_x, train_y) user_stat_spec = generate_stat_spec(type="table", X=test_x) user_info = BaseUserInfo( @@ -155,7 +155,7 @@ class CorporacionDatasetWorkflow(TableWorkflow): common_config = {"learnwares": mixture_learnware_list} method_configs = { "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, - "homo_single_aug": {"learnwares": [single_result[0].learnware]}, + "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, "homo_multiple_aug": common_config, "homo_multiple_avg": common_config, "homo_ensemble_pruning": common_config diff --git a/examples/dataset_table_workflow/utils.py b/examples/dataset_table_workflow/utils.py index c467b9b..fd31064 100644 --- a/examples/dataset_table_workflow/utils.py +++ b/examples/dataset_table_workflow/utils.py @@ -14,18 +14,15 @@ logger = get_module_logger("base_table", level="INFO") class Recorder: def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]): assert len(headers) == len(formats), "Headers and formats length must match." - self.data = defaultdict(lambda: defaultdict(list)) + self.data = defaultdict(list) self.headers = headers self.formats = formats - def record(self, user, idx, scores): - self.data[user][idx].append(scores) + def record(self, user, scores): + self.data[user].append(scores) def get_performance_data(self, user): - if user in self.data: - return [idx_scores for idx_scores in self.data[user].values()] - else: - return [] + return self.data.get(user, []) def save(self, path): with open(path, "w") as f: @@ -38,7 +35,7 @@ class Recorder: def should_test_method(self, user, idx, path): if os.path.exists(path): self.load(path) - return user not in self.data or str(idx) not in self.data[user] + return user not in self.data or idx > len(self.data[user]) - 1 return True @@ -53,7 +50,7 @@ def process_single_aug(user, idx, scores, recorders, root_path): for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], [select_scores, mean_scores, oracle_scores]): - recorders[method_name].record(user, idx, scores) + recorders[method_name].record(user, scores) save_path = os.path.join(root_path, f"{method_name}_performance.json") recorders[method_name].save(save_path) except Exception as e: @@ -82,33 +79,43 @@ def analyze_performance(user, recorders): def plot_performance_curves(user, recorders, task, n_labeled_list): plt.figure(figsize=(10, 6)) + plt.xticks(range(len(n_labeled_list)), n_labeled_list) for method, recorder in recorders.items(): if method == "hetero_single_aug": continue - user_data = recorder.get_performance_data(user) - - if user_data: - scores_array = np.array([np.array(lst) for lst in user_data]) - mean_scores = np.squeeze(np.mean(scores_array, axis=0)) - std_scores = np.squeeze(np.std(scores_array, axis=0)) + scores_array = recorder.get_performance_data(user) + if scores_array: + mean_curve, std_curve = [], [] + for i in range(len(n_labeled_list)): + sub_scores_array = np.vstack([lst[i] for lst in scores_array]) + sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) + mean_curve.append(np.mean(sub_scores_mean)) + std_curve.append(np.std(sub_scores_mean)) + + mean_curve = np.array(mean_curve) + std_curve = np.array(std_curve) method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method style = styles.get(method_plot, {"color": "black", "linestyle": "-"}) - plt.plot(range(len(n_labeled_list)), mean_scores, label=labels.get(method_plot), **style) - - std_scale = 0.2 if task == "Hetero" else 0.5 - plt.fill_between(range(len(n_labeled_list)), mean_scores - std_scale * std_scores, mean_scores + std_scale * std_scores, color=style["color"], alpha=0.2) - - plt.xticks(range(len(n_labeled_list)), n_labeled_list) - plt.xlabel('Sample Size') - plt.ylabel('RMSE') - plt.title(f'Table {task} Limited Labeled Data') - plt.legend() + plt.plot(mean_curve, label=labels.get(method_plot), **style) + + plt.fill_between( + range(len(mean_curve)), + mean_curve - std_curve, + mean_curve + std_curve, + color=style["color"], + alpha=0.2 + ) + + plt.xlabel("Amount of Labeled User Data", fontsize=14) + plt.ylabel("RMSE", fontsize=14) + plt.title(f"Results on Homo Table Experimental Scenario", fontsize=16) + plt.legend(fontsize=14) plt.tight_layout() root_path = os.path.abspath(os.path.join(__file__, "..")) fig_path = os.path.join(root_path, "results", "figs") os.makedirs(fig_path, exist_ok=True) - plt.savefig(os.path.join(fig_path, f"{user}_labeled_{list(recorders.keys())}.svg"), bbox_inches="tight", dpi=700) \ No newline at end of file + plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700) \ No newline at end of file