|
|
@@ -186,7 +186,12 @@ class EasyStatChecker(BaseChecker): |
|
|
semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) |
|
|
semantic_output_shape = int(semantic_spec["Output"]["Dimension"]) |
|
|
|
|
|
|
|
|
if model_output_shape == 1: |
|
|
if model_output_shape == 1: |
|
|
if not all(int(item) >= 0 and int(item) < semantic_output_shape for item in outputs): |
|
|
|
|
|
|
|
|
if isinstance(outputs, torch.Tensor): |
|
|
|
|
|
outputs = outputs.detach().cpu().numpy() |
|
|
|
|
|
if isinstance(outputs, list): |
|
|
|
|
|
outputs = np.array(outputs) |
|
|
|
|
|
|
|
|
|
|
|
if not np.all(np.logical_and(outputs >= 0, outputs < semantic_output_shape)): |
|
|
message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" |
|
|
message = f"The learnware [{learnware.id}] output label mismatch, where outputs of model is {outputs}, semantic_shape={(semantic_output_shape, )}" |
|
|
logger.warning(message) |
|
|
logger.warning(message) |
|
|
return self.INVALID_LEARNWARE, message |
|
|
return self.INVALID_LEARNWARE, message |
|
|
|