|
|
|
@@ -18,6 +18,7 @@ import copy |
|
|
|
|
|
|
|
from mindinsight.utils.exceptions import ParamValueError |
|
|
|
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap |
|
|
|
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError |
|
|
|
|
|
|
|
|
|
|
|
def _sort_key_min_confidence(sample): |
|
|
|
@@ -90,19 +91,17 @@ class SaliencyEncap(ExplainDataEncap): |
|
|
|
""" |
|
|
|
job = self.job_manager.get_job(train_id) |
|
|
|
if job is None: |
|
|
|
return 0, None |
|
|
|
raise TrainJobNotExistError(train_id) |
|
|
|
|
|
|
|
samples = copy.deepcopy(job.get_all_samples()) |
|
|
|
if labels: |
|
|
|
filtered = [] |
|
|
|
for sample in samples: |
|
|
|
has_label = False |
|
|
|
for label in sample["labels"]: |
|
|
|
if label in labels: |
|
|
|
has_label = True |
|
|
|
infer_labels = [inference["label"] for inference in sample["inferences"]] |
|
|
|
for infer_label in infer_labels: |
|
|
|
if infer_label in labels: |
|
|
|
filtered.append(sample) |
|
|
|
break |
|
|
|
if has_label: |
|
|
|
filtered.append(sample) |
|
|
|
samples = filtered |
|
|
|
|
|
|
|
reverse = sorted_type == "descending" |
|
|
|
|