| @@ -258,49 +258,40 @@ class TestMarket(unittest.TestCase): | |||||
| semantic_spec["Input"]["Description"] = { | semantic_spec["Input"]["Description"] = { | ||||
| str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | ||||
| } | } | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | ||||
| ( | |||||
| sorted_score_list, | |||||
| single_learnware_list, | |||||
| mixture_score, | |||||
| mixture_learnware_list, | |||||
| ) = hetero_market.search_learnware(user_info) | |||||
| search_result = hetero_market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| multiple_result = search_result.get_multiple_results() | |||||
| print(f"search result of user{idx}:") | print(f"search result of user{idx}:") | ||||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||||
| print( | |||||
| f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}" | |||||
| ) | |||||
| for single_item in single_result: | |||||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||||
| for multiple_item in multiple_result: | |||||
| print( | |||||
| f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" | |||||
| ) | |||||
| # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec | # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec | ||||
| print(">> test for key 'Task' has empty 'Values':") | print(">> test for key 'Task' has empty 'Values':") | ||||
| semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} | semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | ||||
| ( | |||||
| sorted_score_list, | |||||
| single_learnware_list, | |||||
| mixture_score, | |||||
| mixture_learnware_list, | |||||
| ) = hetero_market.search_learnware(user_info) | |||||
| search_result = hetero_market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| assert len(single_learnware_list) == 0, f"Statistical search failed!" | |||||
| assert len(single_result) == 0, f"Statistical search failed!" | |||||
| # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" | # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" | ||||
| print(">> delele key 'Task' test:") | print(">> delele key 'Task' test:") | ||||
| semantic_spec.pop("Task") | semantic_spec.pop("Task") | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | ||||
| ( | |||||
| sorted_score_list, | |||||
| single_learnware_list, | |||||
| mixture_score, | |||||
| mixture_learnware_list, | |||||
| ) = hetero_market.search_learnware(user_info) | |||||
| search_result = hetero_market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| assert len(single_learnware_list) == 0, f"Statistical search failed!" | |||||
| assert len(single_result) == 0, f"Statistical search failed!" | |||||
| # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." | # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." | ||||
| print(">> mismatch dim test") | print(">> mismatch dim test") | ||||
| @@ -312,14 +303,10 @@ class TestMarket(unittest.TestCase): | |||||
| } | } | ||||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | ||||
| ( | |||||
| sorted_score_list, | |||||
| single_learnware_list, | |||||
| mixture_score, | |||||
| mixture_learnware_list, | |||||
| ) = hetero_market.search_learnware(user_info) | |||||
| search_result = hetero_market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| assert len(single_learnware_list) == 0, f"Statistical search failed!" | |||||
| assert len(single_result) == 0, f"Statistical search failed!" | |||||
| rmtree(test_folder) # rm -r test_folder | rmtree(test_folder) # rm -r test_folder | ||||
| @@ -340,21 +327,19 @@ class TestMarket(unittest.TestCase): | |||||
| user_spec = RKMETableSpecification() | user_spec = RKMETableSpecification() | ||||
| user_spec.load(os.path.join(unzip_dir, "stat.json")) | user_spec.load(os.path.join(unzip_dir, "stat.json")) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | ||||
| ( | |||||
| sorted_score_list, | |||||
| single_learnware_list, | |||||
| mixture_score, | |||||
| mixture_learnware_list, | |||||
| ) = hetero_market.search_learnware(user_info) | |||||
| target_spec_num = 3 if idx % 2 == 0 else 2 | |||||
| assert len(single_learnware_list) >= 1, f"Statistical search failed!" | |||||
| search_result = hetero_market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| multiple_result = search_result.get_multiple_results() | |||||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||||
| print(f"search result of user{idx}:") | print(f"search result of user{idx}:") | ||||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||||
| print(f"mixture_score: {mixture_score}\n") | |||||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||||
| print(f"mixture_learnware: {mixture_id}\n") | |||||
| for single_item in single_result: | |||||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||||
| for multiple_item in multiple_result: | |||||
| print(f"mixture_score: {multiple_item.score}\n") | |||||
| mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) | |||||
| print(f"mixture_learnware: {mixture_id}\n") | |||||
| rmtree(test_folder) # rm -r test_folder | rmtree(test_folder) # rm -r test_folder | ||||
| @@ -372,26 +357,24 @@ class TestMarket(unittest.TestCase): | |||||
| # learnware market search | # learnware market search | ||||
| hetero_market = self.test_train_market_model(learnware_num) | hetero_market = self.test_train_market_model(learnware_num) | ||||
| ( | |||||
| sorted_score_list, | |||||
| single_learnware_list, | |||||
| mixture_score, | |||||
| mixture_learnware_list, | |||||
| ) = hetero_market.search_learnware(user_info) | |||||
| search_result = hetero_market.search_learnware(user_info) | |||||
| single_result = search_result.get_single_results() | |||||
| multiple_result = search_result.get_multiple_results() | |||||
| # print search results | # print search results | ||||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||||
| print(f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}") | |||||
| for single_item in single_result: | |||||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||||
| for multiple_item in multiple_result: | |||||
| print(f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}") | |||||
| # single model reuse | # single model reuse | ||||
| hetero_learnware = HeteroMapAlignLearnware(single_learnware_list[0], mode="regression") | |||||
| hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") | |||||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | hetero_learnware.align(user_spec, X[:100], y[:100]) | ||||
| single_predict_y = hetero_learnware.predict(X) | single_predict_y = hetero_learnware.predict(X) | ||||
| # multi model reuse | # multi model reuse | ||||
| hetero_learnware_list = [] | hetero_learnware_list = [] | ||||
| for learnware in mixture_learnware_list: | |||||
| for learnware in multiple_result[0].learnwares: | |||||
| hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") | hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") | ||||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | hetero_learnware.align(user_spec, X[:100], y[:100]) | ||||
| hetero_learnware_list.append(hetero_learnware) | hetero_learnware_list.append(hetero_learnware) | ||||