|
- import torch
- import unittest
- import os
- import copy
- import joblib
- import zipfile
- import numpy as np
- from sklearn.linear_model import Ridge
- from sklearn.datasets import make_regression
- from shutil import copyfile, rmtree
- from multiprocessing import Pool
- from learnware.client import LearnwareClient
- from sklearn.metrics import mean_squared_error
-
- import learnware
- from learnware.market import instantiate_learnware_market, BaseUserInfo
- from learnware.specification import RKMETableSpecification, generate_rkme_table_spec
- from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser
- from example_learnwares.config import (
- input_shape_list,
- input_description_list,
- output_description_list,
- user_description_list,
- )
-
- curr_root = os.path.dirname(os.path.abspath(__file__))
-
- user_semantic = {
- "Data": {"Values": ["Table"], "Type": "Class"},
- "Task": {
- "Values": ["Regression"],
- "Type": "Class",
- },
- "Library": {"Values": ["Scikit-learn"], "Type": "Class"},
- "Scenario": {"Values": ["Education"], "Type": "Tag"},
- "Description": {"Values": "", "Type": "String"},
- "Name": {"Values": "", "Type": "String"},
- }
-
-
- def check_learnware(learnware_name, dir_path=os.path.join(curr_root, "learnware_pool")):
- print(f"Checking Learnware: {learnware_name}")
- zip_file_path = os.path.join(dir_path, learnware_name)
- client = LearnwareClient()
- # if check_learnware doesn't raise an exception, return True, otherwise, return false
- try:
- client.check_learnware(zip_file_path)
- return True
- except Exception as e:
- print(f"Learnware {learnware_name} failed the check: {e}")
- return False
-
-
- class TestMarket(unittest.TestCase):
- @classmethod
- def setUpClass(cls) -> None:
- np.random.seed(2023)
- learnware.init()
-
- def _init_learnware_market(self, organizer_kwargs=None):
- """initialize learnware market"""
- hetero_market = instantiate_learnware_market(
- market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs
- )
- return hetero_market
-
- def test_prepare_learnware_randomly(self, learnware_num=5):
- self.zip_path_list = []
-
- for i in range(learnware_num):
- dir_path = os.path.join(curr_root, "learnware_pool", "ridge_%d" % (i))
- os.makedirs(dir_path, exist_ok=True)
-
- print("Preparing Learnware: %d" % (i))
-
- example_learnware_idx = i % 2
- input_dim = input_shape_list[example_learnware_idx]
- learnware_example_dir = "example_learnwares"
-
- X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_dim, noise=0.1, random_state=42)
-
- clf = Ridge(alpha=1.0)
- clf.fit(X, y)
-
- joblib.dump(clf, os.path.join(dir_path, "ridge.pkl"))
-
- spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)
- spec.save(os.path.join(dir_path, "stat.json"))
-
- init_file = os.path.join(dir_path, "__init__.py")
- copyfile(
- os.path.join(curr_root, learnware_example_dir, f"model{example_learnware_idx}.py"), init_file
- ) # cp example_init.py init_file
-
- yaml_file = os.path.join(dir_path, "learnware.yaml")
- copyfile(
- os.path.join(curr_root, learnware_example_dir, "learnware.yaml"), yaml_file
- ) # cp example.yaml yaml_file
-
- env_file = os.path.join(dir_path, "requirements.txt")
- copyfile(os.path.join(curr_root, learnware_example_dir, "requirements.txt"), env_file)
-
- zip_file = dir_path + ".zip"
- # zip -q -r -j zip_file dir_path
- with zipfile.ZipFile(zip_file, "w") as zip_obj:
- for foldername, subfolders, filenames in os.walk(dir_path):
- for filename in filenames:
- file_path = os.path.join(foldername, filename)
- zip_info = zipfile.ZipInfo(filename)
- zip_info.compress_type = zipfile.ZIP_STORED
- with open(file_path, "rb") as file:
- zip_obj.writestr(zip_info, file.read())
-
- rmtree(dir_path) # rm -r dir_path
-
- self.zip_path_list.append(zip_file)
-
- def test_generated_learnwares(self):
- curr_root = os.path.dirname(os.path.abspath(__file__))
- dir_path = os.path.join(curr_root, "learnware_pool")
-
- # Execute multi-process checking using Pool
- with Pool() as pool:
- results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)])
-
- # Use an assert statement to ensure that all checks return True
- self.assertTrue(all(results), "Not all learnwares passed the check")
-
- def test_upload_delete_learnware(self, learnware_num=5, delete=True):
- hetero_market = self._init_learnware_market()
- self.test_prepare_learnware_randomly(learnware_num)
- self.learnware_num = learnware_num
-
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == 0, f"The market should be empty!"
-
- for idx, zip_path in enumerate(self.zip_path_list):
- semantic_spec = copy.deepcopy(user_semantic)
- semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
- semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
- semantic_spec["Input"] = input_description_list[idx % 2]
- semantic_spec["Output"] = output_description_list[idx % 2]
- hetero_market.add_learnware(zip_path, semantic_spec)
-
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
- curr_inds = hetero_market.get_learnware_ids()
- print("Available ids After Uploading Learnwares:", curr_inds)
- assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
-
- if delete:
- for learnware_id in curr_inds:
- hetero_market.delete_learnware(learnware_id)
- self.learnware_num -= 1
- assert (
- len(hetero_market) == self.learnware_num
- ), f"The number of learnwares must be {self.learnware_num}!"
-
- curr_inds = hetero_market.get_learnware_ids()
- print("Available ids After Deleting Learnwares:", curr_inds)
- assert len(curr_inds) == 0, f"The market should be empty!"
-
- return hetero_market
-
- def test_train_market_model(self, learnware_num=5):
- hetero_market = self._init_learnware_market(
- organizer_kwargs={"auto_update": False, "auto_update_limit": learnware_num}
- )
- self.test_prepare_learnware_randomly(learnware_num)
- self.learnware_num = learnware_num
-
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == 0, f"The market should be empty!"
-
- for idx, zip_path in enumerate(self.zip_path_list):
- semantic_spec = copy.deepcopy(user_semantic)
- semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
- semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
- semantic_spec["Input"] = input_description_list[idx % 2]
- semantic_spec["Output"] = output_description_list[idx % 2]
- hetero_market.add_learnware(zip_path, semantic_spec)
-
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
- curr_inds = hetero_market.get_learnware_ids()
- print("Available ids After Uploading Learnwares:", curr_inds)
- assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
-
- # organizer=hetero_market.learnware_organizer
- # organizer.train(hetero_market.learnware_organizer.learnware_list.values())
- return hetero_market
-
- def test_search_semantics(self, learnware_num=5):
- hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False)
- print("Total Item:", len(hetero_market))
- assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!"
-
- semantic_spec = copy.deepcopy(user_semantic)
- semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}"
-
- user_info = BaseUserInfo(semantic_spec=semantic_spec)
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- print("User info:", user_info.get_semantic_spec())
- print(f"Search result:")
- assert len(single_result) == 1, f"Exact semantic search failed!"
- for search_item in single_result:
- semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
- print("Choose learnware:", search_item.learnware.id, semantic_spec1)
- assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!"
-
- semantic_spec["Name"]["Values"] = "laernwaer"
- user_info = BaseUserInfo(semantic_spec=semantic_spec)
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- print("User info:", user_info.get_semantic_spec())
- print(f"Search result:")
- assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!"
- for search_item in single_result:
- semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec()
- print("Choose learnware:", search_item.learnware.id, semantic_spec1)
-
- def test_stat_search(self, learnware_num=5):
- hetero_market = self.test_train_market_model(learnware_num)
- print("Total Item:", len(hetero_market))
-
- # hetero test
- print("+++++ HETERO TEST ++++++")
- user_dim = 15
-
- test_folder = os.path.join(curr_root, "test_stat")
-
- for idx, zip_path in enumerate(self.zip_path_list):
- unzip_dir = os.path.join(test_folder, f"{idx}")
-
- # unzip -o -q zip_path -d unzip_dir
- if os.path.exists(unzip_dir):
- rmtree(unzip_dir)
- os.makedirs(unzip_dir, exist_ok=True)
- with zipfile.ZipFile(zip_path, "r") as zip_obj:
- zip_obj.extractall(path=unzip_dir)
-
- user_spec = RKMETableSpecification()
- user_spec.load(os.path.join(unzip_dir, "stat.json"))
- z = user_spec.get_z()
- z = z[:, :user_dim]
- device = user_spec.device
- z = torch.tensor(z, device=device)
- user_spec.z = z
-
- print(">> normal case test:")
- semantic_spec = copy.deepcopy(user_semantic)
- semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2])
- semantic_spec["Input"]["Dimension"] = user_dim
- # keep only the first user_dim descriptions
- semantic_spec["Input"]["Description"] = {
- str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim)
- }
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
-
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
- multiple_result = search_result.get_multiple_results()
-
- print(f"search result of user{idx}:")
- for single_item in single_result:
- print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
-
- for multiple_item in multiple_result:
- print(
- f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}"
- )
-
- # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec
- print(">> test for key 'Task' has empty 'Values':")
- semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"}
-
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- assert len(single_result) == 0, f"Statistical search failed!"
-
- # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type"
- print(">> delele key 'Task' test:")
- semantic_spec.pop("Task")
-
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- assert len(single_result) == 0, f"Statistical search failed!"
-
- # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification."
- print(">> mismatch dim test")
- semantic_spec = copy.deepcopy(user_semantic)
- semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2])
- semantic_spec["Input"]["Dimension"] = user_dim - 2
- semantic_spec["Input"]["Description"] = {
- str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim)
- }
-
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
-
- assert len(single_result) == 0, f"Statistical search failed!"
-
- rmtree(test_folder) # rm -r test_folder
-
- # homo test
- print("\n+++++ HOMO TEST ++++++")
- test_folder = os.path.join(curr_root, "test_stat")
-
- for idx, zip_path in enumerate(self.zip_path_list):
- unzip_dir = os.path.join(test_folder, f"{idx}")
-
- # unzip -o -q zip_path -d unzip_dir
- if os.path.exists(unzip_dir):
- rmtree(unzip_dir)
- os.makedirs(unzip_dir, exist_ok=True)
- with zipfile.ZipFile(zip_path, "r") as zip_obj:
- zip_obj.extractall(path=unzip_dir)
-
- user_spec = RKMETableSpecification()
- user_spec.load(os.path.join(unzip_dir, "stat.json"))
- user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
- multiple_result = search_result.get_multiple_results()
-
- assert len(single_result) >= 1, f"Statistical search failed!"
- print(f"search result of user{idx}:")
- for single_item in single_result:
- print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
-
- for multiple_item in multiple_result:
- print(f"mixture_score: {multiple_item.score}\n")
- mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares])
- print(f"mixture_learnware: {mixture_id}\n")
-
- rmtree(test_folder) # rm -r test_folder
-
- def test_model_reuse(self, learnware_num=5):
- # generate toy regression problem
- X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0)
-
- # generate rkme
- user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0)
-
- # generate specification
- semantic_spec = copy.deepcopy(user_semantic)
- semantic_spec["Input"] = user_description_list[0]
- user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec})
-
- # learnware market search
- hetero_market = self.test_train_market_model(learnware_num)
- search_result = hetero_market.search_learnware(user_info)
- single_result = search_result.get_single_results()
- multiple_result = search_result.get_multiple_results()
- # print search results
- for single_item in single_result:
- print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}")
-
- for multiple_item in multiple_result:
- print(f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}")
-
- # single model reuse
- hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression")
- hetero_learnware.align(user_spec, X[:100], y[:100])
- single_predict_y = hetero_learnware.predict(X)
-
- # multi model reuse
- hetero_learnware_list = []
- for learnware in multiple_result[0].learnwares:
- hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression")
- hetero_learnware.align(user_spec, X[:100], y[:100])
- hetero_learnware_list.append(hetero_learnware)
-
- # Use averaging ensemble reuser to reuse the searched learnwares to make prediction
- reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean")
- ensemble_predict_y = reuse_ensemble.predict(user_data=X)
-
- # Use ensemble pruning reuser to reuse the searched learnwares to make prediction
- reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression")
- reuse_ensemble.fit(X[:100], y[:100])
- ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X)
-
- print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False))
- print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False))
- print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False))
-
-
- def suite():
- _suite = unittest.TestSuite()
- _suite.addTest(TestMarket("test_prepare_learnware_randomly"))
- _suite.addTest(TestMarket("test_generated_learnwares"))
- _suite.addTest(TestMarket("test_upload_delete_learnware"))
- _suite.addTest(TestMarket("test_train_market_model"))
- _suite.addTest(TestMarket("test_search_semantics"))
- _suite.addTest(TestMarket("test_stat_search"))
- _suite.addTest(TestMarket("test_model_reuse"))
- return _suite
-
-
- if __name__ == "__main__":
- runner = unittest.TextTestRunner()
- runner.run(suite())
|