| @@ -80,7 +80,7 @@ is composed of the following four parts. | |||||
| - ``learnware.yaml`` | - ``learnware.yaml`` | ||||
| A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and | |||||
| A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and | |||||
| the file name of your statistical specification file. | the file name of your statistical specification file. | ||||
| - ``environment.yaml`` | - ``environment.yaml`` | ||||
| @@ -178,10 +178,10 @@ For example, the following code is designed to work with Reduced Set Kernel Embe | |||||
| ```python | ```python | ||||
| import learnware.specification as specification | import learnware.specification as specification | ||||
| user_spec = specification.RKMEStatSpecification() | |||||
| user_spec = specification.RKMETableSpecification() | |||||
| user_spec.load(os.path.join(unzip_path, "rkme.json")) | user_spec.load(os.path.join(unzip_path, "rkme.json")) | ||||
| user_info = BaseUserInfo( | user_info = BaseUserInfo( | ||||
| semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} | |||||
| semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} | |||||
| ) | ) | ||||
| (sorted_score_list, single_learnware_list, | (sorted_score_list, single_learnware_list, | ||||
| mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) | mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) | ||||
| @@ -50,7 +50,7 @@ Specification | |||||
| .. autoclass:: learnware.specification.BaseStatSpecification | .. autoclass:: learnware.specification.BaseStatSpecification | ||||
| :members: | :members: | ||||
| .. autoclass:: learnware.specification.RKMEStatSpecification | |||||
| .. autoclass:: learnware.specification.RKMETableSpecification | |||||
| :members: | :members: | ||||
| Model | Model | ||||
| @@ -117,13 +117,13 @@ You can search learnwares in official market using semantic specification. All t | |||||
| Statistical Specification Search | Statistical Specification Search | ||||
| --------------------------------- | --------------------------------- | ||||
| You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMEStatSpecification`: | |||||
| You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMETableSpecification`: | |||||
| .. code-block:: python | .. code-block:: python | ||||
| import learnware.specification as specification | import learnware.specification as specification | ||||
| user_spec = specification.RKMEStatSpecification() | |||||
| user_spec = specification.RKMETableSpecification() | |||||
| user_spec.load(os.path.join(unzip_path, "rkme.json")) | user_spec.load(os.path.join(unzip_path, "rkme.json")) | ||||
| specification = learnware.specification.Specification() | specification = learnware.specification.Specification() | ||||
| @@ -138,7 +138,7 @@ You can search learnware by providing a statistical specification. The statistic | |||||
| Combine Semantic and Statistical Search | Combine Semantic and Statistical Search | ||||
| ---------------------------------------- | ---------------------------------------- | ||||
| You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMEStatSpecification`: | |||||
| You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMETableSpecification`: | |||||
| .. code-block:: python | .. code-block:: python | ||||
| @@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares | |||||
| senarioes=[], | senarioes=[], | ||||
| input_description={}, output_description={}) | input_description={}, output_description={}) | ||||
| stat_spec = specification.RKMEStatSpecification() | |||||
| stat_spec = specification.RKMETableSpecification() | |||||
| stat_spec.load(os.path.join(unzip_path, "rkme.json")) | stat_spec.load(os.path.join(unzip_path, "rkme.json")) | ||||
| specification = learnware.specification.Specification() | specification = learnware.specification.Specification() | ||||
| specification.update_semantic_spec(semantic_spec) | specification.update_semantic_spec(semantic_spec) | ||||
| @@ -47,7 +47,7 @@ includes the following four components: | |||||
| - ``learnware.yaml`` | - ``learnware.yaml`` | ||||
| A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMEStatSpecification`` for Reduced Kernel Mean Embedding), and | |||||
| A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMETableSpecification`` for Reduced Kernel Mean Embedding), and | |||||
| the file name of your statistical specification file. | the file name of your statistical specification file. | ||||
| - ``environment.yaml`` or ``requirements.txt`` | - ``environment.yaml`` or ``requirements.txt`` | ||||
| @@ -170,12 +170,12 @@ For example, the code below executes learnware search when using Reduced Set Ker | |||||
| import learnware.specification as specification | import learnware.specification as specification | ||||
| user_spec = specification.RKMEStatSpecification() | |||||
| user_spec = specification.RKMETableSpecification() | |||||
| # unzip_path: directory for unzipped learnware zipfile | # unzip_path: directory for unzipped learnware zipfile | ||||
| user_spec.load(os.path.join(unzip_path, "rkme.json")) | user_spec.load(os.path.join(unzip_path, "rkme.json")) | ||||
| user_info = BaseUserInfo( | user_info = BaseUserInfo( | ||||
| semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} | |||||
| semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} | |||||
| ) | ) | ||||
| (sorted_score_list, single_learnware_list, | (sorted_score_list, single_learnware_list, | ||||
| mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) | mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) | ||||
| @@ -73,10 +73,10 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb | |||||
| import learnware.specification as specification | import learnware.specification as specification | ||||
| user_spec = specification.RKMEStatSpecification() | |||||
| user_spec = specification.RKMETableSpecification() | |||||
| user_spec.load(os.path.join("rkme.json")) | user_spec.load(os.path.join("rkme.json")) | ||||
| user_info = BaseUserInfo( | user_info = BaseUserInfo( | ||||
| semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec} | |||||
| semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec} | |||||
| ) | ) | ||||
| (sorted_score_list, single_learnware_list, | (sorted_score_list, single_learnware_list, | ||||
| mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) | mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) | ||||
| @@ -94,7 +94,7 @@ guaranteeing the security and privacy of your local original data. | |||||
| ------------------ | ------------------ | ||||
| Additionally, you are asked to prepare a configuration file in YAML format. | Additionally, you are asked to prepare a configuration file in YAML format. | ||||
| The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and | |||||
| The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and | |||||
| the file name of your statistical specification file. The following ``learnware.yaml`` provides an example of | the file name of your statistical specification file. The following ``learnware.yaml`` provides an example of | ||||
| how your learnware configuration file should be structured, based on our previous discussion: | how your learnware configuration file should be structured, based on our previous discussion: | ||||
| @@ -105,7 +105,7 @@ how your learnware configuration file should be structured, based on our previou | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEStatSpecification | |||||
| class_name: RKMETableSpecification | |||||
| file_name: stat.json | file_name: stat.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -3,6 +3,6 @@ model: | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEImageStatSpecification | |||||
| class_name: RKMEImageSpecification | |||||
| file_name: rkme.json | file_name: rkme.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -6,7 +6,7 @@ from get_data import * | |||||
| import os | import os | ||||
| import random | import random | ||||
| from learnware.specification.image import RKMEImageStatSpecification | |||||
| from learnware.specification.image import RKMEImageSpecification | |||||
| from learnware.reuse.averaging import AveragingReuser | from learnware.reuse.averaging import AveragingReuser | ||||
| from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction | from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction | ||||
| from learnware.learnware import Learnware | from learnware.learnware import Learnware | ||||
| @@ -100,7 +100,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo | |||||
| X_sampled = X[indices] | X_sampled = X[indices] | ||||
| st = time.time() | st = time.time() | ||||
| user_spec = RKMEImageStatSpecification(cuda_idx=0) | |||||
| user_spec = RKMEImageSpecification(cuda_idx=0) | |||||
| user_spec.generate_stat_spec_from_data(X=X_sampled) | user_spec.generate_stat_spec_from_data(X=X_sampled) | ||||
| ed = time.time() | ed = time.time() | ||||
| logger.info("Stat spec generated in %.3f s" % (ed - st)) | logger.info("Stat spec generated in %.3f s" % (ed - st)) | ||||
| @@ -164,9 +164,9 @@ def test_search(gamma=0.1, load_market=True): | |||||
| user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) | user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) | ||||
| user_data = np.load(user_data_path) | user_data = np.load(user_data_path) | ||||
| user_label = np.load(user_label_path) | user_label = np.load(user_label_path) | ||||
| user_stat_spec = RKMEImageStatSpecification(cuda_idx=0) | |||||
| user_stat_spec = RKMEImageSpecification(cuda_idx=0) | |||||
| user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) | user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec}) | |||||
| logger.info("Searching Market for user: %d" % i) | logger.info("Searching Market for user: %d" % i) | ||||
| sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( | sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( | ||||
| user_info | user_info | ||||
| @@ -3,6 +3,6 @@ model: | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEStatSpecification | |||||
| class_name: RKMETableSpecification | |||||
| file_name: rkme.json | file_name: rkme.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -144,7 +144,7 @@ class M5DatasetWorkflow: | |||||
| user_spec_path = f"./user_spec/user_{idx}.json" | user_spec_path = f"./user_spec/user_{idx}.json" | ||||
| user_spec.save(user_spec_path) | user_spec.save(user_spec_path) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||||
| ( | ( | ||||
| sorted_score_list, | sorted_score_list, | ||||
| single_learnware_list, | single_learnware_list, | ||||
| @@ -3,6 +3,6 @@ model: | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEStatSpecification | |||||
| class_name: RKMETableSpecification | |||||
| file_name: rkme.json | file_name: rkme.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -142,7 +142,7 @@ class PFSDatasetWorkflow: | |||||
| user_spec_path = f"./user_spec/user_{idx}.json" | user_spec_path = f"./user_spec/user_{idx}.json" | ||||
| user_spec.save(user_spec_path) | user_spec.save(user_spec_path) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||||
| ( | ( | ||||
| sorted_score_list, | sorted_score_list, | ||||
| single_learnware_list, | single_learnware_list, | ||||
| @@ -3,6 +3,6 @@ model: | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEStatSpecification | |||||
| class_name: RKMETableSpecification | |||||
| file_name: svm.json | file_name: svm.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -148,9 +148,9 @@ class LearnwareMarketWorkflow: | |||||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | with zipfile.ZipFile(zip_path, "r") as zip_obj: | ||||
| zip_obj.extractall(path=unzip_dir) | zip_obj.extractall(path=unzip_dir) | ||||
| user_spec = specification.RKMEStatSpecification() | |||||
| user_spec = specification.RKMETableSpecification() | |||||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | user_spec.load(os.path.join(unzip_dir, "svm.json")) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||||
| ( | ( | ||||
| sorted_score_list, | sorted_score_list, | ||||
| single_learnware_list, | single_learnware_list, | ||||
| @@ -175,7 +175,7 @@ class LearnwareMarketWorkflow: | |||||
| _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | ||||
| stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||||
| _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | ||||
| @@ -37,7 +37,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath: | |||||
| "stat_specifications": [ | "stat_specifications": [ | ||||
| { | { | ||||
| "module_path": "learnware.specification", | "module_path": "learnware.specification", | ||||
| "class_name": "RKMEStatSpecification", | |||||
| "class_name": "RKMETableSpecification", | |||||
| "file_name": "stat_spec.json", | "file_name": "stat_spec.json", | ||||
| "kwargs": {}, | "kwargs": {}, | ||||
| }, | }, | ||||
| @@ -18,7 +18,7 @@ from .. import utils | |||||
| from ..config import C as conf | from ..config import C as conf | ||||
| from ..logger import get_module_logger | from ..logger import get_module_logger | ||||
| from ..learnware import Learnware, get_learnware_from_dirpath | from ..learnware import Learnware, get_learnware_from_dirpath | ||||
| from ..specification import RKMEStatSpecification, Specification | |||||
| from ..specification import RKMETableSpecification, Specification | |||||
| logger = get_module_logger("market", "INFO") | logger = get_module_logger("market", "INFO") | ||||
| @@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket): | |||||
| pass | pass | ||||
| # check rkme dimension | # check rkme dimension | ||||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") | |||||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") | |||||
| if stat_spec is not None: | if stat_spec is not None: | ||||
| if stat_spec.get_z().shape[1:] != input_shape: | if stat_spec.get_z().shape[1:] != input_shape: | ||||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") | logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") | ||||
| @@ -296,7 +296,7 @@ class EasyMarket(LearnwareMarket): | |||||
| def _calculate_rkme_spec_mixture_weight( | def _calculate_rkme_spec_mixture_weight( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| intermediate_K: np.ndarray = None, | intermediate_K: np.ndarray = None, | ||||
| intermediate_C: np.ndarray = None, | intermediate_C: np.ndarray = None, | ||||
| ) -> Tuple[List[float], float]: | ) -> Tuple[List[float], float]: | ||||
| @@ -306,7 +306,7 @@ class EasyMarket(LearnwareMarket): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| A list of existing learnwares | A list of existing learnwares | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| intermediate_K : np.ndarray, optional | intermediate_K : np.ndarray, optional | ||||
| Intermediate kernel matrix K, by default None | Intermediate kernel matrix K, by default None | ||||
| @@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket): | |||||
| """ | """ | ||||
| learnware_num = len(learnware_list) | learnware_num = len(learnware_list) | ||||
| RKME_list = [ | RKME_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||||
| ] | ] | ||||
| if type(intermediate_K) == np.ndarray: | if type(intermediate_K) == np.ndarray: | ||||
| @@ -365,7 +365,7 @@ class EasyMarket(LearnwareMarket): | |||||
| def _calculate_intermediate_K_and_C( | def _calculate_intermediate_K_and_C( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| intermediate_K: np.ndarray = None, | intermediate_K: np.ndarray = None, | ||||
| intermediate_C: np.ndarray = None, | intermediate_C: np.ndarray = None, | ||||
| ) -> Tuple[np.ndarray, np.ndarray]: | ) -> Tuple[np.ndarray, np.ndarray]: | ||||
| @@ -375,7 +375,7 @@ class EasyMarket(LearnwareMarket): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares up till now | The list of learnwares up till now | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| intermediate_K : np.ndarray, optional | intermediate_K : np.ndarray, optional | ||||
| Intermediate kernel matrix K, by default None | Intermediate kernel matrix K, by default None | ||||
| @@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket): | |||||
| """ | """ | ||||
| num = intermediate_K.shape[0] - 1 | num = intermediate_K.shape[0] - 1 | ||||
| RKME_list = [ | RKME_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||||
| ] | ] | ||||
| for i in range(intermediate_K.shape[0]): | for i in range(intermediate_K.shape[0]): | ||||
| intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | ||||
| @@ -400,7 +400,7 @@ class EasyMarket(LearnwareMarket): | |||||
| def _search_by_rkme_spec_mixture_auto( | def _search_by_rkme_spec_mixture_auto( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| max_search_num: int, | max_search_num: int, | ||||
| weight_cutoff: float = 0.98, | weight_cutoff: float = 0.98, | ||||
| ) -> Tuple[float, List[float], List[Learnware]]: | ) -> Tuple[float, List[float], List[Learnware]]: | ||||
| @@ -410,7 +410,7 @@ class EasyMarket(LearnwareMarket): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| max_search_num : int | max_search_num : int | ||||
| The maximum number of the returned learnwares | The maximum number of the returned learnwares | ||||
| @@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket): | |||||
| if len(mixture_list) <= 1: | if len(mixture_list) <= 1: | ||||
| mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | ||||
| mixture_weight = [1] | mixture_weight = [1] | ||||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) | |||||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) | |||||
| else: | else: | ||||
| if len(mixture_list) > max_search_num: | if len(mixture_list) > max_search_num: | ||||
| mixture_list = mixture_list[:max_search_num] | mixture_list = mixture_list[:max_search_num] | ||||
| @@ -488,7 +488,7 @@ class EasyMarket(LearnwareMarket): | |||||
| return sorted_score_list[:idx], learnware_list[:idx] | return sorted_score_list[:idx], learnware_list[:idx] | ||||
| def _filter_by_rkme_spec_dimension( | def _filter_by_rkme_spec_dimension( | ||||
| self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification | |||||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||||
| ) -> List[Learnware]: | ) -> List[Learnware]: | ||||
| """Filter learnwares whose rkme dimension different from user_rkme | """Filter learnwares whose rkme dimension different from user_rkme | ||||
| @@ -496,7 +496,7 @@ class EasyMarket(LearnwareMarket): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| Returns | Returns | ||||
| @@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket): | |||||
| user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | ||||
| for learnware in learnware_list: | for learnware in learnware_list: | ||||
| rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") | |||||
| rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") | |||||
| rkme_dim = str(list(rkme.get_z().shape)[1:]) | rkme_dim = str(list(rkme.get_z().shape)[1:]) | ||||
| if rkme_dim == user_rkme_dim: | if rkme_dim == user_rkme_dim: | ||||
| filtered_learnware_list.append(learnware) | filtered_learnware_list.append(learnware) | ||||
| @@ -518,7 +518,7 @@ class EasyMarket(LearnwareMarket): | |||||
| def _search_by_rkme_spec_mixture_greedy( | def _search_by_rkme_spec_mixture_greedy( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| max_search_num: int, | max_search_num: int, | ||||
| score_cutoff: float = 0.001, | score_cutoff: float = 0.001, | ||||
| ) -> Tuple[float, List[float], List[Learnware]]: | ) -> Tuple[float, List[float], List[Learnware]]: | ||||
| @@ -528,7 +528,7 @@ class EasyMarket(LearnwareMarket): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| max_search_num : int | max_search_num : int | ||||
| The maximum number of the returned learnwares | The maximum number of the returned learnwares | ||||
| @@ -588,7 +588,7 @@ class EasyMarket(LearnwareMarket): | |||||
| return mmd_dist, weight_min, mixture_list | return mmd_dist, weight_min, mixture_list | ||||
| def _search_by_rkme_spec_single( | def _search_by_rkme_spec_single( | ||||
| self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification | |||||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||||
| ) -> Tuple[List[float], List[Learnware]]: | ) -> Tuple[List[float], List[Learnware]]: | ||||
| """Calculate the distances between learnwares in the given learnware_list and user_rkme | """Calculate the distances between learnwares in the given learnware_list and user_rkme | ||||
| @@ -596,7 +596,7 @@ class EasyMarket(LearnwareMarket): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| user RKME statistical specification | user RKME statistical specification | ||||
| Returns | Returns | ||||
| @@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket): | |||||
| both lists are sorted by mmd dist | both lists are sorted by mmd dist | ||||
| """ | """ | ||||
| RKME_list = [ | RKME_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||||
| ] | ] | ||||
| mmd_dist_list = [] | mmd_dist_list = [] | ||||
| for RKME in RKME_list: | for RKME in RKME_list: | ||||
| @@ -819,12 +819,12 @@ class EasyMarket(LearnwareMarket): | |||||
| # if len(learnware_list) == 0: | # if len(learnware_list) == 0: | ||||
| learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) | learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) | ||||
| if "RKMEStatSpecification" not in user_info.stat_info: | |||||
| if "RKMETableSpecification" not in user_info.stat_info: | |||||
| return None, learnware_list, 0.0, None | return None, learnware_list, 0.0, None | ||||
| elif len(learnware_list) == 0: | elif len(learnware_list) == 0: | ||||
| return [], [], 0.0, [] | return [], [], 0.0, [] | ||||
| else: | else: | ||||
| user_rkme = user_info.stat_info["RKMEStatSpecification"] | |||||
| user_rkme = user_info.stat_info["RKMETableSpecification"] | |||||
| learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | ||||
| logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | ||||
| @@ -77,7 +77,7 @@ class EasyStatisticalChecker(BaseChecker): | |||||
| input_shape = learnware_model.input_shape | input_shape = learnware_model.input_shape | ||||
| # Check rkme dimension | # Check rkme dimension | ||||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") | |||||
| stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") | |||||
| if stat_spec is not None: | if stat_spec is not None: | ||||
| if stat_spec.get_z().shape[1:] != input_shape: | if stat_spec.get_z().shape[1:] != input_shape: | ||||
| logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") | logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") | ||||
| @@ -7,7 +7,7 @@ from typing import Tuple, List | |||||
| from .organizer import EasyOrganizer | from .organizer import EasyOrganizer | ||||
| from ..base import BaseUserInfo, BaseSearcher | from ..base import BaseUserInfo, BaseSearcher | ||||
| from ...learnware import Learnware | from ...learnware import Learnware | ||||
| from ...specification import RKMEStatSpecification | |||||
| from ...specification import RKMETableSpecification | |||||
| from ...logger import get_module_logger | from ...logger import get_module_logger | ||||
| logger = get_module_logger("easy_seacher") | logger = get_module_logger("easy_seacher") | ||||
| @@ -227,7 +227,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| def _calculate_rkme_spec_mixture_weight( | def _calculate_rkme_spec_mixture_weight( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| intermediate_K: np.ndarray = None, | intermediate_K: np.ndarray = None, | ||||
| intermediate_C: np.ndarray = None, | intermediate_C: np.ndarray = None, | ||||
| ) -> Tuple[List[float], float]: | ) -> Tuple[List[float], float]: | ||||
| @@ -237,7 +237,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| A list of existing learnwares | A list of existing learnwares | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| intermediate_K : np.ndarray, optional | intermediate_K : np.ndarray, optional | ||||
| Intermediate kernel matrix K, by default None | Intermediate kernel matrix K, by default None | ||||
| @@ -252,7 +252,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| """ | """ | ||||
| learnware_num = len(learnware_list) | learnware_num = len(learnware_list) | ||||
| RKME_list = [ | RKME_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||||
| ] | ] | ||||
| if type(intermediate_K) == np.ndarray: | if type(intermediate_K) == np.ndarray: | ||||
| @@ -296,7 +296,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| def _calculate_intermediate_K_and_C( | def _calculate_intermediate_K_and_C( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| intermediate_K: np.ndarray = None, | intermediate_K: np.ndarray = None, | ||||
| intermediate_C: np.ndarray = None, | intermediate_C: np.ndarray = None, | ||||
| ) -> Tuple[np.ndarray, np.ndarray]: | ) -> Tuple[np.ndarray, np.ndarray]: | ||||
| @@ -306,7 +306,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares up till now | The list of learnwares up till now | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| intermediate_K : np.ndarray, optional | intermediate_K : np.ndarray, optional | ||||
| Intermediate kernel matrix K, by default None | Intermediate kernel matrix K, by default None | ||||
| @@ -321,7 +321,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| """ | """ | ||||
| num = intermediate_K.shape[0] - 1 | num = intermediate_K.shape[0] - 1 | ||||
| RKME_list = [ | RKME_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||||
| ] | ] | ||||
| for i in range(intermediate_K.shape[0]): | for i in range(intermediate_K.shape[0]): | ||||
| intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) | ||||
| @@ -331,7 +331,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| def _search_by_rkme_spec_mixture_auto( | def _search_by_rkme_spec_mixture_auto( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| max_search_num: int, | max_search_num: int, | ||||
| weight_cutoff: float = 0.98, | weight_cutoff: float = 0.98, | ||||
| ) -> Tuple[float, List[float], List[Learnware]]: | ) -> Tuple[float, List[float], List[Learnware]]: | ||||
| @@ -341,7 +341,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| max_search_num : int | max_search_num : int | ||||
| The maximum number of the returned learnwares | The maximum number of the returned learnwares | ||||
| @@ -377,7 +377,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| if len(mixture_list) <= 1: | if len(mixture_list) <= 1: | ||||
| mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] | ||||
| mixture_weight = [1] | mixture_weight = [1] | ||||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) | |||||
| mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) | |||||
| else: | else: | ||||
| if len(mixture_list) > max_search_num: | if len(mixture_list) > max_search_num: | ||||
| mixture_list = mixture_list[:max_search_num] | mixture_list = mixture_list[:max_search_num] | ||||
| @@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| return sorted_score_list[:idx], learnware_list[:idx] | return sorted_score_list[:idx], learnware_list[:idx] | ||||
| def _filter_by_rkme_spec_dimension( | def _filter_by_rkme_spec_dimension( | ||||
| self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification | |||||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||||
| ) -> List[Learnware]: | ) -> List[Learnware]: | ||||
| """Filter learnwares whose rkme dimension different from user_rkme | """Filter learnwares whose rkme dimension different from user_rkme | ||||
| @@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| Returns | Returns | ||||
| @@ -439,7 +439,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) | ||||
| for learnware in learnware_list: | for learnware in learnware_list: | ||||
| rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") | |||||
| rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") | |||||
| rkme_dim = str(list(rkme.get_z().shape)[1:]) | rkme_dim = str(list(rkme.get_z().shape)[1:]) | ||||
| if rkme_dim == user_rkme_dim: | if rkme_dim == user_rkme_dim: | ||||
| filtered_learnware_list.append(learnware) | filtered_learnware_list.append(learnware) | ||||
| @@ -449,7 +449,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| def _search_by_rkme_spec_mixture_greedy( | def _search_by_rkme_spec_mixture_greedy( | ||||
| self, | self, | ||||
| learnware_list: List[Learnware], | learnware_list: List[Learnware], | ||||
| user_rkme: RKMEStatSpecification, | |||||
| user_rkme: RKMETableSpecification, | |||||
| max_search_num: int, | max_search_num: int, | ||||
| score_cutoff: float = 0.001, | score_cutoff: float = 0.001, | ||||
| ) -> Tuple[float, List[float], List[Learnware]]: | ) -> Tuple[float, List[float], List[Learnware]]: | ||||
| @@ -459,7 +459,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| User RKME statistical specification | User RKME statistical specification | ||||
| max_search_num : int | max_search_num : int | ||||
| The maximum number of the returned learnwares | The maximum number of the returned learnwares | ||||
| @@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| return mmd_dist, weight_min, mixture_list | return mmd_dist, weight_min, mixture_list | ||||
| def _search_by_rkme_spec_single( | def _search_by_rkme_spec_single( | ||||
| self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification | |||||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||||
| ) -> Tuple[List[float], List[Learnware]]: | ) -> Tuple[List[float], List[Learnware]]: | ||||
| """Calculate the distances between learnwares in the given learnware_list and user_rkme | """Calculate the distances between learnwares in the given learnware_list and user_rkme | ||||
| @@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| ---------- | ---------- | ||||
| learnware_list : List[Learnware] | learnware_list : List[Learnware] | ||||
| The list of learnwares whose mixture approximates the user's rkme | The list of learnwares whose mixture approximates the user's rkme | ||||
| user_rkme : RKMEStatSpecification | |||||
| user_rkme : RKMETableSpecification | |||||
| user RKME statistical specification | user RKME statistical specification | ||||
| Returns | Returns | ||||
| @@ -538,7 +538,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| both lists are sorted by mmd dist | both lists are sorted by mmd dist | ||||
| """ | """ | ||||
| RKME_list = [ | RKME_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list | |||||
| ] | ] | ||||
| mmd_dist_list = [] | mmd_dist_list = [] | ||||
| for RKME in RKME_list: | for RKME in RKME_list: | ||||
| @@ -558,7 +558,7 @@ class EasyTableSearcher(BaseSearcher): | |||||
| max_search_num: int = 5, | max_search_num: int = 5, | ||||
| search_method: str = "greedy", | search_method: str = "greedy", | ||||
| ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: | ||||
| user_rkme = user_info.stat_info["RKMEStatSpecification"] | |||||
| user_rkme = user_info.stat_info["RKMETableSpecification"] | |||||
| learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) | ||||
| logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") | ||||
| @@ -631,7 +631,7 @@ class EasySearcher(BaseSearcher): | |||||
| if len(learnware_list) == 0: | if len(learnware_list) == 0: | ||||
| return [], [], 0.0, [] | return [], [], 0.0, [] | ||||
| elif "RKMEStatSpecification" in user_info.stat_info: | |||||
| elif "RKMETableSpecification" in user_info.stat_info: | |||||
| return self.table_searcher(learnware_list, user_info, max_search_num, search_method) | return self.table_searcher(learnware_list, user_info, max_search_num, search_method) | ||||
| else: | else: | ||||
| return None, learnware_list, 0.0, None | return None, learnware_list, 0.0, None | ||||
| @@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score | |||||
| from learnware.learnware import Learnware | from learnware.learnware import Learnware | ||||
| import learnware.specification as specification | import learnware.specification as specification | ||||
| from .base import BaseReuser | from .base import BaseReuser | ||||
| from ..specification import RKMEStatSpecification | |||||
| from ..specification import RKMETableSpecification | |||||
| from ..logger import get_module_logger | from ..logger import get_module_logger | ||||
| logger = get_module_logger("job_selector_reuse") | logger = get_module_logger("job_selector_reuse") | ||||
| @@ -86,7 +86,7 @@ class JobSelectorReuser(BaseReuser): | |||||
| return np.array([0] * user_data_num) | return np.array([0] * user_data_num) | ||||
| else: | else: | ||||
| learnware_rkme_spec_list = [ | learnware_rkme_spec_list = [ | ||||
| learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") | |||||
| learnware.specification.get_stat_spec_by_name("RKMETableSpecification") | |||||
| for learnware in self.learnware_list | for learnware in self.learnware_list | ||||
| ] | ] | ||||
| @@ -154,7 +154,7 @@ class JobSelectorReuser(BaseReuser): | |||||
| return job_select_result | return job_select_result | ||||
| def _calculate_rkme_spec_mixture_weight( | def _calculate_rkme_spec_mixture_weight( | ||||
| self, user_data: np.ndarray, task_rkme_list: List[RKMEStatSpecification], task_rkme_matrix: np.ndarray | |||||
| self, user_data: np.ndarray, task_rkme_list: List[RKMETableSpecification], task_rkme_matrix: np.ndarray | |||||
| ) -> List[float]: | ) -> List[float]: | ||||
| """_summary_ | """_summary_ | ||||
| @@ -162,7 +162,7 @@ class JobSelectorReuser(BaseReuser): | |||||
| ---------- | ---------- | ||||
| user_data : np.ndarray | user_data : np.ndarray | ||||
| Raw user data. | Raw user data. | ||||
| task_rkme_list : List[RKMEStatSpecification] | |||||
| task_rkme_list : List[RKMETableSpecification] | |||||
| The list of learwares' rkmes whose mixture approximates the user's rkme | The list of learwares' rkmes whose mixture approximates the user's rkme | ||||
| task_rkme_matrix : np.ndarray | task_rkme_matrix : np.ndarray | ||||
| Inner product matrix calculated from task_rkme_list. | Inner product matrix calculated from task_rkme_list. | ||||
| @@ -1,3 +1,3 @@ | |||||
| from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec | from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec | ||||
| from .base import Specification, BaseStatSpecification | from .base import Specification, BaseStatSpecification | ||||
| from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMEImageStatSpecification | |||||
| from .regular import RegularStatsSpecification, RKMETableSpecification, RKMEImageSpecification | |||||
| @@ -1,3 +1,3 @@ | |||||
| from .table import RKMEStatSpecification | |||||
| from .image import RKMEImageStatSpecification | |||||
| from .table import RKMETableSpecification, RKMEStatSpecification | |||||
| from .image import RKMEImageSpecification | |||||
| from .base import RegularStatsSpecification | from .base import RegularStatsSpecification | ||||
| @@ -1 +1 @@ | |||||
| from .rkme import RKMEImageStatSpecification | |||||
| from .rkme import RKMEImageSpecification | |||||
| @@ -21,7 +21,7 @@ from ..base import BaseStatSpecification | |||||
| from ..table.rkme import solve_qp, choose_device, setup_seed | from ..table.rkme import solve_qp, choose_device, setup_seed | ||||
| class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| class RKMEImageSpecification(BaseStatSpecification): | |||||
| # INNER_PRODUCT_COUNT = 0 | # INNER_PRODUCT_COUNT = 0 | ||||
| IMAGE_WIDTH = 32 | IMAGE_WIDTH = 32 | ||||
| @@ -49,7 +49,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| ) | ) | ||||
| setup_seed(0) | setup_seed(0) | ||||
| super(RKMEImageStatSpecification, self).__init__(type=self.__class__.__name__) | |||||
| super(RKMEImageSpecification, self).__init__(type=self.__class__.__name__) | |||||
| def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): | def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): | ||||
| model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) | model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) | ||||
| @@ -98,12 +98,12 @@ class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| """ | """ | ||||
| if ( | if ( | ||||
| X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH | |||||
| X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH | |||||
| ) and not resize: | ) and not resize: | ||||
| raise ValueError( | raise ValueError( | ||||
| "X should be in shape of [N, C, {0:d}, {0:d}]. " | "X should be in shape of [N, C, {0:d}, {0:d}]. " | ||||
| "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format( | "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format( | ||||
| RKMEImageStatSpecification.IMAGE_WIDTH | |||||
| RKMEImageSpecification.IMAGE_WIDTH | |||||
| ) | ) | ||||
| ) | ) | ||||
| @@ -121,9 +121,9 @@ class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| img_mean = torch.nanmean(img) | img_mean = torch.nanmean(img) | ||||
| X[i] = torch.where(is_nan, img_mean, img) | X[i] = torch.where(is_nan, img_mean, img) | ||||
| if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH: | |||||
| if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: | |||||
| X = Resize( | X = Resize( | ||||
| (RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None | |||||
| (RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None | |||||
| )(X) | )(X) | ||||
| num_points = X.shape[0] | num_points = X.shape[0] | ||||
| @@ -253,12 +253,12 @@ class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device)) | Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device)) | ||||
| return X_features, Y_features | return X_features, Y_features | ||||
| def inner_prod(self, Phi2: RKMEImageStatSpecification) -> float: | |||||
| def inner_prod(self, Phi2: RKMEImageSpecification) -> float: | |||||
| """Compute the inner product between two RKME Image specifications | """Compute the inner product between two RKME Image specifications | ||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| Phi2 : RKMEImageStatSpecification | |||||
| Phi2 : RKMEImageSpecification | |||||
| The other RKME Image specification. | The other RKME Image specification. | ||||
| Returns | Returns | ||||
| @@ -269,7 +269,7 @@ class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| v = self._inner_prod_nngp(Phi2) | v = self._inner_prod_nngp(Phi2) | ||||
| return v | return v | ||||
| def _inner_prod_nngp(self, Phi2: RKMEImageStatSpecification) -> float: | |||||
| def _inner_prod_nngp(self, Phi2: RKMEImageSpecification) -> float: | |||||
| beta_1 = self.beta.reshape(1, -1).detach().to(self.device) | beta_1 = self.beta.reshape(1, -1).detach().to(self.device) | ||||
| beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) | beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) | ||||
| @@ -283,15 +283,15 @@ class RKMEImageStatSpecification(BaseStatSpecification): | |||||
| K_zz = kernel_fn(Z1, Z2) | K_zz = kernel_fn(Z1, Z2) | ||||
| v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() | v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() | ||||
| # RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1 | |||||
| # RKMEImageSpecification.INNER_PRODUCT_COUNT += 1 | |||||
| return v | return v | ||||
| def dist(self, Phi2: RKMEImageStatSpecification, omit_term1: bool = False) -> float: | |||||
| def dist(self, Phi2: RKMEImageSpecification, omit_term1: bool = False) -> float: | |||||
| """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications | """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications | ||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| Phi2 : RKMEImageStatSpecification | |||||
| Phi2 : RKMEImageSpecification | |||||
| The other RKME specification. | The other RKME specification. | ||||
| omit_term1 : bool, optional | omit_term1 : bool, optional | ||||
| True if the inner product of self with itself can be omitted, by default False. | True if the inner product of self with itself can be omitted, by default False. | ||||
| @@ -1 +1 @@ | |||||
| from .rkme import RKMEStatSpecification | |||||
| from .rkme import RKMETableSpecification | |||||
| @@ -29,7 +29,7 @@ if not _FAISS_INSTALLED: | |||||
| logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") | logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") | ||||
| class RKMEStatSpecification(RegularStatsSpecification): | |||||
| class RKMETableSpecification(RegularStatsSpecification): | |||||
| """Reduced Kernel Mean Embedding (RKME) Specification""" | """Reduced Kernel Mean Embedding (RKME) Specification""" | ||||
| def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): | def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): | ||||
| @@ -50,7 +50,7 @@ class RKMEStatSpecification(RegularStatsSpecification): | |||||
| torch.cuda.empty_cache() | torch.cuda.empty_cache() | ||||
| self.device = choose_device(cuda_idx=cuda_idx) | self.device = choose_device(cuda_idx=cuda_idx) | ||||
| setup_seed(0) | setup_seed(0) | ||||
| super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__) | |||||
| super(RKMETableSpecification, self).__init__(type=self.__class__.__name__) | |||||
| def get_beta(self) -> np.ndarray: | def get_beta(self) -> np.ndarray: | ||||
| """Move beta(RKME weights) back to memory accessible to the CPU. | """Move beta(RKME weights) back to memory accessible to the CPU. | ||||
| @@ -333,12 +333,12 @@ class RKMEStatSpecification(RegularStatsSpecification): | |||||
| else: | else: | ||||
| logger.warning("Not enough candidates for herding!") | logger.warning("Not enough candidates for herding!") | ||||
| def inner_prod(self, Phi2: RKMEStatSpecification) -> float: | |||||
| def inner_prod(self, Phi2: RKMETableSpecification) -> float: | |||||
| """Compute the inner product between two RKME specifications | """Compute the inner product between two RKME specifications | ||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| Phi2 : RKMEStatSpecification | |||||
| Phi2 : RKMETableSpecification | |||||
| The other RKME specification. | The other RKME specification. | ||||
| Returns | Returns | ||||
| @@ -354,12 +354,12 @@ class RKMEStatSpecification(RegularStatsSpecification): | |||||
| return float(v) | return float(v) | ||||
| def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float: | |||||
| def dist(self, Phi2: RKMETableSpecification, omit_term1: bool = False) -> float: | |||||
| """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME specifications | """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME specifications | ||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| Phi2 : RKMEStatSpecification | |||||
| Phi2 : RKMETableSpecification | |||||
| The other RKME specification. | The other RKME specification. | ||||
| omit_term1 : bool, optional | omit_term1 : bool, optional | ||||
| True if the inner product of self with itself can be omitted, by default False. | True if the inner product of self with itself can be omitted, by default False. | ||||
| @@ -463,6 +463,11 @@ class RKMEStatSpecification(RegularStatsSpecification): | |||||
| else: | else: | ||||
| return False | return False | ||||
| class RKMEStatSpecification(RKMETableSpecification): | |||||
| """nickname for RKMETableSpecification, for compatibility currently. | |||||
| TODO: modify all learnware in database and remove this nickname | |||||
| """ | |||||
| pass | |||||
| def setup_seed(seed): | def setup_seed(seed): | ||||
| """Fix a random seed for addressing reproducibility issues. | """Fix a random seed for addressing reproducibility issues. | ||||
| @@ -4,7 +4,7 @@ import pandas as pd | |||||
| from typing import Union | from typing import Union | ||||
| from .base import BaseStatSpecification | from .base import BaseStatSpecification | ||||
| from .regular import RKMEStatSpecification, RKMEImageStatSpecification | |||||
| from .regular import RKMETableSpecification, RKMEImageSpecification | |||||
| from ..config import C | from ..config import C | ||||
| @@ -42,10 +42,10 @@ def generate_rkme_spec( | |||||
| nonnegative_beta: bool = True, | nonnegative_beta: bool = True, | ||||
| reduce: bool = True, | reduce: bool = True, | ||||
| cuda_idx: int = None, | cuda_idx: int = None, | ||||
| ) -> RKMEStatSpecification: | |||||
| ) -> RKMETableSpecification: | |||||
| """ | """ | ||||
| Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification. | Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification. | ||||
| Return a RKMEStatSpecification object, use .save() method to save as json file. | |||||
| Return a RKMETableSpecification object, use .save() method to save as json file. | |||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| @@ -73,8 +73,8 @@ def generate_rkme_spec( | |||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| RKMEStatSpecification | |||||
| A RKMEStatSpecification object | |||||
| RKMETableSpecification | |||||
| A RKMETableSpecification object | |||||
| """ | """ | ||||
| # Convert data type | # Convert data type | ||||
| X = convert_to_numpy(X) | X = convert_to_numpy(X) | ||||
| @@ -94,7 +94,7 @@ def generate_rkme_spec( | |||||
| cuda_idx = 0 | cuda_idx = 0 | ||||
| # Generate rkme spec | # Generate rkme spec | ||||
| rkme_spec = RKMEStatSpecification(gamma=gamma, cuda_idx=cuda_idx) | |||||
| rkme_spec = RKMETableSpecification(gamma=gamma, cuda_idx=cuda_idx) | |||||
| rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) | rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) | ||||
| return rkme_spec | return rkme_spec | ||||
| @@ -109,10 +109,10 @@ def generate_rkme_image_spec( | |||||
| reduce: bool = True, | reduce: bool = True, | ||||
| verbose: bool = True, | verbose: bool = True, | ||||
| cuda_idx: int = None, | cuda_idx: int = None, | ||||
| ) -> RKMEImageStatSpecification: | |||||
| ) -> RKMEImageSpecification: | |||||
| """ | """ | ||||
| Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image. | Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image. | ||||
| Return a RKMEImageStatSpecification object, use .save() method to save as json file. | |||||
| Return a RKMEImageSpecification object, use .save() method to save as json file. | |||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| @@ -144,8 +144,8 @@ def generate_rkme_image_spec( | |||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| RKMEImageStatSpecification | |||||
| A RKMEImageStatSpecification object | |||||
| RKMEImageSpecification | |||||
| A RKMEImageSpecification object | |||||
| """ | """ | ||||
| # Check cuda_idx | # Check cuda_idx | ||||
| @@ -157,7 +157,7 @@ def generate_rkme_image_spec( | |||||
| cuda_idx = 0 | cuda_idx = 0 | ||||
| # Generate rkme spec | # Generate rkme spec | ||||
| rkme_image_spec = RKMEImageStatSpecification(cuda_idx=cuda_idx) | |||||
| rkme_image_spec = RKMEImageSpecification(cuda_idx=cuda_idx) | |||||
| rkme_image_spec.generate_stat_spec_from_data( | rkme_image_spec.generate_stat_spec_from_data( | ||||
| X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose | X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose | ||||
| ) | ) | ||||
| @@ -3,6 +3,6 @@ model: | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEStatSpecification | |||||
| class_name: RKMETableSpecification | |||||
| file_name: svm.json | file_name: svm.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -170,9 +170,9 @@ class TestMarket(unittest.TestCase): | |||||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | with zipfile.ZipFile(zip_path, "r") as zip_obj: | ||||
| zip_obj.extractall(path=unzip_dir) | zip_obj.extractall(path=unzip_dir) | ||||
| user_spec = specification.rkme.RKMEStatSpecification() | |||||
| user_spec = specification.rkme.RKMETableSpecification() | |||||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | user_spec.load(os.path.join(unzip_dir, "svm.json")) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||||
| ( | ( | ||||
| sorted_score_list, | sorted_score_list, | ||||
| single_learnware_list, | single_learnware_list, | ||||
| @@ -5,7 +5,7 @@ import unittest | |||||
| import tempfile | import tempfile | ||||
| import numpy as np | import numpy as np | ||||
| from learnware.specification import RKMEStatSpecification, RKMEImageStatSpecification | |||||
| from learnware.specification import RKMETableSpecification, RKMEImageSpecification | |||||
| from learnware.specification import generate_rkme_image_spec, generate_rkme_spec | from learnware.specification import generate_rkme_image_spec, generate_rkme_spec | ||||
| @@ -22,11 +22,11 @@ class TestRKME(unittest.TestCase): | |||||
| with open(rkme_path, "r") as f: | with open(rkme_path, "r") as f: | ||||
| data = json.load(f) | data = json.load(f) | ||||
| assert data["type"] == "RKMEStatSpecification" | |||||
| assert data["type"] == "RKMETableSpecification" | |||||
| rkme2 = RKMEStatSpecification() | |||||
| rkme2 = RKMETableSpecification() | |||||
| rkme2.load(rkme_path) | rkme2.load(rkme_path) | ||||
| assert rkme2.type == "RKMEStatSpecification" | |||||
| assert rkme2.type == "RKMETableSpecification" | |||||
| def test_image_rkme(self): | def test_image_rkme(self): | ||||
| def _test_image_rkme(X): | def _test_image_rkme(X): | ||||
| @@ -38,11 +38,11 @@ class TestRKME(unittest.TestCase): | |||||
| with open(rkme_path, "r") as f: | with open(rkme_path, "r") as f: | ||||
| data = json.load(f) | data = json.load(f) | ||||
| assert data["type"] == "RKMEImageStatSpecification" | |||||
| assert data["type"] == "RKMEImageSpecification" | |||||
| rkme2 = RKMEImageStatSpecification() | |||||
| rkme2 = RKMEImageSpecification() | |||||
| rkme2.load(rkme_path) | rkme2.load(rkme_path) | ||||
| assert rkme2.type == "RKMEImageStatSpecification" | |||||
| assert rkme2.type == "RKMEImageSpecification" | |||||
| _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | ||||
| _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) | _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) | ||||
| @@ -3,6 +3,6 @@ model: | |||||
| kwargs: {} | kwargs: {} | ||||
| stat_specifications: | stat_specifications: | ||||
| - module_path: learnware.specification | - module_path: learnware.specification | ||||
| class_name: RKMEStatSpecification | |||||
| class_name: RKMETableSpecification | |||||
| file_name: svm.json | file_name: svm.json | ||||
| kwargs: {} | kwargs: {} | ||||
| @@ -155,9 +155,9 @@ class TestAllWorkflow(unittest.TestCase): | |||||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | with zipfile.ZipFile(zip_path, "r") as zip_obj: | ||||
| zip_obj.extractall(path=unzip_dir) | zip_obj.extractall(path=unzip_dir) | ||||
| user_spec = specification.RKMEStatSpecification() | |||||
| user_spec = specification.RKMETableSpecification() | |||||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | user_spec.load(os.path.join(unzip_dir, "svm.json")) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||||
| ( | ( | ||||
| sorted_score_list, | sorted_score_list, | ||||
| single_learnware_list, | single_learnware_list, | ||||
| @@ -182,7 +182,7 @@ class TestAllWorkflow(unittest.TestCase): | |||||
| train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | ||||
| stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) | ||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec}) | |||||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||||
| _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) | ||||