|
- import logging
- import os
- import pickle
- import tempfile
- import unittest
- import zipfile
-
- import torch
- from hetero_config import input_description_list, input_shape_list, output_description_list, user_description_list
- from sklearn.datasets import make_regression
- from sklearn.linear_model import Ridge
- from sklearn.metrics import mean_squared_error
-
- import learnware
- from learnware.market import BaseUserInfo, instantiate_learnware_market
- from learnware.reuse import AveragingReuser, EnsemblePruningReuser, HeteroMapAlignLearnware
- from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec
- from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate
-
- learnware.init(logging_level=logging.WARNING)
- curr_root = os.path.dirname(os.path.abspath(__file__))
-
-
- class TestHeteroWorkflow(unittest.TestCase):
- universal_semantic_config = {
- "data_type": "Table",
- "task_type": "Regression",
- "library_type": "Scikit-learn",
- "scenarios": "Education",
- "license": "MIT",
- }
-
- def _init_learnware_market(self, organizer_kwargs=None):
- """initialize learnware market"""
- hetero_market = instantiate_learnware_market(
- market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
- )
- return hetero_market
-
- def test_prepare_learnware_randomly(self, learnware_num=5):
- self.zip_path_list = []
-
- for i in range(learnware_num):
- learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero")
- os.makedirs(learnware_pool_dirpath, exist_ok=True)
- learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i))
-
- print("Preparing Learnware: %d" % (i))
-
- X, y = make_regression(
- n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42
- )
- clf = Ridge(alpha=1.0)
- clf.fit(X, y)
- pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl")
- with open(pickle_filepath, "wb") as fout:
- pickle.dump(clf, fout)
-
- spec = generate_rkme_table_spec(X=X, gamma=0.1)
- spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json")
- spec.save(spec_filepath)
-
- LearnwareTemplate.generate_learnware_zipfile(
- learnware_zippath=learnware_zippath,
- model_template=PickleModelTemplate(
- pickle_filepath=pickle_filepath,
- model_kwargs={"input_shape": (input_shape_list[i % 2],), "output_shape": (1,)},
- ),
- stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"),
- requirements=["scikit-learn==0.22"],
- )
-
- self.zip_path_list.append(learnware_zippath)
-
- def _upload_delete_learnware(self, hetero_market, learnware_num, delete):
- self.test_prepare_learnware_randomly(learnware_num)
- self.learnware_num = learnware_num
-
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == 0, "The market should be empty!"
-
- for idx, zip_path in enumerate(self.zip_path_list):
- semantic_spec = generate_semantic_spec(
- name=f"learnware_{idx}",
- description=f"test_learnware_number_{idx}",
- input_description=input_description_list[idx % 2],
- output_description=output_description_list[idx % 2],
- **self.universal_semantic_config,
- )
- hetero_market.add_learnware(zip_path, semantic_spec)
-
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
- curr_inds = hetero_market.get_learnware_ids()
- print("Available ids After Uploading Learnwares:", curr_inds)
- assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
-
- if delete:
- for learnware_id in curr_inds:
- hetero_market.delete_learnware(learnware_id)
- self.learnware_num -= 1
- assert (
- len(hetero_market) == self.learnware_num
- ), f"The number of learnwares must be {self.learnware_num}!"
-
- curr_inds = hetero_market.get_learnware_ids()
- print("Available ids After Deleting Learnwares:", curr_inds)
- assert len(curr_inds) == 0, "The market should be empty!"
-
- return hetero_market
-
- def test_upload_delete_learnware(self, learnware_num=5, delete=True):
- hetero_market = self._init_learnware_market()
- return self._upload_delete_learnware(hetero_market, learnware_num, delete)
-
- def test_train_market_model(self, learnware_num=5, delete=False):
- hetero_market = self._init_learnware_market(
- organizer_kwargs={"auto_update": True, "auto_update_limit": learnware_num}
- )
- hetero_market = self._upload_delete_learnware(hetero_market, learnware_num, delete)
- # organizer=hetero_market.learnware_organizer
- # organizer.train(hetero_market.learnware_organizer.learnware_list.values())
- return hetero_market
-
- def test_search_semantics(self, learnware_num=5):
- hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False)
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
-
- semantic_spec = generate_semantic_spec(
- name=f"learnware_{learnware_num - 1}",
- **self.universal_semantic_config,
- )
-
- user_info = BaseUserInfo(semantic_spec=semantic_spec)
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- print("Search result1:")
- assert len(single_result) == 1, "Exact semantic search failed!"
- for search_item in single_result:
- semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
- print("Choose learnware:", search_item.learnware.id)
- assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], "Exact semantic search failed!"
-
- semantic_spec["Name"]["Values"] = "laernwaer"
- user_info = BaseUserInfo(semantic_spec=semantic_spec)
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- print("Search result2:")
- assert len(single_result) == self.learnware_num, "Fuzzy semantic search failed!"
- for search_item in single_result:
- print("Choose learnware:", search_item.learnware.id)
-
- def test_hetero_stat_search(self, learnware_num=5):
- hetero_market = self.test_train_market_model(learnware_num, delete=False)
- print("Total Item:", len(hetero_market))
-
- user_dim = 15
-
- with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
- for idx, zip_path in enumerate(self.zip_path_list):
- with zipfile.ZipFile(zip_path, "r") as zip_obj:
- zip_obj.extractall(path=test_folder)
-
- user_spec = RKMETableSpecification()
- user_spec.load(os.path.join(test_folder, "stat_spec.json"))
- z = user_spec.get_z()
- z = z[:, :user_dim]
- device = user_spec.device
- z = torch.tensor(z, device=device)
- user_spec.z = z
-
- print(">> normal case test:")
- semantic_spec = generate_semantic_spec(
- input_description={
- "Dimension": user_dim,
- "Description": {
- str(key): input_description_list[idx % 2]["Description"][str(key)]
- for key in range(user_dim)
- },
- },
- **self.universal_semantic_config,
- )
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
- multiple_result = search_result.get_multiple_results()
-
- print(f"search result of user{idx}:")
- for single_item in single_result:
- print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
-
- for multiple_item in multiple_result:
- print(
- f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
- )
-
- # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
- print(">> test for key 'Task' has empty 'Values':")
- semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- assert len(single_result) == 0, "Statistical search failed!"
-
- # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
- print(">> delele key 'Task' test:")
- semantic_spec.pop("Task")
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- assert len(single_result) == 0, "Statistical search failed!"
-
- # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
- print(">> mismatch dim test")
- semantic_spec = generate_semantic_spec(
- input_description={
- "Dimension": user_dim - 2,
- "Description": {
- str(key): input_description_list[idx % 2]["Description"][str(key)]
- for key in range(user_dim)
- },
- },
- **self.universal_semantic_config,
- )
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- assert len(single_result) == 0, "Statistical search failed!"
-
- def test_homo_stat_search(self, learnware_num=5):
- hetero_market = self.test_train_market_model(learnware_num, delete=False)
- print("Total Item:", len(hetero_market))
-
- with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder:
- for idx, zip_path in enumerate(self.zip_path_list):
- with zipfile.ZipFile(zip_path, "r") as zip_obj:
- zip_obj.extractall(path=test_folder)
-
- user_spec = RKMETableSpecification()
- user_spec.load(os.path.join(test_folder, "stat_spec.json"))
- user_semantic = generate_semantic_spec(**self.universal_semantic_config)
- user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
- multiple_result = search_result.get_multiple_results()
-
- assert len(single_result) >= 1, "Statistical search failed!"
- print(f"search result of user{idx}:")
- for single_item in single_result:
- print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
-
- for multiple_item in multiple_result:
- print(f"mixture_score: {multiple_item.score}\n")
- mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
- print(f"mixture_learnware: {mixture_id}\n")
-
- def test_model_reuse(self, learnware_num=5):
- # generate toy regression problem
- X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)
-
- # generate rkme
- user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)
-
- # generate specification
- semantic_spec = generate_semantic_spec(
- input_description=user_description_list[0], **self.universal_semantic_config
- )
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
-
- # learnware market search
- hetero_market = self.test_train_market_model(learnware_num, delete=False)
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
- multiple_result = search_result.get_multiple_results()
-
- # print search results
- for single_item in single_result:
- print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
-
- for multiple_item in multiple_result:
- print(
- f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
- )
-
- # single model reuse
- hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
- hetero_learnware.align(user_spec, X[:100], y[:100])
- single_predict_y = hetero_learnware.predict(X)
-
- # multi model reuse
- hetero_learnware_list = []
- for org_learnware in multiple_result[0].learnwares:
- hetero_learnware = HeteroMapAlignLearnware(org_learnware, mode="regression")
- hetero_learnware.align(user_spec, X[:100], y[:100])
- hetero_learnware_list.append(hetero_learnware)
-
- # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
- reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean")
- ensemble_predict_y = reuse_ensemble.predict(user_data=X)
-
- # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
- reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression")
- reuse_ensemble.fit(X[:100], y[:100])
- ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X)
-
- print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False))
- print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False))
- print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False))
-
-
- def suite():
- _suite = unittest.TestSuite()
- # _suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly"))
- # _suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware"))
- # _suite.addTest(TestHeteroWorkflow("test_train_market_model"))
- _suite.addTest(TestHeteroWorkflow("test_search_semantics"))
- _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search"))
- _suite.addTest(TestHeteroWorkflow("test_homo_stat_search"))
- _suite.addTest(TestHeteroWorkflow("test_model_reuse"))
- return _suite
-
-
- if __name__ == "__main__":
- runner = unittest.TextTestRunner(verbosity=2)
- runner.run(suite())
|