Browse Source

[MNT] rename table and image rkme

tags/v0.3.2
bxdd 2 years ago
parent
commit
6b1694825c
31 changed files with 128 additions and 123 deletions
  1. +3
    -3
      README.md
  2. +1
    -1
      docs/references/api.rst
  3. +4
    -4
      docs/start/client.rst
  4. +3
    -3
      docs/start/quick.rst
  5. +2
    -2
      docs/workflow/identify.rst
  6. +2
    -2
      docs/workflow/submit.rst
  7. +1
    -1
      examples/dataset_image_workflow/example_files/example_yaml.yaml
  8. +4
    -4
      examples/dataset_image_workflow/main.py
  9. +1
    -1
      examples/dataset_m5_workflow/example.yaml
  10. +1
    -1
      examples/dataset_m5_workflow/main.py
  11. +1
    -1
      examples/dataset_pfs_workflow/example.yaml
  12. +1
    -1
      examples/dataset_pfs_workflow/main.py
  13. +1
    -1
      examples/workflow_by_code/learnware_example/example.yaml
  14. +3
    -3
      examples/workflow_by_code/main.py
  15. +1
    -1
      learnware/learnware/__init__.py
  16. +21
    -21
      learnware/market/easy.py
  17. +1
    -1
      learnware/market/easy2/checker.py
  18. +20
    -20
      learnware/market/easy2/searcher.py
  19. +4
    -4
      learnware/reuse/job_selector.py
  20. +1
    -1
      learnware/specification/__init__.py
  21. +2
    -2
      learnware/specification/regular/__init__.py
  22. +1
    -1
      learnware/specification/regular/image/__init__.py
  23. +12
    -12
      learnware/specification/regular/image/rkme.py
  24. +1
    -1
      learnware/specification/regular/table/__init__.py
  25. +11
    -6
      learnware/specification/regular/table/rkme.py
  26. +11
    -11
      learnware/specification/utils.py
  27. +1
    -1
      tests/test_market/learnware_example/example.yaml
  28. +2
    -2
      tests/test_market/test_easy.py
  29. +7
    -7
      tests/test_specification/test_rkme.py
  30. +1
    -1
      tests/test_workflow/learnware_example/example.yaml
  31. +3
    -3
      tests/test_workflow/test_workflow.py

+ 3
- 3
README.md View File

@@ -80,7 +80,7 @@ is composed of the following four parts.


- ``learnware.yaml`` - ``learnware.yaml``
A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and
A config file describing your model class name, type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and
the file name of your statistical specification file. the file name of your statistical specification file.


- ``environment.yaml`` - ``environment.yaml``
@@ -178,10 +178,10 @@ For example, the following code is designed to work with Reduced Set Kernel Embe
```python ```python
import learnware.specification as specification import learnware.specification as specification


user_spec = specification.RKMEStatSpecification()
user_spec = specification.RKMETableSpecification()
user_spec.load(os.path.join(unzip_path, "rkme.json")) user_spec.load(os.path.join(unzip_path, "rkme.json"))
user_info = BaseUserInfo( user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}
semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}
) )
(sorted_score_list, single_learnware_list, (sorted_score_list, single_learnware_list,
mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info)


+ 1
- 1
docs/references/api.rst View File

@@ -50,7 +50,7 @@ Specification
.. autoclass:: learnware.specification.BaseStatSpecification .. autoclass:: learnware.specification.BaseStatSpecification
:members: :members:


.. autoclass:: learnware.specification.RKMEStatSpecification
.. autoclass:: learnware.specification.RKMETableSpecification
:members: :members:


Model Model


+ 4
- 4
docs/start/client.rst View File

@@ -117,13 +117,13 @@ You can search learnwares in official market using semantic specification. All t
Statistical Specification Search Statistical Specification Search
--------------------------------- ---------------------------------


You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMEStatSpecification`:
You can search learnware by providing a statistical specification. The statistical specification is a json file that contains the statistical information of your training data. For example, the code below searches learnwares with `RKMETableSpecification`:


.. code-block:: python .. code-block:: python


import learnware.specification as specification import learnware.specification as specification


user_spec = specification.RKMEStatSpecification()
user_spec = specification.RKMETableSpecification()
user_spec.load(os.path.join(unzip_path, "rkme.json")) user_spec.load(os.path.join(unzip_path, "rkme.json"))
specification = learnware.specification.Specification() specification = learnware.specification.Specification()
@@ -138,7 +138,7 @@ You can search learnware by providing a statistical specification. The statistic


Combine Semantic and Statistical Search Combine Semantic and Statistical Search
---------------------------------------- ----------------------------------------
You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMEStatSpecification`:
You can provide both semantic and statistical specification to search learnwares. The engine will first filter learnwares by semantic specification and then search by statistical specification. For example, the code below searches learnwares with `Table` data type and `RKMETableSpecification`:


.. code-block:: python .. code-block:: python


@@ -151,7 +151,7 @@ You can provide both semantic and statistical specification to search learnwares
senarioes=[], senarioes=[],
input_description={}, output_description={}) input_description={}, output_description={})


stat_spec = specification.RKMEStatSpecification()
stat_spec = specification.RKMETableSpecification()
stat_spec.load(os.path.join(unzip_path, "rkme.json")) stat_spec.load(os.path.join(unzip_path, "rkme.json"))
specification = learnware.specification.Specification() specification = learnware.specification.Specification()
specification.update_semantic_spec(semantic_spec) specification.update_semantic_spec(semantic_spec)


+ 3
- 3
docs/start/quick.rst View File

@@ -47,7 +47,7 @@ includes the following four components:


- ``learnware.yaml`` - ``learnware.yaml``
A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMEStatSpecification`` for Reduced Kernel Mean Embedding), and
A configuration file that details your model's class name, the type of statistical specification(e.g. ``RKMETableSpecification`` for Reduced Kernel Mean Embedding), and
the file name of your statistical specification file. the file name of your statistical specification file.


- ``environment.yaml`` or ``requirements.txt`` - ``environment.yaml`` or ``requirements.txt``
@@ -170,12 +170,12 @@ For example, the code below executes learnware search when using Reduced Set Ker


import learnware.specification as specification import learnware.specification as specification


user_spec = specification.RKMEStatSpecification()
user_spec = specification.RKMETableSpecification()


# unzip_path: directory for unzipped learnware zipfile # unzip_path: directory for unzipped learnware zipfile
user_spec.load(os.path.join(unzip_path, "rkme.json")) user_spec.load(os.path.join(unzip_path, "rkme.json"))
user_info = BaseUserInfo( user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}
semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}
) )
(sorted_score_list, single_learnware_list, (sorted_score_list, single_learnware_list,
mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info)


+ 2
- 2
docs/workflow/identify.rst View File

@@ -73,10 +73,10 @@ For example, the following code is designed to work with Reduced Kernel Mean Emb


import learnware.specification as specification import learnware.specification as specification


user_spec = specification.RKMEStatSpecification()
user_spec = specification.RKMETableSpecification()
user_spec.load(os.path.join("rkme.json")) user_spec.load(os.path.join("rkme.json"))
user_info = BaseUserInfo( user_info = BaseUserInfo(
semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec}
semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}
) )
(sorted_score_list, single_learnware_list, (sorted_score_list, single_learnware_list,
mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info) mixture_score, mixture_learnware_list) = easy_market.search_learnware(user_info)


+ 2
- 2
docs/workflow/submit.rst View File

@@ -94,7 +94,7 @@ guaranteeing the security and privacy of your local original data.
------------------ ------------------


Additionally, you are asked to prepare a configuration file in YAML format. Additionally, you are asked to prepare a configuration file in YAML format.
The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMEStatSpecification``), and
The file should detail your model's class name, the type of statistical specification(e.g. Reduced Kernel Mean Embedding, ``RKMETableSpecification``), and
the file name of your statistical specification file. The following ``learnware.yaml`` provides an example of the file name of your statistical specification file. The following ``learnware.yaml`` provides an example of
how your learnware configuration file should be structured, based on our previous discussion: how your learnware configuration file should be structured, based on our previous discussion:


@@ -105,7 +105,7 @@ how your learnware configuration file should be structured, based on our previou
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEStatSpecification
class_name: RKMETableSpecification
file_name: stat.json file_name: stat.json
kwargs: {} kwargs: {}




+ 1
- 1
examples/dataset_image_workflow/example_files/example_yaml.yaml View File

@@ -3,6 +3,6 @@ model:
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEImageStatSpecification
class_name: RKMEImageSpecification
file_name: rkme.json file_name: rkme.json
kwargs: {} kwargs: {}

+ 4
- 4
examples/dataset_image_workflow/main.py View File

@@ -6,7 +6,7 @@ from get_data import *
import os import os
import random import random


from learnware.specification.image import RKMEImageStatSpecification
from learnware.specification.image import RKMEImageSpecification
from learnware.reuse.averaging import AveragingReuser from learnware.reuse.averaging import AveragingReuser
from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction from utils import generate_uploader, generate_user, ImageDataLoader, train, eval_prediction
from learnware.learnware import Learnware from learnware.learnware import Learnware
@@ -100,7 +100,7 @@ def prepare_learnware(data_path, model_path, init_file_path, yaml_path, save_roo
X_sampled = X[indices] X_sampled = X[indices]


st = time.time() st = time.time()
user_spec = RKMEImageStatSpecification(cuda_idx=0)
user_spec = RKMEImageSpecification(cuda_idx=0)
user_spec.generate_stat_spec_from_data(X=X_sampled) user_spec.generate_stat_spec_from_data(X=X_sampled)
ed = time.time() ed = time.time()
logger.info("Stat spec generated in %.3f s" % (ed - st)) logger.info("Stat spec generated in %.3f s" % (ed - st))
@@ -164,9 +164,9 @@ def test_search(gamma=0.1, load_market=True):
user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i)) user_label_path = os.path.join(user_save_root, "user_%d_y.npy" % (i))
user_data = np.load(user_data_path) user_data = np.load(user_data_path)
user_label = np.load(user_label_path) user_label = np.load(user_label_path)
user_stat_spec = RKMEImageStatSpecification(cuda_idx=0)
user_stat_spec = RKMEImageSpecification(cuda_idx=0)
user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False) user_stat_spec.generate_stat_spec_from_data(X=user_data, resize=False)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_stat_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_stat_spec})
logger.info("Searching Market for user: %d" % i) logger.info("Searching Market for user: %d" % i)
sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware( sorted_score_list, single_learnware_list, mixture_score, mixture_learnware_list = image_market.search_learnware(
user_info user_info


+ 1
- 1
examples/dataset_m5_workflow/example.yaml View File

@@ -3,6 +3,6 @@ model:
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEStatSpecification
class_name: RKMETableSpecification
file_name: rkme.json file_name: rkme.json
kwargs: {} kwargs: {}

+ 1
- 1
examples/dataset_m5_workflow/main.py View File

@@ -144,7 +144,7 @@ class M5DatasetWorkflow:
user_spec_path = f"./user_spec/user_{idx}.json" user_spec_path = f"./user_spec/user_{idx}.json"
user_spec.save(user_spec_path) user_spec.save(user_spec_path)


user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
( (
sorted_score_list, sorted_score_list,
single_learnware_list, single_learnware_list,


+ 1
- 1
examples/dataset_pfs_workflow/example.yaml View File

@@ -3,6 +3,6 @@ model:
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEStatSpecification
class_name: RKMETableSpecification
file_name: rkme.json file_name: rkme.json
kwargs: {} kwargs: {}

+ 1
- 1
examples/dataset_pfs_workflow/main.py View File

@@ -142,7 +142,7 @@ class PFSDatasetWorkflow:
user_spec_path = f"./user_spec/user_{idx}.json" user_spec_path = f"./user_spec/user_{idx}.json"
user_spec.save(user_spec_path) user_spec.save(user_spec_path)


user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
( (
sorted_score_list, sorted_score_list,
single_learnware_list, single_learnware_list,


+ 1
- 1
examples/workflow_by_code/learnware_example/example.yaml View File

@@ -3,6 +3,6 @@ model:
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEStatSpecification
class_name: RKMETableSpecification
file_name: svm.json file_name: svm.json
kwargs: {} kwargs: {}

+ 3
- 3
examples/workflow_by_code/main.py View File

@@ -148,9 +148,9 @@ class LearnwareMarketWorkflow:
with zipfile.ZipFile(zip_path, "r") as zip_obj: with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir) zip_obj.extractall(path=unzip_dir)


user_spec = specification.RKMEStatSpecification()
user_spec = specification.RKMETableSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json")) user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
( (
sorted_score_list, sorted_score_list,
single_learnware_list, single_learnware_list,
@@ -175,7 +175,7 @@ class LearnwareMarketWorkflow:
_, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) _, data_X, _, data_y = train_test_split(X, y, test_size=0.3, shuffle=True)


stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})


_, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info)




+ 1
- 1
learnware/learnware/__init__.py View File

@@ -37,7 +37,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath:
"stat_specifications": [ "stat_specifications": [
{ {
"module_path": "learnware.specification", "module_path": "learnware.specification",
"class_name": "RKMEStatSpecification",
"class_name": "RKMETableSpecification",
"file_name": "stat_spec.json", "file_name": "stat_spec.json",
"kwargs": {}, "kwargs": {},
}, },


+ 21
- 21
learnware/market/easy.py View File

@@ -18,7 +18,7 @@ from .. import utils
from ..config import C as conf from ..config import C as conf
from ..logger import get_module_logger from ..logger import get_module_logger
from ..learnware import Learnware, get_learnware_from_dirpath from ..learnware import Learnware, get_learnware_from_dirpath
from ..specification import RKMEStatSpecification, Specification
from ..specification import RKMETableSpecification, Specification




logger = get_module_logger("market", "INFO") logger = get_module_logger("market", "INFO")
@@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket):
pass pass


# check rkme dimension # check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification")
if stat_spec is not None: if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape: if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification")
@@ -296,7 +296,7 @@ class EasyMarket(LearnwareMarket):
def _calculate_rkme_spec_mixture_weight( def _calculate_rkme_spec_mixture_weight(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
intermediate_K: np.ndarray = None, intermediate_K: np.ndarray = None,
intermediate_C: np.ndarray = None, intermediate_C: np.ndarray = None,
) -> Tuple[List[float], float]: ) -> Tuple[List[float], float]:
@@ -306,7 +306,7 @@ class EasyMarket(LearnwareMarket):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
A list of existing learnwares A list of existing learnwares
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
intermediate_K : np.ndarray, optional intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None Intermediate kernel matrix K, by default None
@@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket):
""" """
learnware_num = len(learnware_list) learnware_num = len(learnware_list)
RKME_list = [ RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
] ]


if type(intermediate_K) == np.ndarray: if type(intermediate_K) == np.ndarray:
@@ -365,7 +365,7 @@ class EasyMarket(LearnwareMarket):
def _calculate_intermediate_K_and_C( def _calculate_intermediate_K_and_C(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
intermediate_K: np.ndarray = None, intermediate_K: np.ndarray = None,
intermediate_C: np.ndarray = None, intermediate_C: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
@@ -375,7 +375,7 @@ class EasyMarket(LearnwareMarket):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares up till now The list of learnwares up till now
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
intermediate_K : np.ndarray, optional intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None Intermediate kernel matrix K, by default None
@@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket):
""" """
num = intermediate_K.shape[0] - 1 num = intermediate_K.shape[0] - 1
RKME_list = [ RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
] ]
for i in range(intermediate_K.shape[0]): for i in range(intermediate_K.shape[0]):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
@@ -400,7 +400,7 @@ class EasyMarket(LearnwareMarket):
def _search_by_rkme_spec_mixture_auto( def _search_by_rkme_spec_mixture_auto(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
max_search_num: int, max_search_num: int,
weight_cutoff: float = 0.98, weight_cutoff: float = 0.98,
) -> Tuple[float, List[float], List[Learnware]]: ) -> Tuple[float, List[float], List[Learnware]]:
@@ -410,7 +410,7 @@ class EasyMarket(LearnwareMarket):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
max_search_num : int max_search_num : int
The maximum number of the returned learnwares The maximum number of the returned learnwares
@@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket):
if len(mixture_list) <= 1: if len(mixture_list) <= 1:
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_list = [learnware_list[sort_by_weight_idx_list[0]]]
mixture_weight = [1] mixture_weight = [1]
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification"))
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification"))
else: else:
if len(mixture_list) > max_search_num: if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num] mixture_list = mixture_list[:max_search_num]
@@ -488,7 +488,7 @@ class EasyMarket(LearnwareMarket):
return sorted_score_list[:idx], learnware_list[:idx] return sorted_score_list[:idx], learnware_list[:idx]


def _filter_by_rkme_spec_dimension( def _filter_by_rkme_spec_dimension(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification
self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification
) -> List[Learnware]: ) -> List[Learnware]:
"""Filter learnwares whose rkme dimension different from user_rkme """Filter learnwares whose rkme dimension different from user_rkme


@@ -496,7 +496,7 @@ class EasyMarket(LearnwareMarket):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification


Returns Returns
@@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket):
user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) user_rkme_dim = str(list(user_rkme.get_z().shape)[1:])


for learnware in learnware_list: for learnware in learnware_list:
rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification")
rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification")
rkme_dim = str(list(rkme.get_z().shape)[1:]) rkme_dim = str(list(rkme.get_z().shape)[1:])
if rkme_dim == user_rkme_dim: if rkme_dim == user_rkme_dim:
filtered_learnware_list.append(learnware) filtered_learnware_list.append(learnware)
@@ -518,7 +518,7 @@ class EasyMarket(LearnwareMarket):
def _search_by_rkme_spec_mixture_greedy( def _search_by_rkme_spec_mixture_greedy(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
max_search_num: int, max_search_num: int,
score_cutoff: float = 0.001, score_cutoff: float = 0.001,
) -> Tuple[float, List[float], List[Learnware]]: ) -> Tuple[float, List[float], List[Learnware]]:
@@ -528,7 +528,7 @@ class EasyMarket(LearnwareMarket):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
max_search_num : int max_search_num : int
The maximum number of the returned learnwares The maximum number of the returned learnwares
@@ -588,7 +588,7 @@ class EasyMarket(LearnwareMarket):
return mmd_dist, weight_min, mixture_list return mmd_dist, weight_min, mixture_list


def _search_by_rkme_spec_single( def _search_by_rkme_spec_single(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification
self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification
) -> Tuple[List[float], List[Learnware]]: ) -> Tuple[List[float], List[Learnware]]:
"""Calculate the distances between learnwares in the given learnware_list and user_rkme """Calculate the distances between learnwares in the given learnware_list and user_rkme


@@ -596,7 +596,7 @@ class EasyMarket(LearnwareMarket):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
user RKME statistical specification user RKME statistical specification


Returns Returns
@@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket):
both lists are sorted by mmd dist both lists are sorted by mmd dist
""" """
RKME_list = [ RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
] ]
mmd_dist_list = [] mmd_dist_list = []
for RKME in RKME_list: for RKME in RKME_list:
@@ -819,12 +819,12 @@ class EasyMarket(LearnwareMarket):
# if len(learnware_list) == 0: # if len(learnware_list) == 0:
learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info) learnware_list = self._search_by_semantic_spec_fuzz(learnware_list, user_info)


if "RKMEStatSpecification" not in user_info.stat_info:
if "RKMETableSpecification" not in user_info.stat_info:
return None, learnware_list, 0.0, None return None, learnware_list, 0.0, None
elif len(learnware_list) == 0: elif len(learnware_list) == 0:
return [], [], 0.0, [] return [], [], 0.0, []
else: else:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
user_rkme = user_info.stat_info["RKMETableSpecification"]
learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme)
logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}")




+ 1
- 1
learnware/market/easy2/checker.py View File

@@ -77,7 +77,7 @@ class EasyStatisticalChecker(BaseChecker):
input_shape = learnware_model.input_shape input_shape = learnware_model.input_shape


# Check rkme dimension # Check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification")
if stat_spec is not None: if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape: if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.") logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification.")


+ 20
- 20
learnware/market/easy2/searcher.py View File

@@ -7,7 +7,7 @@ from typing import Tuple, List
from .organizer import EasyOrganizer from .organizer import EasyOrganizer
from ..base import BaseUserInfo, BaseSearcher from ..base import BaseUserInfo, BaseSearcher
from ...learnware import Learnware from ...learnware import Learnware
from ...specification import RKMEStatSpecification
from ...specification import RKMETableSpecification
from ...logger import get_module_logger from ...logger import get_module_logger


logger = get_module_logger("easy_seacher") logger = get_module_logger("easy_seacher")
@@ -227,7 +227,7 @@ class EasyTableSearcher(BaseSearcher):
def _calculate_rkme_spec_mixture_weight( def _calculate_rkme_spec_mixture_weight(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
intermediate_K: np.ndarray = None, intermediate_K: np.ndarray = None,
intermediate_C: np.ndarray = None, intermediate_C: np.ndarray = None,
) -> Tuple[List[float], float]: ) -> Tuple[List[float], float]:
@@ -237,7 +237,7 @@ class EasyTableSearcher(BaseSearcher):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
A list of existing learnwares A list of existing learnwares
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
intermediate_K : np.ndarray, optional intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None Intermediate kernel matrix K, by default None
@@ -252,7 +252,7 @@ class EasyTableSearcher(BaseSearcher):
""" """
learnware_num = len(learnware_list) learnware_num = len(learnware_list)
RKME_list = [ RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
] ]


if type(intermediate_K) == np.ndarray: if type(intermediate_K) == np.ndarray:
@@ -296,7 +296,7 @@ class EasyTableSearcher(BaseSearcher):
def _calculate_intermediate_K_and_C( def _calculate_intermediate_K_and_C(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
intermediate_K: np.ndarray = None, intermediate_K: np.ndarray = None,
intermediate_C: np.ndarray = None, intermediate_C: np.ndarray = None,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
@@ -306,7 +306,7 @@ class EasyTableSearcher(BaseSearcher):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares up till now The list of learnwares up till now
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
intermediate_K : np.ndarray, optional intermediate_K : np.ndarray, optional
Intermediate kernel matrix K, by default None Intermediate kernel matrix K, by default None
@@ -321,7 +321,7 @@ class EasyTableSearcher(BaseSearcher):
""" """
num = intermediate_K.shape[0] - 1 num = intermediate_K.shape[0] - 1
RKME_list = [ RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
] ]
for i in range(intermediate_K.shape[0]): for i in range(intermediate_K.shape[0]):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
@@ -331,7 +331,7 @@ class EasyTableSearcher(BaseSearcher):
def _search_by_rkme_spec_mixture_auto( def _search_by_rkme_spec_mixture_auto(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
max_search_num: int, max_search_num: int,
weight_cutoff: float = 0.98, weight_cutoff: float = 0.98,
) -> Tuple[float, List[float], List[Learnware]]: ) -> Tuple[float, List[float], List[Learnware]]:
@@ -341,7 +341,7 @@ class EasyTableSearcher(BaseSearcher):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
max_search_num : int max_search_num : int
The maximum number of the returned learnwares The maximum number of the returned learnwares
@@ -377,7 +377,7 @@ class EasyTableSearcher(BaseSearcher):
if len(mixture_list) <= 1: if len(mixture_list) <= 1:
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_list = [learnware_list[sort_by_weight_idx_list[0]]]
mixture_weight = [1] mixture_weight = [1]
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification"))
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification"))
else: else:
if len(mixture_list) > max_search_num: if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num] mixture_list = mixture_list[:max_search_num]
@@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher):
return sorted_score_list[:idx], learnware_list[:idx] return sorted_score_list[:idx], learnware_list[:idx]


def _filter_by_rkme_spec_dimension( def _filter_by_rkme_spec_dimension(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification
self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification
) -> List[Learnware]: ) -> List[Learnware]:
"""Filter learnwares whose rkme dimension different from user_rkme """Filter learnwares whose rkme dimension different from user_rkme


@@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification


Returns Returns
@@ -439,7 +439,7 @@ class EasyTableSearcher(BaseSearcher):
user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) user_rkme_dim = str(list(user_rkme.get_z().shape)[1:])


for learnware in learnware_list: for learnware in learnware_list:
rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification")
rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification")
rkme_dim = str(list(rkme.get_z().shape)[1:]) rkme_dim = str(list(rkme.get_z().shape)[1:])
if rkme_dim == user_rkme_dim: if rkme_dim == user_rkme_dim:
filtered_learnware_list.append(learnware) filtered_learnware_list.append(learnware)
@@ -449,7 +449,7 @@ class EasyTableSearcher(BaseSearcher):
def _search_by_rkme_spec_mixture_greedy( def _search_by_rkme_spec_mixture_greedy(
self, self,
learnware_list: List[Learnware], learnware_list: List[Learnware],
user_rkme: RKMEStatSpecification,
user_rkme: RKMETableSpecification,
max_search_num: int, max_search_num: int,
score_cutoff: float = 0.001, score_cutoff: float = 0.001,
) -> Tuple[float, List[float], List[Learnware]]: ) -> Tuple[float, List[float], List[Learnware]]:
@@ -459,7 +459,7 @@ class EasyTableSearcher(BaseSearcher):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
User RKME statistical specification User RKME statistical specification
max_search_num : int max_search_num : int
The maximum number of the returned learnwares The maximum number of the returned learnwares
@@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher):
return mmd_dist, weight_min, mixture_list return mmd_dist, weight_min, mixture_list


def _search_by_rkme_spec_single( def _search_by_rkme_spec_single(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification
self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification
) -> Tuple[List[float], List[Learnware]]: ) -> Tuple[List[float], List[Learnware]]:
"""Calculate the distances between learnwares in the given learnware_list and user_rkme """Calculate the distances between learnwares in the given learnware_list and user_rkme


@@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher):
---------- ----------
learnware_list : List[Learnware] learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMEStatSpecification
user_rkme : RKMETableSpecification
user RKME statistical specification user RKME statistical specification


Returns Returns
@@ -538,7 +538,7 @@ class EasyTableSearcher(BaseSearcher):
both lists are sorted by mmd dist both lists are sorted by mmd dist
""" """
RKME_list = [ RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
] ]
mmd_dist_list = [] mmd_dist_list = []
for RKME in RKME_list: for RKME in RKME_list:
@@ -558,7 +558,7 @@ class EasyTableSearcher(BaseSearcher):
max_search_num: int = 5, max_search_num: int = 5,
search_method: str = "greedy", search_method: str = "greedy",
) -> Tuple[List[float], List[Learnware], float, List[Learnware]]: ) -> Tuple[List[float], List[Learnware], float, List[Learnware]]:
user_rkme = user_info.stat_info["RKMEStatSpecification"]
user_rkme = user_info.stat_info["RKMETableSpecification"]
learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme) learnware_list = self._filter_by_rkme_spec_dimension(learnware_list, user_rkme)
logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}") logger.info(f"After filter by rkme dimension, learnware_list length is {len(learnware_list)}")


@@ -631,7 +631,7 @@ class EasySearcher(BaseSearcher):


if len(learnware_list) == 0: if len(learnware_list) == 0:
return [], [], 0.0, [] return [], [], 0.0, []
elif "RKMEStatSpecification" in user_info.stat_info:
elif "RKMETableSpecification" in user_info.stat_info:
return self.table_searcher(learnware_list, user_info, max_search_num, search_method) return self.table_searcher(learnware_list, user_info, max_search_num, search_method)
else: else:
return None, learnware_list, 0.0, None return None, learnware_list, 0.0, None

+ 4
- 4
learnware/reuse/job_selector.py View File

@@ -9,7 +9,7 @@ from sklearn.metrics import accuracy_score
from learnware.learnware import Learnware from learnware.learnware import Learnware
import learnware.specification as specification import learnware.specification as specification
from .base import BaseReuser from .base import BaseReuser
from ..specification import RKMEStatSpecification
from ..specification import RKMETableSpecification
from ..logger import get_module_logger from ..logger import get_module_logger


logger = get_module_logger("job_selector_reuse") logger = get_module_logger("job_selector_reuse")
@@ -86,7 +86,7 @@ class JobSelectorReuser(BaseReuser):
return np.array([0] * user_data_num) return np.array([0] * user_data_num)
else: else:
learnware_rkme_spec_list = [ learnware_rkme_spec_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification")
learnware.specification.get_stat_spec_by_name("RKMETableSpecification")
for learnware in self.learnware_list for learnware in self.learnware_list
] ]


@@ -154,7 +154,7 @@ class JobSelectorReuser(BaseReuser):
return job_select_result return job_select_result


def _calculate_rkme_spec_mixture_weight( def _calculate_rkme_spec_mixture_weight(
self, user_data: np.ndarray, task_rkme_list: List[RKMEStatSpecification], task_rkme_matrix: np.ndarray
self, user_data: np.ndarray, task_rkme_list: List[RKMETableSpecification], task_rkme_matrix: np.ndarray
) -> List[float]: ) -> List[float]:
"""_summary_ """_summary_


@@ -162,7 +162,7 @@ class JobSelectorReuser(BaseReuser):
---------- ----------
user_data : np.ndarray user_data : np.ndarray
Raw user data. Raw user data.
task_rkme_list : List[RKMEStatSpecification]
task_rkme_list : List[RKMETableSpecification]
The list of learwares' rkmes whose mixture approximates the user's rkme The list of learwares' rkmes whose mixture approximates the user's rkme
task_rkme_matrix : np.ndarray task_rkme_matrix : np.ndarray
Inner product matrix calculated from task_rkme_list. Inner product matrix calculated from task_rkme_list.


+ 1
- 1
learnware/specification/__init__.py View File

@@ -1,3 +1,3 @@
from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec from .utils import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec
from .base import Specification, BaseStatSpecification from .base import Specification, BaseStatSpecification
from .regular import RegularStatsSpecification, RKMEStatSpecification, RKMEImageStatSpecification
from .regular import RegularStatsSpecification, RKMETableSpecification, RKMEImageSpecification

+ 2
- 2
learnware/specification/regular/__init__.py View File

@@ -1,3 +1,3 @@
from .table import RKMEStatSpecification
from .image import RKMEImageStatSpecification
from .table import RKMETableSpecification, RKMEStatSpecification
from .image import RKMEImageSpecification
from .base import RegularStatsSpecification from .base import RegularStatsSpecification

+ 1
- 1
learnware/specification/regular/image/__init__.py View File

@@ -1 +1 @@
from .rkme import RKMEImageStatSpecification
from .rkme import RKMEImageSpecification

+ 12
- 12
learnware/specification/regular/image/rkme.py View File

@@ -21,7 +21,7 @@ from ..base import BaseStatSpecification
from ..table.rkme import solve_qp, choose_device, setup_seed from ..table.rkme import solve_qp, choose_device, setup_seed




class RKMEImageStatSpecification(BaseStatSpecification):
class RKMEImageSpecification(BaseStatSpecification):
# INNER_PRODUCT_COUNT = 0 # INNER_PRODUCT_COUNT = 0
IMAGE_WIDTH = 32 IMAGE_WIDTH = 32


@@ -49,7 +49,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):
) )


setup_seed(0) setup_seed(0)
super(RKMEImageStatSpecification, self).__init__(type=self.__class__.__name__)
super(RKMEImageSpecification, self).__init__(type=self.__class__.__name__)


def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None): def _generate_models(self, n_models: int, channel: int = 3, fixed_seed=None):
model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config) model_class = functools.partial(_ConvNet_wide, channel=channel, **self.model_config)
@@ -98,12 +98,12 @@ class RKMEImageStatSpecification(BaseStatSpecification):


""" """
if ( if (
X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH
X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH
) and not resize: ) and not resize:
raise ValueError( raise ValueError(
"X should be in shape of [N, C, {0:d}, {0:d}]. " "X should be in shape of [N, C, {0:d}, {0:d}]. "
"Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format( "Or set resize=True and the image will be automatically resized to {0:d} x {0:d}.".format(
RKMEImageStatSpecification.IMAGE_WIDTH
RKMEImageSpecification.IMAGE_WIDTH
) )
) )


@@ -121,9 +121,9 @@ class RKMEImageStatSpecification(BaseStatSpecification):
img_mean = torch.nanmean(img) img_mean = torch.nanmean(img)
X[i] = torch.where(is_nan, img_mean, img) X[i] = torch.where(is_nan, img_mean, img)


if X.shape[2] != RKMEImageStatSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageStatSpecification.IMAGE_WIDTH:
if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH:
X = Resize( X = Resize(
(RKMEImageStatSpecification.IMAGE_WIDTH, RKMEImageStatSpecification.IMAGE_WIDTH), antialias=None
(RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None
)(X) )(X)


num_points = X.shape[0] num_points = X.shape[0]
@@ -253,12 +253,12 @@ class RKMEImageStatSpecification(BaseStatSpecification):
Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device)) Y_features = Y_features / torch.sqrt(torch.asarray(Y_features.shape[1], device=self.device))
return X_features, Y_features return X_features, Y_features


def inner_prod(self, Phi2: RKMEImageStatSpecification) -> float:
def inner_prod(self, Phi2: RKMEImageSpecification) -> float:
"""Compute the inner product between two RKME Image specifications """Compute the inner product between two RKME Image specifications


Parameters Parameters
---------- ----------
Phi2 : RKMEImageStatSpecification
Phi2 : RKMEImageSpecification
The other RKME Image specification. The other RKME Image specification.


Returns Returns
@@ -269,7 +269,7 @@ class RKMEImageStatSpecification(BaseStatSpecification):
v = self._inner_prod_nngp(Phi2) v = self._inner_prod_nngp(Phi2)
return v return v


def _inner_prod_nngp(self, Phi2: RKMEImageStatSpecification) -> float:
def _inner_prod_nngp(self, Phi2: RKMEImageSpecification) -> float:
beta_1 = self.beta.reshape(1, -1).detach().to(self.device) beta_1 = self.beta.reshape(1, -1).detach().to(self.device)
beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device) beta_2 = Phi2.beta.reshape(1, -1).detach().to(self.device)


@@ -283,15 +283,15 @@ class RKMEImageStatSpecification(BaseStatSpecification):
K_zz = kernel_fn(Z1, Z2) K_zz = kernel_fn(Z1, Z2)
v = torch.sum(K_zz * (beta_1.T @ beta_2)).item() v = torch.sum(K_zz * (beta_1.T @ beta_2)).item()


# RKMEImageStatSpecification.INNER_PRODUCT_COUNT += 1
# RKMEImageSpecification.INNER_PRODUCT_COUNT += 1
return v return v


def dist(self, Phi2: RKMEImageStatSpecification, omit_term1: bool = False) -> float:
def dist(self, Phi2: RKMEImageSpecification, omit_term1: bool = False) -> float:
"""Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME Image specifications


Parameters Parameters
---------- ----------
Phi2 : RKMEImageStatSpecification
Phi2 : RKMEImageSpecification
The other RKME specification. The other RKME specification.
omit_term1 : bool, optional omit_term1 : bool, optional
True if the inner product of self with itself can be omitted, by default False. True if the inner product of self with itself can be omitted, by default False.


+ 1
- 1
learnware/specification/regular/table/__init__.py View File

@@ -1 +1 @@
from .rkme import RKMEStatSpecification
from .rkme import RKMETableSpecification

+ 11
- 6
learnware/specification/regular/table/rkme.py View File

@@ -29,7 +29,7 @@ if not _FAISS_INSTALLED:
logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first")




class RKMEStatSpecification(RegularStatsSpecification):
class RKMETableSpecification(RegularStatsSpecification):
"""Reduced Kernel Mean Embedding (RKME) Specification""" """Reduced Kernel Mean Embedding (RKME) Specification"""


def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): def __init__(self, gamma: float = 0.1, cuda_idx: int = -1):
@@ -50,7 +50,7 @@ class RKMEStatSpecification(RegularStatsSpecification):
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.device = choose_device(cuda_idx=cuda_idx) self.device = choose_device(cuda_idx=cuda_idx)
setup_seed(0) setup_seed(0)
super(RKMEStatSpecification, self).__init__(type=self.__class__.__name__)
super(RKMETableSpecification, self).__init__(type=self.__class__.__name__)


def get_beta(self) -> np.ndarray: def get_beta(self) -> np.ndarray:
"""Move beta(RKME weights) back to memory accessible to the CPU. """Move beta(RKME weights) back to memory accessible to the CPU.
@@ -333,12 +333,12 @@ class RKMEStatSpecification(RegularStatsSpecification):
else: else:
logger.warning("Not enough candidates for herding!") logger.warning("Not enough candidates for herding!")


def inner_prod(self, Phi2: RKMEStatSpecification) -> float:
def inner_prod(self, Phi2: RKMETableSpecification) -> float:
"""Compute the inner product between two RKME specifications """Compute the inner product between two RKME specifications


Parameters Parameters
---------- ----------
Phi2 : RKMEStatSpecification
Phi2 : RKMETableSpecification
The other RKME specification. The other RKME specification.


Returns Returns
@@ -354,12 +354,12 @@ class RKMEStatSpecification(RegularStatsSpecification):


return float(v) return float(v)


def dist(self, Phi2: RKMEStatSpecification, omit_term1: bool = False) -> float:
def dist(self, Phi2: RKMETableSpecification, omit_term1: bool = False) -> float:
"""Compute the Maximum-Mean-Discrepancy(MMD) between two RKME specifications """Compute the Maximum-Mean-Discrepancy(MMD) between two RKME specifications


Parameters Parameters
---------- ----------
Phi2 : RKMEStatSpecification
Phi2 : RKMETableSpecification
The other RKME specification. The other RKME specification.
omit_term1 : bool, optional omit_term1 : bool, optional
True if the inner product of self with itself can be omitted, by default False. True if the inner product of self with itself can be omitted, by default False.
@@ -463,6 +463,11 @@ class RKMEStatSpecification(RegularStatsSpecification):
else: else:
return False return False


class RKMEStatSpecification(RKMETableSpecification):
"""nickname for RKMETableSpecification, for compatibility currently.
TODO: modify all learnware in database and remove this nickname
"""
pass


def setup_seed(seed): def setup_seed(seed):
"""Fix a random seed for addressing reproducibility issues. """Fix a random seed for addressing reproducibility issues.


+ 11
- 11
learnware/specification/utils.py View File

@@ -4,7 +4,7 @@ import pandas as pd
from typing import Union from typing import Union


from .base import BaseStatSpecification from .base import BaseStatSpecification
from .regular import RKMEStatSpecification, RKMEImageStatSpecification
from .regular import RKMETableSpecification, RKMEImageSpecification
from ..config import C from ..config import C




@@ -42,10 +42,10 @@ def generate_rkme_spec(
nonnegative_beta: bool = True, nonnegative_beta: bool = True,
reduce: bool = True, reduce: bool = True,
cuda_idx: int = None, cuda_idx: int = None,
) -> RKMEStatSpecification:
) -> RKMETableSpecification:
""" """
Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification. Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification.
Return a RKMEStatSpecification object, use .save() method to save as json file.
Return a RKMETableSpecification object, use .save() method to save as json file.


Parameters Parameters
---------- ----------
@@ -73,8 +73,8 @@ def generate_rkme_spec(


Returns Returns
------- -------
RKMEStatSpecification
A RKMEStatSpecification object
RKMETableSpecification
A RKMETableSpecification object
""" """
# Convert data type # Convert data type
X = convert_to_numpy(X) X = convert_to_numpy(X)
@@ -94,7 +94,7 @@ def generate_rkme_spec(
cuda_idx = 0 cuda_idx = 0


# Generate rkme spec # Generate rkme spec
rkme_spec = RKMEStatSpecification(gamma=gamma, cuda_idx=cuda_idx)
rkme_spec = RKMETableSpecification(gamma=gamma, cuda_idx=cuda_idx)
rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce) rkme_spec.generate_stat_spec_from_data(X, reduced_set_size, step_size, steps, nonnegative_beta, reduce)
return rkme_spec return rkme_spec


@@ -109,10 +109,10 @@ def generate_rkme_image_spec(
reduce: bool = True, reduce: bool = True,
verbose: bool = True, verbose: bool = True,
cuda_idx: int = None, cuda_idx: int = None,
) -> RKMEImageStatSpecification:
) -> RKMEImageSpecification:
""" """
Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image. Interface for users to generate Reduced Kernel Mean Embedding (RKME) specification for Image.
Return a RKMEImageStatSpecification object, use .save() method to save as json file.
Return a RKMEImageSpecification object, use .save() method to save as json file.


Parameters Parameters
---------- ----------
@@ -144,8 +144,8 @@ def generate_rkme_image_spec(


Returns Returns
------- -------
RKMEImageStatSpecification
A RKMEImageStatSpecification object
RKMEImageSpecification
A RKMEImageSpecification object
""" """


# Check cuda_idx # Check cuda_idx
@@ -157,7 +157,7 @@ def generate_rkme_image_spec(
cuda_idx = 0 cuda_idx = 0


# Generate rkme spec # Generate rkme spec
rkme_image_spec = RKMEImageStatSpecification(cuda_idx=cuda_idx)
rkme_image_spec = RKMEImageSpecification(cuda_idx=cuda_idx)
rkme_image_spec.generate_stat_spec_from_data( rkme_image_spec.generate_stat_spec_from_data(
X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose X, reduced_set_size, step_size, steps, resize, nonnegative_beta, reduce, verbose
) )


+ 1
- 1
tests/test_market/learnware_example/example.yaml View File

@@ -3,6 +3,6 @@ model:
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEStatSpecification
class_name: RKMETableSpecification
file_name: svm.json file_name: svm.json
kwargs: {} kwargs: {}

+ 2
- 2
tests/test_market/test_easy.py View File

@@ -170,9 +170,9 @@ class TestMarket(unittest.TestCase):
with zipfile.ZipFile(zip_path, "r") as zip_obj: with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir) zip_obj.extractall(path=unzip_dir)


user_spec = specification.rkme.RKMEStatSpecification()
user_spec = specification.rkme.RKMETableSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json")) user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
( (
sorted_score_list, sorted_score_list,
single_learnware_list, single_learnware_list,


+ 7
- 7
tests/test_specification/test_rkme.py View File

@@ -5,7 +5,7 @@ import unittest
import tempfile import tempfile
import numpy as np import numpy as np


from learnware.specification import RKMEStatSpecification, RKMEImageStatSpecification
from learnware.specification import RKMETableSpecification, RKMEImageSpecification
from learnware.specification import generate_rkme_image_spec, generate_rkme_spec from learnware.specification import generate_rkme_image_spec, generate_rkme_spec




@@ -22,11 +22,11 @@ class TestRKME(unittest.TestCase):


with open(rkme_path, "r") as f: with open(rkme_path, "r") as f:
data = json.load(f) data = json.load(f)
assert data["type"] == "RKMEStatSpecification"
assert data["type"] == "RKMETableSpecification"


rkme2 = RKMEStatSpecification()
rkme2 = RKMETableSpecification()
rkme2.load(rkme_path) rkme2.load(rkme_path)
assert rkme2.type == "RKMEStatSpecification"
assert rkme2.type == "RKMETableSpecification"


def test_image_rkme(self): def test_image_rkme(self):
def _test_image_rkme(X): def _test_image_rkme(X):
@@ -38,11 +38,11 @@ class TestRKME(unittest.TestCase):


with open(rkme_path, "r") as f: with open(rkme_path, "r") as f:
data = json.load(f) data = json.load(f)
assert data["type"] == "RKMEImageStatSpecification"
assert data["type"] == "RKMEImageSpecification"


rkme2 = RKMEImageStatSpecification()
rkme2 = RKMEImageSpecification()
rkme2.load(rkme_path) rkme2.load(rkme_path)
assert rkme2.type == "RKMEImageStatSpecification"
assert rkme2.type == "RKMEImageSpecification"


_test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
_test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128))) _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 128, 128)))


+ 1
- 1
tests/test_workflow/learnware_example/example.yaml View File

@@ -3,6 +3,6 @@ model:
kwargs: {} kwargs: {}
stat_specifications: stat_specifications:
- module_path: learnware.specification - module_path: learnware.specification
class_name: RKMEStatSpecification
class_name: RKMETableSpecification
file_name: svm.json file_name: svm.json
kwargs: {} kwargs: {}

+ 3
- 3
tests/test_workflow/test_workflow.py View File

@@ -155,9 +155,9 @@ class TestAllWorkflow(unittest.TestCase):
with zipfile.ZipFile(zip_path, "r") as zip_obj: with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir) zip_obj.extractall(path=unzip_dir)


user_spec = specification.RKMEStatSpecification()
user_spec = specification.RKMETableSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json")) user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": user_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
( (
sorted_score_list, sorted_score_list,
single_learnware_list, single_learnware_list,
@@ -182,7 +182,7 @@ class TestAllWorkflow(unittest.TestCase):
train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True)


stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0) stat_spec = specification.utils.generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMEStatSpecification": stat_spec})
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})


_, _, _, mixture_learnware_list = easy_market.search_learnware(user_info) _, _, _, mixture_learnware_list = easy_market.search_learnware(user_info)




Loading…
Cancel
Save