| @@ -105,22 +105,13 @@ class HomogeneousDatasetWorkflow(TableWorkflow): | |||||
| ) | ) | ||||
| def labeled_homo_table_example(self, skip_test=False): | |||||
| def labeled_homo_table_example(self, skip_test=True): | |||||
| logger.info("Total Item: %d" % (len(self.market))) | logger.info("Total Item: %d" % (len(self.market))) | ||||
| methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"] | methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"] | ||||
| methods_to_retest = [] | methods_to_retest = [] | ||||
| recorders = {method: Recorder() for method in methods} | recorders = {method: Recorder() for method in methods} | ||||
| user = self.benchmark.name | user = self.benchmark.name | ||||
| if not skip_test: | |||||
| for idx in range(self.benchmark.user_num): | |||||
| test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | |||||
| test_x, test_y = test_x.values, test_y.values | |||||
| train_x, train_y = self.benchmark.get_train_data(user_ids=idx) | |||||
| train_x, train_y = train_x.values, train_y.values | |||||
| train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) | |||||
| if not skip_test: | if not skip_test: | ||||
| for idx in range(self.benchmark.user_num): | for idx in range(self.benchmark.user_num): | ||||
| test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | ||||
| @@ -130,10 +121,6 @@ class HomogeneousDatasetWorkflow(TableWorkflow): | |||||
| train_x, train_y = train_x.values, train_y.values | train_x, train_y = train_x.values, train_y.values | ||||
| train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) | train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) | ||||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | |||||
| user_info = BaseUserInfo( | |||||
| semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec} | |||||
| ) | |||||
| logger.info(f"Searching Market for user: {user}_{idx}") | logger.info(f"Searching Market for user: {user}_{idx}") | ||||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | user_stat_spec = generate_stat_spec(type="table", X=test_x) | ||||
| user_info = BaseUserInfo( | user_info = BaseUserInfo( | ||||
| @@ -141,28 +128,15 @@ class HomogeneousDatasetWorkflow(TableWorkflow): | |||||
| ) | ) | ||||
| logger.info(f"Searching Market for user: {user}_{idx}") | logger.info(f"Searching Market for user: {user}_{idx}") | ||||
| search_result = self.market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| multiple_result = search_result.get_multiple_results() | |||||
| search_result = self.market.search_learnware(user_info) | search_result = self.market.search_learnware(user_info) | ||||
| single_result = search_result.get_single_results() | single_result = search_result.get_single_results() | ||||
| multiple_result = search_result.get_multiple_results() | multiple_result = search_result.get_multiple_results() | ||||
| logger.info(f"search result of user {user}_{idx}:") | |||||
| logger.info( | |||||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||||
| ) | |||||
| logger.info(f"search result of user {user}_{idx}:") | logger.info(f"search result of user {user}_{idx}:") | ||||
| logger.info( | logger.info( | ||||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | ||||
| ) | ) | ||||
| if len(multiple_result) > 0: | |||||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||||
| logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||||
| mixture_learnware_list = multiple_result[0].learnwares | |||||
| else: | |||||
| mixture_learnware_list = [single_result[0].learnware] | |||||
| if len(multiple_result) > 0: | if len(multiple_result) > 0: | ||||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | ||||
| logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | ||||
| @@ -177,13 +151,6 @@ class HomogeneousDatasetWorkflow(TableWorkflow): | |||||
| "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, | "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, | ||||
| "homo_ensemble_pruning": common_config | "homo_ensemble_pruning": common_config | ||||
| } | } | ||||
| test_info = {"user": user, "idx": idx, "train_subsets": train_subsets, "test_x": test_x, "test_y": test_y} | |||||
| common_config = {"learnwares": mixture_learnware_list} | |||||
| method_configs = { | |||||
| "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | |||||
| "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, | |||||
| "homo_ensemble_pruning": common_config | |||||
| } | |||||
| for method_name in methods: | for method_name in methods: | ||||
| logger.info(f"Testing method {method_name}") | logger.info(f"Testing method {method_name}") | ||||
| @@ -194,15 +161,5 @@ class HomogeneousDatasetWorkflow(TableWorkflow): | |||||
| for method, recorder in recorders.items(): | for method, recorder in recorders.items(): | ||||
| recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | ||||
| for method_name in methods: | |||||
| logger.info(f"Testing method {method_name}") | |||||
| test_info["method_name"] = method_name | |||||
| test_info["force"] = method_name in methods_to_retest | |||||
| test_info.update(method_configs[method_name]) | |||||
| self.test_method(test_info, recorders, loss_func=loss_func_rmse) | |||||
| for method, recorder in recorders.items(): | |||||
| recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | |||||
| plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) | |||||
| plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) | plot_performance_curves(self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list) | ||||