From 15e6e601d9b0a7936a6da02d9daee0abbcc38a99 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sun, 26 Nov 2023 19:52:16 +0800 Subject: [PATCH] [FIX] fix tests --- tests/test_hetero_market/test_hetero.py | 103 ++++++++++-------------- 1 file changed, 43 insertions(+), 60 deletions(-) diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index a98ab60..be828e5 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -258,49 +258,40 @@ class TestMarket(unittest.TestCase): semantic_spec["Input"]["Description"] = { str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) } - user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) - + + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print( - f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}" - ) + for single_item in single_result: + print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") + + for multiple_item in multiple_result: + print( + f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" + ) # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec print(">> test for key 'Task' has empty 'Values':") semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() - assert len(single_learnware_list) == 0, f"Statistical search failed!" + assert len(single_result) == 0, f"Statistical search failed!" # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" print(">> delele key 'Task' test:") semantic_spec.pop("Task") user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() - assert len(single_learnware_list) == 0, f"Statistical search failed!" + assert len(single_result) == 0, f"Statistical search failed!" # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." print(">> mismatch dim test") @@ -312,14 +303,10 @@ class TestMarket(unittest.TestCase): } user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() - assert len(single_learnware_list) == 0, f"Statistical search failed!" + assert len(single_result) == 0, f"Statistical search failed!" rmtree(test_folder) # rm -r test_folder @@ -340,21 +327,19 @@ class TestMarket(unittest.TestCase): user_spec = RKMETableSpecification() user_spec.load(os.path.join(unzip_dir, "stat.json")) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) - - target_spec_num = 3 if idx % 2 == 0 else 2 - assert len(single_learnware_list) >= 1, f"Statistical search failed!" + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() + + assert len(single_result) >= 1, f"Statistical search failed!" print(f"search result of user{idx}:") - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}\n") - mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) - print(f"mixture_learnware: {mixture_id}\n") + for single_item in single_result: + print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") + + for multiple_item in multiple_result: + print(f"mixture_score: {multiple_item.score}\n") + mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) + print(f"mixture_learnware: {mixture_id}\n") rmtree(test_folder) # rm -r test_folder @@ -372,26 +357,24 @@ class TestMarket(unittest.TestCase): # learnware market search hetero_market = self.test_train_market_model(learnware_num) - ( - sorted_score_list, - single_learnware_list, - mixture_score, - mixture_learnware_list, - ) = hetero_market.search_learnware(user_info) - + search_result = hetero_market.search_learnware(user_info) + single_result = search_result.get_single_results() + multiple_result = search_result.get_multiple_results() # print search results - for score, learnware in zip(sorted_score_list, single_learnware_list): - print(f"score: {score}, learnware_id: {learnware.id}") - print(f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}") + for single_item in single_result: + print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") + + for multiple_item in multiple_result: + print(f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}") # single model reuse - hetero_learnware = HeteroMapAlignLearnware(single_learnware_list[0], mode="regression") + hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") hetero_learnware.align(user_spec, X[:100], y[:100]) single_predict_y = hetero_learnware.predict(X) # multi model reuse hetero_learnware_list = [] - for learnware in mixture_learnware_list: + for learnware in multiple_result[0].learnwares: hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") hetero_learnware.align(user_spec, X[:100], y[:100]) hetero_learnware_list.append(hetero_learnware)