Browse Source

[FIX | MNT] fix bugs for hetero organizer, and modify tests

tags/v0.3.2
bxdd 2 years ago
parent
commit
c44b76afd5
8 changed files with 9 additions and 26 deletions
  1. +5
    -1
      learnware/market/heterogeneous/organizer/__init__.py
  2. +0
    -8
      tests/test_hetero_market/example_learnwares/example_learnware_1/learnware.yaml
  3. +0
    -1
      tests/test_hetero_market/example_learnwares/example_learnware_1/requirements.txt
  4. +0
    -0
      tests/test_hetero_market/example_learnwares/learnware.yaml
  5. +0
    -6
      tests/test_hetero_market/example_learnwares/model0.py
  6. +0
    -6
      tests/test_hetero_market/example_learnwares/model1.py
  7. +0
    -0
      tests/test_hetero_market/example_learnwares/requirements.txt
  8. +4
    -4
      tests/test_hetero_market/test_hetero.py

+ 5
- 1
learnware/market/heterogeneous/organizer/__init__.py View File

@@ -206,6 +206,9 @@ class HeteroMapTableOrganizer(EasyOrganizer):
str: id of target learware
List[str]: A list of ids of target learnwares
"""
if isinstance(ids, str):
ids = [ids]

for idx in ids:
try:
spec = self.learnware_list[idx].get_specification()
@@ -218,7 +221,8 @@ class HeteroMapTableOrganizer(EasyOrganizer):
hetero_spec.save(save_path)

except Exception as err:
logger.warning(f"Learnware {idx} generate HeteroMapTableSpecification failed! Due to {err}")
traceback.print_exc()
logger.warning(f"Learnware {idx} generate HeteroMapTableSpecification failed!")

def _get_hetero_learnware_ids(self, ids: Union[str, List[str]]) -> List[str]:
"""Get learnware ids that supports heterogeneous market training and search.


+ 0
- 8
tests/test_hetero_market/example_learnwares/example_learnware_1/learnware.yaml View File

@@ -1,8 +0,0 @@
model:
class_name: MyModel
kwargs: {}
stat_specifications:
- module_path: learnware.specification
class_name: RKMETableSpecification
file_name: stat.json
kwargs: {}

+ 0
- 1
tests/test_hetero_market/example_learnwares/example_learnware_1/requirements.txt View File

@@ -1 +0,0 @@
learnware == 0.1.0.999

tests/test_hetero_market/example_learnwares/example_learnware_0/learnware.yaml → tests/test_hetero_market/example_learnwares/learnware.yaml View File


tests/test_hetero_market/example_learnwares/example_learnware_0/__init__.py → tests/test_hetero_market/example_learnwares/model0.py View File

@@ -12,11 +12,5 @@ class MyModel(BaseModel):
model = joblib.load(model_path)
self.model = model

def fit(self, X: np.ndarray, y: np.ndarray):
pass

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict(X)

def finetune(self, X: np.ndarray, y: np.ndarray):
pass

tests/test_hetero_market/example_learnwares/example_learnware_1/__init__.py → tests/test_hetero_market/example_learnwares/model1.py View File

@@ -12,11 +12,5 @@ class MyModel(BaseModel):
model = joblib.load(model_path)
self.model = model

def fit(self, X: np.ndarray, y: np.ndarray):
pass

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict(X)

def finetune(self, X: np.ndarray, y: np.ndarray):
pass

tests/test_hetero_market/example_learnwares/example_learnware_0/requirements.txt → tests/test_hetero_market/example_learnwares/requirements.txt View File


+ 4
- 4
tests/test_hetero_market/test_hetero.py View File

@@ -75,7 +75,7 @@ class TestMarket(unittest.TestCase):

example_learnware_idx = i % 2
input_dim = input_shape_list[example_learnware_idx]
example_learnware_name = "example_learnwares/example_learnware_%d" % (example_learnware_idx)
learnware_example_dir = "example_learnwares"

X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_dim, noise=0.1, random_state=42)

@@ -89,16 +89,16 @@ class TestMarket(unittest.TestCase):

init_file = os.path.join(dir_path, "__init__.py")
copyfile(
os.path.join(curr_root, example_learnware_name, "__init__.py"), init_file
os.path.join(curr_root, learnware_example_dir, f"model{example_learnware_idx}.py"), init_file
) # cp example_init.py init_file

yaml_file = os.path.join(dir_path, "learnware.yaml")
copyfile(
os.path.join(curr_root, example_learnware_name, "learnware.yaml"), yaml_file
os.path.join(curr_root, learnware_example_dir, "learnware.yaml"), yaml_file
) # cp example.yaml yaml_file

env_file = os.path.join(dir_path, "requirements.txt")
copyfile(os.path.join(curr_root, example_learnware_name, "requirements.txt"), env_file)
copyfile(os.path.join(curr_root, learnware_example_dir, "requirements.txt"), env_file)

zip_file = dir_path + ".zip"
# zip -q -r -j zip_file dir_path


Loading…
Cancel
Save