Browse Source

!958 Filter saliency restful API samples by inference labels, returns job not found error when explain job is not found

From: @ngtony
Reviewed-by: @wangyue01,@wenkai_dist
Signed-off-by: @wenkai_dist
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2c86447056
2 changed files with 6 additions and 10 deletions
  1. +0
    -3
      mindinsight/backend/explainer/explainer_api.py
  2. +6
    -7
      mindinsight/explainer/encapsulator/saliency_encap.py

+ 0
- 3
mindinsight/backend/explainer/explainer_api.py View File

@@ -25,7 +25,6 @@ from flask import request
from mindinsight.conf import settings
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import get_train_id
@@ -168,8 +167,6 @@ def query_image():

encapsulator = DatafileEncap(EXPLAIN_MANAGER)
image = encapsulator.query_image_binary(train_id, image_path, image_type)
if image is None:
raise ImageNotExistError(f"{image_path}")

return image



+ 6
- 7
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -18,6 +18,7 @@ import copy

from mindinsight.utils.exceptions import ParamValueError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError


def _sort_key_min_confidence(sample):
@@ -90,19 +91,17 @@ class SaliencyEncap(ExplainDataEncap):
"""
job = self.job_manager.get_job(train_id)
if job is None:
return 0, None
raise TrainJobNotExistError(train_id)

samples = copy.deepcopy(job.get_all_samples())
if labels:
filtered = []
for sample in samples:
has_label = False
for label in sample["labels"]:
if label in labels:
has_label = True
infer_labels = [inference["label"] for inference in sample["inferences"]]
for infer_label in infer_labels:
if infer_label in labels:
filtered.append(sample)
break
if has_label:
filtered.append(sample)
samples = filtered

reverse = sorted_type == "descending"


Loading…
Cancel
Save