diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index 7e0983e..eb6fe75 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -186,7 +186,12 @@ class EasyStatChecker(BaseChecker): semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) if model_output_shape == 1: - if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): + if isinstance(outputs, torch.Tensor): + outputs = outputs.detach().cpu().numpy() + if isinstance(outputs, list): + outputs = np.array(outputs) + + if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)): message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" logger.warning(message) return self.INVALID_LEARNWARE, message diff --git a/tests/test_learnware_client/test_check_learnware.py b/tests/test_learnware_client/test_check_learnware.py index 77a44c3..59f0820 100644 --- a/tests/test_learnware_client/test_check_learnware.py +++ b/tests/test_learnware_client/test_check_learnware.py @@ -14,7 +14,7 @@ class TestCheckLearnware(unittest.TestCase): self.client = LearnwareClient() def test_check_learnware_pip(self): - learnware_id = "00000154" + learnware_id = "00000208" with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: self.zip_path = os.path.join(tempdir, "test.zip") self.client.download_learnware(learnware_id, self.zip_path)