diff --git a/mindinsight/backend/explainer/explainer_api.py b/mindinsight/backend/explainer/explainer_api.py index f17773be..2126b2c9 100644 --- a/mindinsight/backend/explainer/explainer_api.py +++ b/mindinsight/backend/explainer/explainer_api.py @@ -25,7 +25,6 @@ from flask import request from mindinsight.conf import settings from mindinsight.utils.exceptions import ParamMissError from mindinsight.utils.exceptions import ParamValueError -from mindinsight.datavisual.common.exceptions import ImageNotExistError from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.utils.tools import get_train_id @@ -168,8 +167,6 @@ def query_image(): encapsulator = DatafileEncap(EXPLAIN_MANAGER) image = encapsulator.query_image_binary(train_id, image_path, image_type) - if image is None: - raise ImageNotExistError(f"{image_path}") return image diff --git a/mindinsight/explainer/encapsulator/saliency_encap.py b/mindinsight/explainer/encapsulator/saliency_encap.py index aa9419dc..07f434da 100644 --- a/mindinsight/explainer/encapsulator/saliency_encap.py +++ b/mindinsight/explainer/encapsulator/saliency_encap.py @@ -18,6 +18,7 @@ import copy from mindinsight.utils.exceptions import ParamValueError from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap +from mindinsight.datavisual.common.exceptions import TrainJobNotExistError def _sort_key_min_confidence(sample): @@ -90,19 +91,17 @@ class SaliencyEncap(ExplainDataEncap): """ job = self.job_manager.get_job(train_id) if job is None: - return 0, None + raise TrainJobNotExistError(train_id) samples = copy.deepcopy(job.get_all_samples()) if labels: filtered = [] for sample in samples: - has_label = False - for label in sample["labels"]: - if label in labels: - has_label = True + infer_labels = [inference["label"] for inference in sample["inferences"]] + for infer_label in infer_labels: + if infer_label in labels: + filtered.append(sample) break - if has_label: - filtered.append(sample) samples = filtered reverse = sorted_type == "descending"