[ENH] add search result classtags/v0.3.2
| @@ -1,8 +1,7 @@ | |||
| .. _dev: | |||
| ============= | |||
| Code Standard | |||
| ============= | |||
| ================ | |||
| For Developer | |||
| ================ | |||
| Docstring | |||
| ============ | |||
| @@ -13,7 +13,7 @@ from learnware.learnware import Learnware | |||
| import time | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.market import database_ops | |||
| from learnware.market.easy import database_ops | |||
| from learnware.learnware import Learnware | |||
| import learnware.specification as specification | |||
| from learnware.logger import get_module_logger | |||
| @@ -168,15 +168,14 @@ def test_search(gamma=0.1, load_market=True): | |||
| user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}) | |||
| logger.info("Searching Market for user: %d" % i) | |||
| sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( | |||
| user_info | |||
| ) | |||
| search_result = image_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| acc_list = [] | |||
| for idx, (score, learnware) in enumerate(zip(sorted_score_list[:5], single_learnware_list[:5])): | |||
| pred_y = learnware.predict(user_data) | |||
| for idx, single_item in enumerate(single_result[:5]): | |||
| pred_y = single_item.learnware.predict(user_data) | |||
| acc = eval_prediction(pred_y, user_label) | |||
| acc_list.append(acc) | |||
| logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, score, learnware.id, acc)) | |||
| logger.info("Search rank: %d, score: %.3f, learnware_id: %s, acc: %.3f" % (idx, single_item.score, single_item.learnware.id, acc)) | |||
| # test reuse (job selector) | |||
| # reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) | |||
| @@ -186,6 +185,7 @@ def test_search(gamma=0.1, load_market=True): | |||
| # print(f"mixture reuse loss: {reuse_score}") | |||
| # test reuse (ensemble) | |||
| single_learnware_list = [single_item.learnware for single_item in single_result] | |||
| reuse_ensemble = AveragingReuser(learnware_list=single_learnware_list[:3], mode="vote_by_prob") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=user_data) | |||
| ensemble_score = eval_prediction(ensemble_predict_y, user_label) | |||
| @@ -155,29 +155,28 @@ class M5DatasetWorkflow: | |||
| user_spec.save(user_spec_path) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{idx}:") | |||
| print( | |||
| f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}" | |||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||
| ) | |||
| loss_list = [] | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| pred_y = learnware.predict(test_x) | |||
| for single_item in single_result: | |||
| pred_y = single_item.learnware.predict(test_x) | |||
| loss_list.append(m5.score(test_y, pred_y)) | |||
| print( | |||
| f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, loss: {loss_list[0]}" | |||
| f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, loss: {loss_list[0]}" | |||
| ) | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}") | |||
| if not mixture_learnware_list: | |||
| mixture_learnware_list = [single_learnware_list[0]] | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||
| mixture_learnware_list = multiple_result[0].learnwares | |||
| else: | |||
| mixture_learnware_list = [single_result[0].learnware] | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False) | |||
| job_selector_predict_y = reuse_job_selector.predict(user_data=test_x) | |||
| @@ -152,29 +152,28 @@ class PFSDatasetWorkflow: | |||
| user_spec.save(user_spec_path) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{idx}:") | |||
| print( | |||
| f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}" | |||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||
| ) | |||
| loss_list = [] | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| pred_y = learnware.predict(test_x) | |||
| for single_item in single_result: | |||
| pred_y = single_item.learnware.predict(test_x) | |||
| loss_list.append(pfs.score(test_y, pred_y)) | |||
| print( | |||
| f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, loss: {loss_list[0]}, random: {np.mean(loss_list)}" | |||
| f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, loss: {loss_list[0]}, random: {np.mean(loss_list)}" | |||
| ) | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}") | |||
| if not mixture_learnware_list: | |||
| mixture_learnware_list = [single_learnware_list[0]] | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||
| mixture_learnware_list = multiple_result[0].learnwares | |||
| else: | |||
| mixture_learnware_list = [single_result[0].learnware] | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list, use_herding=False) | |||
| job_selector_predict_y = reuse_job_selector.predict(user_data=test_x) | |||
| @@ -199,31 +199,34 @@ class TextDatasetWorkflow: | |||
| user_stat_spec.generate_stat_spec_from_data(X=user_data) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETextSpecification": user_stat_spec}) | |||
| logger.info("Searching Market for user: %d" % (i)) | |||
| sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = text_market.search_learnware( | |||
| user_info | |||
| ) | |||
| search_result = text_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{i}:") | |||
| print( | |||
| f"single model num: {len(sorted_score_list)}, max_score: {sorted_score_list[0]}, min_score: {sorted_score_list[-1]}" | |||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||
| ) | |||
| l = len(sorted_score_list) | |||
| l = len(single_result) | |||
| acc_list = [] | |||
| for idx in range(l): | |||
| learnware = single_learnware_list[idx] | |||
| score = sorted_score_list[idx] | |||
| learnware = single_result[idx].learnware | |||
| score = single_result[idx].score | |||
| pred_y = learnware.predict(user_data) | |||
| acc = eval_prediction(pred_y, user_label) | |||
| acc_list.append(acc) | |||
| print( | |||
| f"Top1-score: {sorted_score_list[0]}, learnware_id: {single_learnware_list[0].id}, acc: {acc_list[0]}" | |||
| f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, acc: {acc_list[0]}" | |||
| ) | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_score: {mixture_score}, mixture_learnware: {mixture_id}") | |||
| if not mixture_learnware_list: | |||
| mixture_learnware_list = [single_learnware_list[0]] | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| print(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||
| mixture_learnware_list = multiple_result[0].learnwares | |||
| else: | |||
| mixture_learnware_list = [single_result[0].learnware] | |||
| # test reuse (job selector) | |||
| reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) | |||
| @@ -1,10 +0,0 @@ | |||
| ## How to Generate Environment Yaml | |||
| * create env config for conda: | |||
| ```shell | |||
| conda env export | grep -v "^prefix: " > environment.yml | |||
| ``` | |||
| * recover env from config | |||
| ``` | |||
| conda env create -f environment.yml | |||
| ``` | |||
| @@ -1,27 +0,0 @@ | |||
| name: learnware_example_env | |||
| channels: | |||
| - defaults | |||
| dependencies: | |||
| - _libgcc_mutex=0.1=main | |||
| - _openmp_mutex=5.1=1_gnu | |||
| - ca-certificates=2023.01.10=h06a4308_0 | |||
| - ld_impl_linux-64=2.38=h1181459_1 | |||
| - libffi=3.4.2=h6a678d5_6 | |||
| - libgcc-ng=11.2.0=h1234567_1 | |||
| - libgomp=11.2.0=h1234567_1 | |||
| - libstdcxx-ng=11.2.0=h1234567_1 | |||
| - ncurses=6.4=h6a678d5_0 | |||
| - openssl=1.1.1t=h7f8727e_0 | |||
| - pip=23.0.1=py38h06a4308_0 | |||
| - python=3.8.16=h7a1cb2a_3 | |||
| - readline=8.2=h5eee18b_0 | |||
| - setuptools=66.0.0=py38h06a4308_0 | |||
| - sqlite=3.41.2=h5eee18b_0 | |||
| - tk=8.6.12=h1ccaba5_0 | |||
| - wheel=0.38.4=py38h06a4308_0 | |||
| - xz=5.2.10=h5eee18b_1 | |||
| - zlib=1.2.13=h5eee18b_0 | |||
| - pip: | |||
| - joblib==1.2.0 | |||
| - learnware==0.0.1.99 | |||
| - numpy==1.19.5 | |||
| @@ -1,8 +0,0 @@ | |||
| model: | |||
| class_name: SVM | |||
| kwargs: {} | |||
| stat_specifications: | |||
| - module_path: learnware.specification | |||
| class_name: RKMETableSpecification | |||
| file_name: svm.json | |||
| kwargs: {} | |||
| @@ -1,20 +0,0 @@ | |||
| import os | |||
| import joblib | |||
| import numpy as np | |||
| from learnware.model import BaseModel | |||
| class SVM(BaseModel): | |||
| def __init__(self): | |||
| super(SVM, self).__init__(input_shape=(64,), output_shape=(10,)) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| self.model = joblib.load(os.path.join(dir_path, "svm.pkl")) | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict_proba(X) | |||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| @@ -1,197 +0,0 @@ | |||
| import os | |||
| import fire | |||
| import copy | |||
| import joblib | |||
| import zipfile | |||
| import numpy as np | |||
| from sklearn import svm | |||
| from sklearn.datasets import load_digits | |||
| from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser | |||
| from learnware.specification import generate_rkme_table_spec, RKMETableSpecification | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Classification"], | |||
| "Type": "Class", | |||
| }, | |||
| "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, | |||
| "Scenario": {"Values": ["Education"], "Type": "Tag"}, | |||
| "Description": {"Values": "", "Type": "String"}, | |||
| "Name": {"Values": "", "Type": "String"}, | |||
| } | |||
| class LearnwareMarketWorkflow: | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| learnware.init() | |||
| np.random.seed(2023) | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits", name="easy", rebuild=True) | |||
| return easy_market | |||
| def prepare_learnware_randomly(self, learnware_num=5): | |||
| self.zip_path_list = [] | |||
| X, y = load_digits(return_X_y=True) | |||
| for i in range(learnware_num): | |||
| dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i)) | |||
| os.makedirs(dir_path, exist_ok=True) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| clf = svm.SVC(kernel="linear", probability=True) | |||
| clf.fit(data_X, data_y) | |||
| joblib.dump(clf, os.path.join(dir_path, "svm.pkl")) | |||
| spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| spec.save(os.path.join(dir_path, "svm.json")) | |||
| init_file = os.path.join(dir_path, "__init__.py") | |||
| copyfile( | |||
| os.path.join(curr_root, "learnware_example/example_init.py"), init_file | |||
| ) # cp example_init.py init_file | |||
| yaml_file = os.path.join(dir_path, "learnware.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file | |||
| zip_file = dir_path + ".zip" | |||
| # zip -q -r -j zip_file dir_path | |||
| with zipfile.ZipFile(zip_file, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(dir_path): | |||
| for filename in filenames: | |||
| file_path = os.path.join(foldername, filename) | |||
| zip_info = zipfile.ZipInfo(filename) | |||
| zip_info.compress_type = zipfile.ZIP_STORED | |||
| with open(file_path, "rb") as file: | |||
| zip_obj.writestr(zip_info, file.read()) | |||
| rmtree(dir_path) # rm -r dir_path | |||
| self.zip_path_list.append(zip_file) | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=False): | |||
| easy_market = self._init_learnware_market() | |||
| self.prepare_learnware_randomly(learnware_num) | |||
| print("Total Item:", len(easy_market)) | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(easy_market)) | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| easy_market.delete_learnware(learnware_id) | |||
| curr_inds = easy_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| return easy_market | |||
| def test_search_semantics(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| test_folder = os.path.join(curr_root, "test_semantics") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(test_folder): | |||
| rmtree(test_folder) | |||
| os.makedirs(test_folder, exist_ok=True) | |||
| with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| for learnware in single_learnware_list: | |||
| print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def test_stat_search(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| test_folder = os.path.join(curr_root, "test_stat") | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| unzip_dir = os.path.join(test_folder, f"{idx}") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(unzip_dir): | |||
| rmtree(unzip_dir) | |||
| os.makedirs(unzip_dir, exist_ok=True) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def test_learnware_reuse(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| X, y = load_digits(return_X_y=True) | |||
| _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||
| _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | |||
| # print("Mixture Learnware:", mixture_learnware_list) | |||
| # Based on user information, the learnware market returns a list of learnwares (learnware_list) | |||
| # Use jobselector reuser to reuse the searched learnwares to make prediction | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) | |||
| job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) | |||
| # Use averaging ensemble reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list) | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y)) | |||
| print("Averaging Selector Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y)) | |||
| if __name__ == "__main__": | |||
| fire.Fire(LearnwareMarketWorkflow) | |||
| @@ -18,6 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker | |||
| from ..logger import get_module_logger | |||
| from ..specification import Specification | |||
| from ..learnware import get_learnware_from_dirpath | |||
| from ..market import BaseUserInfo | |||
| from ..tests import get_semantic_specification | |||
| CHUNK_SIZE = 1024 * 1024 | |||
| @@ -204,10 +205,10 @@ class LearnwareClient: | |||
| return learnware_list | |||
| @require_login | |||
| def search_learnware(self, specification: Specification, page_size=10, page_index=0): | |||
| def search_learnware(self, user_info: BaseUserInfo, page_size=10, page_index=0): | |||
| url = f"{self.host}/engine/search_learnware" | |||
| stat_spec = specification.get_stat_spec() | |||
| stat_spec = user_info.stat_info | |||
| if len(stat_spec) > 1: | |||
| raise Exception("statistical specification must have only one key.") | |||
| @@ -222,10 +223,7 @@ class LearnwareClient: | |||
| stat_spec.save(ftemp.name) | |||
| with open(ftemp.name, "r") as fin: | |||
| semantic_specification = specification.get_semantic_spec() | |||
| if semantic_specification is None: | |||
| semantic_specification = {} | |||
| semantic_specification = user_info.get_semantic_spec() | |||
| if stat_spec is None: | |||
| files = None | |||
| else: | |||
| @@ -235,7 +233,7 @@ class LearnwareClient: | |||
| url, | |||
| files=files, | |||
| data={ | |||
| "semantic_specification": json.dumps(specification.get_semantic_spec()), | |||
| "semantic_specification": json.dumps(semantic_specification), | |||
| "limit": page_size, | |||
| "page": page_index, | |||
| }, | |||
| @@ -249,13 +247,25 @@ class LearnwareClient: | |||
| for learnware in result["data"]["learnware_list_single"]: | |||
| returns.append( | |||
| { | |||
| { | |||
| "type": "single", | |||
| "learnware_id": learnware["learnware_id"], | |||
| "semantic_specification": learnware["semantic_specification"], | |||
| "matching": learnware["matching"], | |||
| } | |||
| ) | |||
| if len(result["data"]["learnware_list_multi"]) > 0: | |||
| multiple_learnware = { | |||
| "type": "multiple", | |||
| "learnware_ids": [], | |||
| "semantic_specifications": [], | |||
| "matching": result["data"]["learnware_list_multi"][0]["matching"] | |||
| } | |||
| for learnware in result["data"]["learnware_list_multi"]: | |||
| multiple_learnware["learnware_ids"].append(learnware["learnware_id"]) | |||
| multiple_learnware["semantic_specifications"].append(learnware["semantic_specification"]) | |||
| returns.append(multiple_learnware) | |||
| return returns | |||
| @require_login | |||
| @@ -3,11 +3,12 @@ from __future__ import annotations | |||
| import traceback | |||
| import zipfile | |||
| import tempfile | |||
| from typing import Tuple, Any, List, Union | |||
| from typing import Tuple, Any, List, Union, Dict, Optional | |||
| from dataclasses import dataclass | |||
| from ..learnware import Learnware, get_learnware_from_dirpath | |||
| from ..logger import get_module_logger | |||
| logger = get_module_logger("market_base", "INFO") | |||
| logger = get_module_logger("market_base") | |||
| class BaseUserInfo: | |||
| @@ -42,6 +43,9 @@ class BaseUserInfo: | |||
| def get_stat_info(self, name: str): | |||
| return self.stat_info.get(name, None) | |||
| def update_semantic_spec(self, semantic_spec: dict): | |||
| self.semantic_spec = semantic_spec | |||
| def update_stat_info(self, name: str, item: Any): | |||
| """Update stat_info by market | |||
| @@ -55,6 +59,33 @@ class BaseUserInfo: | |||
| self.stat_info[name] = item | |||
| @dataclass | |||
| class SingleSearchItem: | |||
| learnware: Learnware | |||
| score: Optional[float] = None | |||
| @dataclass | |||
| class MultipleSearchItem: | |||
| learnwares: List[Learnware] | |||
| score: float | |||
| class SearchResults: | |||
| def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): | |||
| self.update_single_results([] if single_results is None else single_results) | |||
| self.update_multiple_results([] if multiple_results is None else multiple_results) | |||
| def get_single_results(self) -> List[SingleSearchItem]: | |||
| return self.single_results | |||
| def get_multiple_results(self) -> List[MultipleSearchItem]: | |||
| return self.multiple_results | |||
| def update_single_results(self, single_results: List[SingleSearchItem]): | |||
| self.single_results = single_results | |||
| def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): | |||
| self.multiple_results = multiple_results | |||
| class LearnwareMarket: | |||
| """Base interface for market, it provide the interface of search/add/detele/update learnwares""" | |||
| @@ -150,7 +181,7 @@ class LearnwareMarket: | |||
| def search_learnware( | |||
| self, user_info: BaseUserInfo, check_status: int = None, **kwargs | |||
| ) -> Tuple[Any, List[Learnware]]: | |||
| ) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| @@ -163,7 +194,7 @@ class LearnwareMarket: | |||
| Returns | |||
| ------- | |||
| Tuple[Any, List[Learnware]] | |||
| SearchResults | |||
| Search results | |||
| """ | |||
| return self.learnware_searcher(user_info, check_status, **kwargs) | |||
| @@ -450,7 +481,7 @@ class BaseSearcher: | |||
| def reset(self, organizer: BaseOrganizer, **kwargs): | |||
| self.learnware_organizer = organizer | |||
| def __call__(self, user_info: BaseUserInfo, check_status: int = None): | |||
| def __call__(self, user_info: BaseUserInfo, check_status: int = None) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| @@ -2,11 +2,11 @@ import math | |||
| import torch | |||
| import numpy as np | |||
| from rapidfuzz import fuzz | |||
| from typing import Tuple, List, Union | |||
| from typing import Tuple, List, Union, Optional | |||
| from .organizer import EasyOrganizer | |||
| from ..utils import parse_specification_type | |||
| from ..base import BaseUserInfo, BaseSearcher | |||
| from ..base import BaseUserInfo, BaseSearcher, SearchResults, SingleSearchItem, MultipleSearchItem | |||
| from ...learnware import Learnware | |||
| from ...specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification, rkme_solve_qp | |||
| from ...logger import get_module_logger | |||
| @@ -57,7 +57,7 @@ class EasyExactSemanticSearcher(BaseSearcher): | |||
| return True | |||
| def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> List[Learnware]: | |||
| def __call__(self, learnware_list: List[Learnware], user_info: BaseUserInfo) -> SearchResults: | |||
| match_learnwares = [] | |||
| for learnware in learnware_list: | |||
| learnware_semantic_spec = learnware.get_specification().get_semantic_spec() | |||
| @@ -65,8 +65,7 @@ class EasyExactSemanticSearcher(BaseSearcher): | |||
| if self._match_semantic_spec(user_semantic_spec, learnware_semantic_spec): | |||
| match_learnwares.append(learnware) | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(match_learnwares), len(learnware_list))) | |||
| return match_learnwares | |||
| return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in match_learnwares]) | |||
| class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| def _match_semantic_spec_tag(self, semantic_spec1, semantic_spec2) -> bool: | |||
| @@ -111,7 +110,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| def __call__( | |||
| self, learnware_list: List[Learnware], user_info: BaseUserInfo, max_num: int = 50000, min_score: float = 75.0 | |||
| ) -> List[Learnware]: | |||
| ) -> SearchResults: | |||
| """Search learnware by fuzzy matching of semantic spec | |||
| Parameters | |||
| @@ -182,7 +181,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| final_result = matched_learnware_tag | |||
| logger.info("semantic_spec search: choose %d from %d learnwares" % (len(final_result), len(learnware_list))) | |||
| return final_result | |||
| return SearchResults(single_results=[SingleSearchItem(learnware=_learnware) for _learnware in final_result]) | |||
| class EasyStatSearcher(BaseSearcher): | |||
| @@ -328,7 +327,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| user_rkme: RKMETableSpecification, | |||
| max_search_num: int, | |||
| weight_cutoff: float = 0.98, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| ) -> Tuple[Optional[float], List[float], List[Learnware]]: | |||
| """Select learnwares based on a total mixture ratio, then recalculate their mixture weights | |||
| Parameters | |||
| @@ -351,7 +350,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| """ | |||
| learnware_num = len(learnware_list) | |||
| if learnware_num == 0: | |||
| return [], [] | |||
| return None, [], [] | |||
| if learnware_num < max_search_num: | |||
| logger.warning("Available Learnware num less than search_num!") | |||
| max_search_num = learnware_num | |||
| @@ -370,7 +369,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| if len(mixture_list) <= 1: | |||
| mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | |||
| mixture_weight = [1] | |||
| mixture_weight = [1.0] | |||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name(self.stat_spec_type)) | |||
| else: | |||
| if len(mixture_list) > max_search_num: | |||
| @@ -455,7 +454,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| user_rkme: RKMETableSpecification, | |||
| max_search_num: int, | |||
| decay_rate: float = 0.95, | |||
| ) -> Tuple[float, List[float], List[Learnware]]: | |||
| ) -> Tuple[Optional[float], List[float], List[Learnware]]: | |||
| """Greedily match learnwares such that their mixture become closer and closer to user's rkme | |||
| Parameters | |||
| @@ -484,7 +483,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| max_search_num = learnware_num | |||
| flag_list = [0 for _ in range(learnware_num)] | |||
| mixture_list, weight_list, mmd_dist = [], None, None | |||
| mixture_list, weight_list, mmd_dist = [], [], None | |||
| intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) | |||
| for k in range(max_search_num): | |||
| @@ -543,10 +542,10 @@ class EasyStatSearcher(BaseSearcher): | |||
| the second is the list of Learnware | |||
| both lists are sorted by mmd dist | |||
| """ | |||
| RKME_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] | |||
| rkme_list = [learnware.specification.get_stat_spec_by_name(self.stat_spec_type) for learnware in learnware_list] | |||
| mmd_dist_list = [] | |||
| for RKME in RKME_list: | |||
| mmd_dist = RKME.dist(user_rkme) | |||
| for rkme in rkme_list: | |||
| mmd_dist = rkme.dist(user_rkme) | |||
| mmd_dist_list.append(mmd_dist) | |||
| sorted_idx_list = sorted(range(len(learnware_list)), key=lambda k: mmd_dist_list[k]) | |||
| @@ -561,7 +560,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| user_info: BaseUserInfo, | |||
| max_search_num: int = 5, | |||
| search_method: str = "greedy", | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| ) -> SearchResults: | |||
| self.stat_spec_type = parse_specification_type(stat_specs=user_info.stat_info) | |||
| if self.stat_spec_type is None: | |||
| raise KeyError("No supported stat specification is given in the user info") | |||
| @@ -572,7 +571,7 @@ class EasyStatSearcher(BaseSearcher): | |||
| sorted_dist_list, single_learnware_list = self._search_by_rkme_spec_single(learnware_list, user_rkme) | |||
| if len(single_learnware_list) == 0: | |||
| return [], [], None, [] | |||
| return SearchResults() | |||
| processed_learnware_list = single_learnware_list[: max_search_num * max_search_num] | |||
| if sorted_dist_list[0] > 0 and search_method == "auto": | |||
| @@ -622,7 +621,16 @@ class EasyStatSearcher(BaseSearcher): | |||
| mixture_score = min(1, mixture_score * ratio) if mixture_score is not None else None | |||
| logger.info(f"After filter by rkme spec, learnware_list length is {len(learnware_list)}") | |||
| return sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list | |||
| search_results = SearchResults() | |||
| search_results.update_single_results( | |||
| [SingleSearchItem(learnware=_learnware, score=_score) for _score, _learnware in zip(sorted_score_list, single_learnware_list)] | |||
| ) | |||
| if mixture_score is not None and len(mixture_learnware_list) > 0: | |||
| search_results.update_multiple_results( | |||
| [MultipleSearchItem(learnwares=mixture_learnware_list, score=mixture_score)] | |||
| ) | |||
| return search_results | |||
| class EasySearcher(BaseSearcher): | |||
| @@ -638,7 +646,7 @@ class EasySearcher(BaseSearcher): | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| ) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| @@ -660,12 +668,13 @@ class EasySearcher(BaseSearcher): | |||
| the fourth is the list of Learnware (mixture), the size is search_num | |||
| """ | |||
| learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) | |||
| learnware_list = self.semantic_searcher(learnware_list, user_info) | |||
| semantic_search_result = self.semantic_searcher(learnware_list, user_info) | |||
| learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] | |||
| if len(learnware_list) == 0: | |||
| return [], [], 0.0, [] | |||
| return SearchResults() | |||
| if parse_specification_type(stat_specs=user_info.stat_info) is not None: | |||
| return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| else: | |||
| return None, learnware_list, 0.0, None | |||
| return semantic_search_result | |||
| @@ -2,7 +2,7 @@ import traceback | |||
| from typing import Tuple, List | |||
| from .utils import is_hetero | |||
| from ..base import BaseUserInfo | |||
| from ..base import BaseUserInfo, SearchResults | |||
| from ..easy import EasySearcher | |||
| from ..utils import parse_specification_type | |||
| from ...learnware import Learnware | |||
| @@ -15,7 +15,7 @@ logger = get_module_logger("hetero_searcher") | |||
| class HeteroSearcher(EasySearcher): | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, check_status: int = None, max_search_num: int = 5, search_method: str = "greedy" | |||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | |||
| ) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status. | |||
| Employs heterogeneous learnware search if specific requirements are met, otherwise resorts to homogeneous search methods. | |||
| @@ -38,10 +38,11 @@ class HeteroSearcher(EasySearcher): | |||
| the fourth is the list of Learnware (mixture), the size is search_num | |||
| """ | |||
| learnware_list = self.learnware_organizer.get_learnwares(check_status=check_status) | |||
| learnware_list = self.semantic_searcher(learnware_list, user_info) | |||
| semantic_search_result = self.semantic_searcher(learnware_list, user_info) | |||
| learnware_list = [search_item.learnware for search_item in semantic_search_result.get_single_results()] | |||
| if len(learnware_list) == 0: | |||
| return [], [], 0.0, [] | |||
| return SearchResults() | |||
| if parse_specification_type(stat_specs=user_info.stat_info) is not None: | |||
| if is_hetero(stat_specs=user_info.stat_info, semantic_spec=user_info.semantic_spec): | |||
| @@ -49,4 +50,4 @@ class HeteroSearcher(EasySearcher): | |||
| user_info.update_stat_info(user_hetero_spec.type, user_hetero_spec) | |||
| return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| else: | |||
| return None, learnware_list, 0.0, None | |||
| return semantic_search_result | |||
| @@ -199,26 +199,28 @@ class TestMarket(unittest.TestCase): | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_learnware_list) == 1, f"Exact semantic search failed!" | |||
| for learnware in single_learnware_list: | |||
| semantic_spec1 = learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", learnware.id, semantic_spec1) | |||
| assert len(single_result) == 1, f"Exact semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id, semantic_spec1) | |||
| assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" | |||
| semantic_spec["Name"]["Values"] = "laernwaer" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_learnware_list) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for learnware in single_learnware_list: | |||
| semantic_spec1 = learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", learnware.id, semantic_spec1) | |||
| assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id, semantic_spec1) | |||
| def test_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num) | |||
| @@ -256,49 +258,40 @@ class TestMarket(unittest.TestCase): | |||
| semantic_spec["Input"]["Description"] = { | |||
| str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | |||
| } | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print( | |||
| f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}" | |||
| ) | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print( | |||
| f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" | |||
| ) | |||
| # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec | |||
| print(">> test for key 'Task' has empty 'Values':") | |||
| semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_learnware_list) == 0, f"Statistical search failed!" | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" | |||
| print(">> delele key 'Task' test:") | |||
| semantic_spec.pop("Task") | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_learnware_list) == 0, f"Statistical search failed!" | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." | |||
| print(">> mismatch dim test") | |||
| @@ -310,14 +303,10 @@ class TestMarket(unittest.TestCase): | |||
| } | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_learnware_list) == 0, f"Statistical search failed!" | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| rmtree(test_folder) # rm -r test_folder | |||
| @@ -338,21 +327,19 @@ class TestMarket(unittest.TestCase): | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "stat.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| target_spec_num = 3 if idx % 2 == 0 else 2 | |||
| assert len(single_learnware_list) >= 1, f"Statistical search failed!" | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print(f"mixture_score: {multiple_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| @@ -370,26 +357,24 @@ class TestMarket(unittest.TestCase): | |||
| # learnware market search | |||
| hetero_market = self.test_train_market_model(learnware_num) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = hetero_market.search_learnware(user_info) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| # print search results | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}, mixture_learnware_ids: {[item.id for item in mixture_learnware_list]}") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print(f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}") | |||
| # single model reuse | |||
| hetero_learnware = HeteroMapAlignLearnware(single_learnware_list[0], mode="regression") | |||
| hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") | |||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | |||
| single_predict_y = hetero_learnware.predict(X) | |||
| # multi model reuse | |||
| hetero_learnware_list = [] | |||
| for learnware in mixture_learnware_list: | |||
| for learnware in multiple_result[0].learnwares: | |||
| hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") | |||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | |||
| hetero_learnware_list.append(hetero_learnware) | |||
| @@ -6,6 +6,7 @@ import tempfile | |||
| from learnware.client import LearnwareClient | |||
| from learnware.specification import Specification | |||
| from learnware.market import BaseUserInfo | |||
| class TestAllLearnware(unittest.TestCase): | |||
| @@ -30,16 +31,9 @@ class TestAllLearnware(unittest.TestCase): | |||
| def test_all_learnware(self): | |||
| max_learnware_num = 1000 | |||
| semantic_spec = dict() | |||
| semantic_spec["Data"] = {"Type": "Class", "Values": []} | |||
| semantic_spec["Task"] = {"Type": "Class", "Values": []} | |||
| semantic_spec["Library"] = {"Type": "Class", "Values": []} | |||
| semantic_spec["Scenario"] = {"Type": "Tag", "Values": []} | |||
| semantic_spec["Name"] = {"Type": "String", "Values": ""} | |||
| semantic_spec["Description"] = {"Type": "String", "Values": ""} | |||
| specification = Specification(semantic_spec=semantic_spec) | |||
| result = self.client.search_learnware(specification, page_size=max_learnware_num) | |||
| semantic_spec = self.client.create_semantic_specification() | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) | |||
| result = self.client.search_learnware(user_info, page_size=max_learnware_num) | |||
| print(f"result size: {len(result)}") | |||
| print(f"key in result: {[key for key in result[0]]}") | |||
| @@ -143,12 +143,13 @@ class TestWorkflow(unittest.TestCase): | |||
| semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| _, single_learnware_list, _, _ = easy_market.search_learnware(user_info) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| for learnware in single_learnware_list: | |||
| print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) | |||
| for search_item in single_result: | |||
| print("Choose learnware:", search_item.learnware.id, search_item.learnware.get_specification().get_semantic_spec()) | |||
| rmtree(test_folder) # rm -r test_folder | |||
| @@ -171,20 +172,20 @@ class TestWorkflow(unittest.TestCase): | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| ( | |||
| sorted_score_list, | |||
| single_learnware_list, | |||
| mixture_score, | |||
| mixture_learnware_list, | |||
| ) = easy_market.search_learnware(user_info) | |||
| assert len(single_learnware_list) >= 1, f"Statistical search failed!" | |||
| search_results = easy_market.search_learnware(user_info) | |||
| single_result = search_results.get_single_results() | |||
| multiple_result = search_results.get_multiple_results() | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for score, learnware in zip(sorted_score_list, single_learnware_list): | |||
| print(f"score: {score}, learnware_id: {learnware.id}") | |||
| print(f"mixture_score: {mixture_score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| for search_item in single_result: | |||
| print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}") | |||
| for mixture_item in multiple_result: | |||
| print(f"mixture_score: {mixture_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| @@ -198,24 +199,25 @@ class TestWorkflow(unittest.TestCase): | |||
| stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||
| _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | |||
| search_results = easy_market.search_learnware(user_info) | |||
| multiple_result = search_results.get_multiple_results() | |||
| mixture_item = multiple_result[0] | |||
| # Based on user information, the learnware market returns a list of learnwares (learnware_list) | |||
| # Use jobselector reuser to reuse the searched learnwares to make prediction | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list) | |||
| reuse_job_selector = JobSelectorReuser(learnware_list=mixture_item.learnwares) | |||
| job_selector_predict_y = reuse_job_selector.predict(user_data=data_X) | |||
| # Use averaging ensemble reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob") | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_item.learnwares, mode="vote_by_prob") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| # Use ensemble pruning reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_item.learnwares, mode="classification") | |||
| reuse_ensemble.fit(train_X[-200:], train_y[-200:]) | |||
| ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X) | |||
| # Use feature augment reuser to reuse the searched learnwares to make prediction | |||
| reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_learnware_list, mode="classification") | |||
| reuse_feature_augment = FeatureAugmentReuser(learnware_list=mixture_item.learnwares, mode="classification") | |||
| reuse_feature_augment.fit(train_X[-200:], train_y[-200:]) | |||
| feature_augment_predict_y = reuse_feature_augment.predict(user_data=data_X) | |||
| @@ -227,8 +229,8 @@ class TestWorkflow(unittest.TestCase): | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestWorkflow("test_upload_delete_learnware")) | |||
| #_suite.addTest(TestWorkflow("test_prepare_learnware_randomly")) | |||
| #_suite.addTest(TestWorkflow("test_upload_delete_learnware")) | |||
| _suite.addTest(TestWorkflow("test_search_semantics")) | |||
| _suite.addTest(TestWorkflow("test_stat_search")) | |||
| _suite.addTest(TestWorkflow("test_learnware_reuse")) | |||