Browse Source

[MNT] modify details

tags/v0.3.2
Gene 2 years ago
parent
commit
77eeceb7e8
2 changed files with 7 additions and 2 deletions
  1. +6
    -1
      learnware/market/easy/checker.py
  2. +1
    -1
      tests/test_learnware_client/test_check_learnware.py

+ 6
- 1
learnware/market/easy/checker.py View File

@@ -186,7 +186,12 @@ class EasyStatChecker(BaseChecker):
semantic_output_shape = int(semantic_spec["Output"]["Dimension"])

if model_output_shape == 1:
if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs):
if isinstance(outputs, torch.Tensor):
outputs = outputs.detach().cpu().numpy()
if isinstance(outputs, list):
outputs = np.array(outputs)

if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)):
message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}"
logger.warning(message)
return self.INVALID_LEARNWARE, message


+ 1
- 1
tests/test_learnware_client/test_check_learnware.py View File

@@ -14,7 +14,7 @@ class TestCheckLearnware(unittest.TestCase):
self.client = LearnwareClient()

def test_check_learnware_pip(self):
learnware_id = "00000154"
learnware_id = "00000208"
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
self.zip_path = os.path.join(tempdir, "test.zip")
self.client.download_learnware(learnware_id, self.zip_path)


Loading…
Cancel
Save