| @@ -44,4 +44,5 @@ tmp/ | |||||
| learnware_pool/ | learnware_pool/ | ||||
| PFS/ | PFS/ | ||||
| data/ | data/ | ||||
| examples/results/ | |||||
| examples/results/ | |||||
| examples/*/results/ | |||||
| @@ -34,13 +34,13 @@ class TableWorkflow: | |||||
| x_train, y_train = sample["x_train"], sample["y_train"] | x_train, y_train = sample["x_train"], sample["y_train"] | ||||
| model = method(x_train, y_train, test_info) | model = method(x_train, y_train, test_info) | ||||
| subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"])) | subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"])) | ||||
| all_scores.append(np.mean(subset_scores)) | |||||
| all_scores.append(subset_scores) | |||||
| return all_scores | return all_scores | ||||
| @staticmethod | @staticmethod | ||||
| def get_train_subsets(train_x, train_y): | |||||
| np.random.seed(1) | |||||
| random.seed(1) | |||||
| def get_train_subsets(idx, train_x, train_y): | |||||
| np.random.seed(idx) | |||||
| random.seed(idx) | |||||
| train_subsets = [] | train_subsets = [] | ||||
| for n_label, repeated in zip(n_labeled_list, n_repeat_list): | for n_label, repeated in zip(n_labeled_list, n_repeat_list): | ||||
| train_subsets.append([]) | train_subsets.append([]) | ||||
| @@ -82,24 +82,21 @@ class TableWorkflow: | |||||
| os.makedirs(save_root_path, exist_ok=True) | os.makedirs(save_root_path, exist_ok=True) | ||||
| save_path = os.path.join(save_root_path, f"{method_name}.json") | save_path = os.path.join(save_root_path, f"{method_name}.json") | ||||
| if method_name == "single_aug": | |||||
| if method_name_full == "hetero_single_aug": | |||||
| if test_info["force"] or recorder.should_test_method(user, idx, save_path): | if test_info["force"] or recorder.should_test_method(user, idx, save_path): | ||||
| for learnware in test_info["learnwares"]: | for learnware in test_info["learnwares"]: | ||||
| test_info["single_learnware"] = [learnware] | test_info["single_learnware"] = [learnware] | ||||
| scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) | scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) | ||||
| recorder.record(user, idx, scores) | |||||
| recorder.record(user, scores) | |||||
| process_single_aug(user, idx, scores, recorders, save_root_path) | process_single_aug(user, idx, scores, recorders, save_root_path) | ||||
| recorder.save(save_path) | recorder.save(save_path) | ||||
| logger.info(f"Method {method_name} on {user}_{idx} finished") | |||||
| else: | else: | ||||
| process_single_aug(user, idx, recorder.data[user][str(idx)], recorders, save_root_path) | |||||
| logger.info(f"Method {method_name} on {user}_{idx} already exists") | |||||
| process_single_aug(user, idx, recorder.data[user], recorders, save_root_path) | |||||
| else: | else: | ||||
| if test_info["force"] or recorder.should_test_method(user, idx, save_path): | if test_info["force"] or recorder.should_test_method(user, idx, save_path): | ||||
| scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) | scores = self._limited_data(test_methods[method_name_full], test_info, loss_func) | ||||
| recorder.record(user, idx, scores) | |||||
| recorder.record(user, scores) | |||||
| recorder.save(save_path) | recorder.save(save_path) | ||||
| logger.info(f"Method {method_name} on {user}_{idx} finished") | |||||
| else: | |||||
| logger.info(f"Method {method_name} on {user}_{idx} already exists") | |||||
| logger.info(f"Method {method_name} on {user}_{idx} finished") | |||||
| @@ -117,8 +117,8 @@ class CorporacionDatasetWorkflow(TableWorkflow): | |||||
| def labeled_homo_table_example(self): | def labeled_homo_table_example(self): | ||||
| logger.info("Total Item: %d" % (len(self.market))) | logger.info("Total Item: %d" % (len(self.market))) | ||||
| methods = ["user_model", "homo_single_aug", "homo_multiple_aug", "homo_multiple_avg", "homo_ensemble_pruning"] | methods = ["user_model", "homo_single_aug", "homo_multiple_aug", "homo_multiple_avg", "homo_ensemble_pruning"] | ||||
| recorders = {method: Recorder() for method in methods} | |||||
| methods_to_retest = [] | methods_to_retest = [] | ||||
| recorders = {method: Recorder() for method in methods} | |||||
| user = self.benchmark.name | user = self.benchmark.name | ||||
| for idx in range(self.benchmark.user_num): | for idx in range(self.benchmark.user_num): | ||||
| @@ -127,7 +127,7 @@ class CorporacionDatasetWorkflow(TableWorkflow): | |||||
| train_x, train_y = self.benchmark.get_train_data(user_ids=idx) | train_x, train_y = self.benchmark.get_train_data(user_ids=idx) | ||||
| train_x, train_y = train_x.values, train_y.values | train_x, train_y = train_x.values, train_y.values | ||||
| train_subsets = self.get_train_subsets(train_x, train_y) | |||||
| train_subsets = self.get_train_subsets(idx, train_x, train_y) | |||||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | user_stat_spec = generate_stat_spec(type="table", X=test_x) | ||||
| user_info = BaseUserInfo( | user_info = BaseUserInfo( | ||||
| @@ -155,7 +155,7 @@ class CorporacionDatasetWorkflow(TableWorkflow): | |||||
| common_config = {"learnwares": mixture_learnware_list} | common_config = {"learnwares": mixture_learnware_list} | ||||
| method_configs = { | method_configs = { | ||||
| "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | ||||
| "homo_single_aug": {"learnwares": [single_result[0].learnware]}, | |||||
| "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, | |||||
| "homo_multiple_aug": common_config, | "homo_multiple_aug": common_config, | ||||
| "homo_multiple_avg": common_config, | "homo_multiple_avg": common_config, | ||||
| "homo_ensemble_pruning": common_config | "homo_ensemble_pruning": common_config | ||||
| @@ -14,18 +14,15 @@ logger = get_module_logger("base_table", level="INFO") | |||||
| class Recorder: | class Recorder: | ||||
| def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]): | def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]): | ||||
| assert len(headers) == len(formats), "Headers and formats length must match." | assert len(headers) == len(formats), "Headers and formats length must match." | ||||
| self.data = defaultdict(lambda: defaultdict(list)) | |||||
| self.data = defaultdict(list) | |||||
| self.headers = headers | self.headers = headers | ||||
| self.formats = formats | self.formats = formats | ||||
| def record(self, user, idx, scores): | |||||
| self.data[user][idx].append(scores) | |||||
| def record(self, user, scores): | |||||
| self.data[user].append(scores) | |||||
| def get_performance_data(self, user): | def get_performance_data(self, user): | ||||
| if user in self.data: | |||||
| return [idx_scores for idx_scores in self.data[user].values()] | |||||
| else: | |||||
| return [] | |||||
| return self.data.get(user, []) | |||||
| def save(self, path): | def save(self, path): | ||||
| with open(path, "w") as f: | with open(path, "w") as f: | ||||
| @@ -38,7 +35,7 @@ class Recorder: | |||||
| def should_test_method(self, user, idx, path): | def should_test_method(self, user, idx, path): | ||||
| if os.path.exists(path): | if os.path.exists(path): | ||||
| self.load(path) | self.load(path) | ||||
| return user not in self.data or str(idx) not in self.data[user] | |||||
| return user not in self.data or idx > len(self.data[user]) - 1 | |||||
| return True | return True | ||||
| @@ -53,7 +50,7 @@ def process_single_aug(user, idx, scores, recorders, root_path): | |||||
| for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], | for method_name, scores in zip(["select_score", "mean_score", "oracle_score"], | ||||
| [select_scores, mean_scores, oracle_scores]): | [select_scores, mean_scores, oracle_scores]): | ||||
| recorders[method_name].record(user, idx, scores) | |||||
| recorders[method_name].record(user, scores) | |||||
| save_path = os.path.join(root_path, f"{method_name}_performance.json") | save_path = os.path.join(root_path, f"{method_name}_performance.json") | ||||
| recorders[method_name].save(save_path) | recorders[method_name].save(save_path) | ||||
| except Exception as e: | except Exception as e: | ||||
| @@ -82,33 +79,43 @@ def analyze_performance(user, recorders): | |||||
| def plot_performance_curves(user, recorders, task, n_labeled_list): | def plot_performance_curves(user, recorders, task, n_labeled_list): | ||||
| plt.figure(figsize=(10, 6)) | plt.figure(figsize=(10, 6)) | ||||
| plt.xticks(range(len(n_labeled_list)), n_labeled_list) | |||||
| for method, recorder in recorders.items(): | for method, recorder in recorders.items(): | ||||
| if method == "hetero_single_aug": | if method == "hetero_single_aug": | ||||
| continue | continue | ||||
| user_data = recorder.get_performance_data(user) | |||||
| if user_data: | |||||
| scores_array = np.array([np.array(lst) for lst in user_data]) | |||||
| mean_scores = np.squeeze(np.mean(scores_array, axis=0)) | |||||
| std_scores = np.squeeze(np.std(scores_array, axis=0)) | |||||
| scores_array = recorder.get_performance_data(user) | |||||
| if scores_array: | |||||
| mean_curve, std_curve = [], [] | |||||
| for i in range(len(n_labeled_list)): | |||||
| sub_scores_array = np.vstack([lst[i] for lst in scores_array]) | |||||
| sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) | |||||
| mean_curve.append(np.mean(sub_scores_mean)) | |||||
| std_curve.append(np.std(sub_scores_mean)) | |||||
| mean_curve = np.array(mean_curve) | |||||
| std_curve = np.array(std_curve) | |||||
| method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method | method_plot = '_'.join(method.split('_')[1:]) if method not in ['user_model', 'oracle_score', 'select_score', 'mean_score'] else method | ||||
| style = styles.get(method_plot, {"color": "black", "linestyle": "-"}) | style = styles.get(method_plot, {"color": "black", "linestyle": "-"}) | ||||
| plt.plot(range(len(n_labeled_list)), mean_scores, label=labels.get(method_plot), **style) | |||||
| std_scale = 0.2 if task == "Hetero" else 0.5 | |||||
| plt.fill_between(range(len(n_labeled_list)), mean_scores - std_scale * std_scores, mean_scores + std_scale * std_scores, color=style["color"], alpha=0.2) | |||||
| plt.xticks(range(len(n_labeled_list)), n_labeled_list) | |||||
| plt.xlabel('Sample Size') | |||||
| plt.ylabel('RMSE') | |||||
| plt.title(f'Table {task} Limited Labeled Data') | |||||
| plt.legend() | |||||
| plt.plot(mean_curve, label=labels.get(method_plot), **style) | |||||
| plt.fill_between( | |||||
| range(len(mean_curve)), | |||||
| mean_curve - std_curve, | |||||
| mean_curve + std_curve, | |||||
| color=style["color"], | |||||
| alpha=0.2 | |||||
| ) | |||||
| plt.xlabel("Amount of Labeled User Data", fontsize=14) | |||||
| plt.ylabel("RMSE", fontsize=14) | |||||
| plt.title(f"Results on Homo Table Experimental Scenario", fontsize=16) | |||||
| plt.legend(fontsize=14) | |||||
| plt.tight_layout() | plt.tight_layout() | ||||
| root_path = os.path.abspath(os.path.join(__file__, "..")) | root_path = os.path.abspath(os.path.join(__file__, "..")) | ||||
| fig_path = os.path.join(root_path, "results", "figs") | fig_path = os.path.join(root_path, "results", "figs") | ||||
| os.makedirs(fig_path, exist_ok=True) | os.makedirs(fig_path, exist_ok=True) | ||||
| plt.savefig(os.path.join(fig_path, f"{user}_labeled_{list(recorders.keys())}.svg"), bbox_inches="tight", dpi=700) | |||||
| plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700) | |||||