Browse Source

[FIX] fix tests

tags/v0.3.2
bxdd 2 years ago
parent
commit
15e6e601d9
1 changed files with 43 additions and 60 deletions
  1. +43
    -60
      tests/test_hetero_market/test_hetero.py

+ 43
- 60
tests/test_hetero_market/test_hetero.py View File

@@ -258,49 +258,40 @@ class TestMarket(unittest.TestCase):
semantic_spec["Input"]["Description"] = { semantic_spec["Input"]["Description"] = {
str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim)
} }

user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)

search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
print(f"search result of user{idx}:") print(f"search result of user{idx}:")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(
f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}"
)
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
for multiple_item in multiple_result:
print(
f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
)


# inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
print(">> test for key 'Task' has empty 'Values':") print(">> test for key 'Task' has empty 'Values':")
semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}


user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()


assert len(single_learnware_list) == 0, f"Statistical search failed!"
assert len(single_result) == 0, f"Statistical search failed!"


# delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
print(">> delele key 'Task' test:") print(">> delele key 'Task' test:")
semantic_spec.pop("Task") semantic_spec.pop("Task")


user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()


assert len(single_learnware_list) == 0, f"Statistical search failed!"
assert len(single_result) == 0, f"Statistical search failed!"


# modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
print(">> mismatch dim test") print(">> mismatch dim test")
@@ -312,14 +303,10 @@ class TestMarket(unittest.TestCase):
} }


user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()


assert len(single_learnware_list) == 0, f"Statistical search failed!"
assert len(single_result) == 0, f"Statistical search failed!"


rmtree(test_folder) # rm -r test_folder rmtree(test_folder) # rm -r test_folder


@@ -340,21 +327,19 @@ class TestMarket(unittest.TestCase):
user_spec = RKMETableSpecification() user_spec = RKMETableSpecification()
user_spec.load(os.path.join(unzip_dir, "stat.json")) user_spec.load(os.path.join(unzip_dir, "stat.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)

target_spec_num = 3 if idx % 2 == 0 else 2
assert len(single_learnware_list) >= 1, f"Statistical search failed!"
search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()

assert len(single_result) >= 1, f"Statistical search failed!"
print(f"search result of user{idx}:") print(f"search result of user{idx}:")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(f"mixture_score: {mixture_score}\n")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
for multiple_item in multiple_result:
print(f"mixture_score: {multiple_item.score}\n")
mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
print(f"mixture_learnware: {mixture_id}\n")


rmtree(test_folder) # rm -r test_folder rmtree(test_folder) # rm -r test_folder


@@ -372,26 +357,24 @@ class TestMarket(unittest.TestCase):


# learnware market search # learnware market search
hetero_market = self.test_train_market_model(learnware_num) hetero_market = self.test_train_market_model(learnware_num)
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = hetero_market.search_learnware(user_info)

search_result = hetero_market.search_learnware(user_info)
single_result = search_result.get_single_results()
multiple_result = search_result.get_multiple_results()
# print search results # print search results
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}")
for single_item in single_result:
print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
for multiple_item in multiple_result:
print(f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}")


# single model reuse # single model reuse
hetero_learnware = HeteroMapAlignLearnware(single_learnware_list[0], mode="regression")
hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
hetero_learnware.align(user_spec, X[:100], y[:100]) hetero_learnware.align(user_spec, X[:100], y[:100])
single_predict_y = hetero_learnware.predict(X) single_predict_y = hetero_learnware.predict(X)


# multi model reuse # multi model reuse
hetero_learnware_list = [] hetero_learnware_list = []
for learnware in mixture_learnware_list:
for learnware in multiple_result[0].learnwares:
hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression")
hetero_learnware.align(user_spec, X[:100], y[:100]) hetero_learnware.align(user_spec, X[:100], y[:100])
hetero_learnware_list.append(hetero_learnware) hetero_learnware_list.append(hetero_learnware)


Loading…
Cancel
Save