Browse Source

Backend adaptation for adding options whether to filter data without hoc explanations

tags/v1.2.0-rc1
lixiaohui 4 years ago
parent
commit
2fd9e47075
6 changed files with 67 additions and 23 deletions
  1. +6
    -1
      mindinsight/backend/explainer/explainer_api.py
  2. +17
    -2
      mindinsight/explainer/encapsulator/explain_data_encap.py
  3. +3
    -1
      mindinsight/explainer/encapsulator/explain_job_encap.py
  4. +25
    -12
      mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py
  5. +4
    -5
      mindinsight/explainer/encapsulator/saliency_encap.py
  6. +12
    -2
      mindinsight/explainer/manager/explain_loader.py

+ 6
- 1
mindinsight/backend/explainer/explainer_api.py View File

@@ -138,7 +138,6 @@ def query_explain_job():
raise ParamMissError("train_id") raise ParamMissError("train_id")
encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) encapsulator = ExplainJobEncap(EXPLAIN_MANAGER)
metadata = encapsulator.query_meta(train_id) metadata = encapsulator.query_meta(train_id)

return jsonify(metadata) return jsonify(metadata)




@@ -175,6 +174,12 @@ def query_hoc():


query_kwargs = _get_query_sample_parameters(data) query_kwargs = _get_query_sample_parameters(data)


filter_empty = data.get("drop_empty", True)
if not isinstance(filter_empty, bool):
raise ParamTypeError("drop_empty", bool)

query_kwargs["drop_empty"] = filter_empty

encapsulator = HierarchicalOcclusionEncap( encapsulator = HierarchicalOcclusionEncap(
_image_url_formatter, _image_url_formatter,
EXPLAIN_MANAGER) EXPLAIN_MANAGER)


+ 17
- 2
mindinsight/explainer/encapsulator/explain_data_encap.py View File

@@ -15,6 +15,7 @@
"""Common explain data encapsulator base class.""" """Common explain data encapsulator base class."""


import copy import copy
from enum import Enum


from mindinsight.utils.exceptions import ParamValueError from mindinsight.utils.exceptions import ParamValueError


@@ -64,6 +65,11 @@ def _sort_key_max_confidence_sd(sample, labels):
max_confidence_sd = confidence_sd max_confidence_sd = confidence_sd
return max_confidence_sd return max_confidence_sd


class ExplanationKeys(Enum):
"""Query type enums."""
HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose
SALIENCY = "saliency_maps"



class ExplainDataEncap: class ExplainDataEncap:
"""Explain data encapsulator base class.""" """Explain data encapsulator base class."""
@@ -89,7 +95,7 @@ class ExplanationEncap(ExplainDataEncap):
sorted_name, sorted_name,
sorted_type, sorted_type,
prediction_types=None, prediction_types=None,
query_type="saliency_maps"):
drop_type=None):
""" """
Query samples. Query samples.


@@ -99,13 +105,22 @@ class ExplanationEncap(ExplainDataEncap):
sorted_name (str): Field to be sorted. sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'. sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction type filter. prediction_types (list[str]): Prediction type filter.
drop_type (str, None): When it is None, no filer will be applied. When it is 'hoc_layers', samples without
hoc explanations will be filtered out. When it is 'saliency_maps', samples without saliency explanations
will be filtered out.


Returns: Returns:
list[dict], samples to be queried. list[dict], samples to be queried.
""" """


samples = copy.deepcopy(job.get_all_samples()) samples = copy.deepcopy(job.get_all_samples())
samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])]
if drop_type not in (None, ExplanationKeys.SALIENCY.value, ExplanationKeys.HOC.value):
raise ParamValueError(
f"Argument drop_type valid options: None, {ExplanationKeys.SALIENCY.value}, "
f"{ExplanationKeys.HOC.value}, but got {drop_type}.")

if drop_type is not None:
samples = [sample for sample in samples if any(infer[drop_type] for infer in sample['inferences'])]
if labels: if labels:
filtered = [] filtered = []
for sample in samples: for sample in samples:


+ 3
- 1
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -43,8 +43,10 @@ class ExplainJobEncap(ExplainDataEncap):
def query_meta(self, train_id): def query_meta(self, train_id):
""" """
Query explain job meta-data. Query explain job meta-data.

Args: Args:
train_id (str): Job ID. train_id (str): Job ID.

Returns: Returns:
dict, the metadata. dict, the metadata.
""" """
@@ -81,7 +83,7 @@ class ExplainJobEncap(ExplainDataEncap):
"""Convert ExplainJob's meta-data to jsonable info object.""" """Convert ExplainJob's meta-data to jsonable info object."""
info = cls._job_2_info(job) info = cls._job_2_info(job)
info["sample_count"] = job.sample_count info["sample_count"] = job.sample_count
info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0]
info["classes"] = job.all_classes
saliency_info = dict() saliency_info = dict()
if job.min_confidence is None: if job.min_confidence is None:
saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE


+ 25
- 12
mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py View File

@@ -15,7 +15,7 @@
"""Hierarchical Occlusion encapsulator.""" """Hierarchical Occlusion encapsulator."""


from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys




class HierarchicalOcclusionEncap(ExplanationEncap): class HierarchicalOcclusionEncap(ExplanationEncap):
@@ -28,7 +28,8 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
offset, offset,
sorted_name, sorted_name,
sorted_type, sorted_type,
prediction_types=None
prediction_types=None,
drop_empty=True,
): ):
""" """
Query hierarchical occlusion results. Query hierarchical occlusion results.
@@ -41,6 +42,7 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
sorted_name (str): Field to be sorted. sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'. sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction types filter. prediction_types (list[str]): Prediction types filter.
drop_empty (bool): Whether to drop out the data without hoc data. Default: True.


Returns: Returns:
tuple[int, list[dict]], total number of samples after filtering and list of sample results. tuple[int, list[dict]], total number of samples after filtering and list of sample results.
@@ -49,8 +51,12 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
if job is None: if job is None:
raise TrainJobNotExistError(train_id) raise TrainJobNotExistError(train_id)


samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
query_type="hoc_layers")
if drop_empty:
samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
drop_type=ExplanationKeys.HOC.value)
else:
samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types)

sample_infos = [] sample_infos = []
obj_offset = offset * limit obj_offset = offset * limit
count = len(samples) count = len(samples)
@@ -59,29 +65,36 @@ class HierarchicalOcclusionEncap(ExplanationEncap):
end = obj_offset + limit end = obj_offset + limit
for i in range(obj_offset, end): for i in range(obj_offset, end):
sample = samples[i] sample = samples[i]
sample_infos.append(self._touch_sample(sample, job))
sample_infos.append(self._touch_sample(sample, job, drop_empty))


return count, sample_infos return count, sample_infos


def _touch_sample(self, sample, job):
def _touch_sample(self, sample, job, drop_empty):
""" """
Final edit on single sample info. Final edit on single sample info.


Args: Args:
sample (dict): Sample info. sample (dict): Sample info.
job (ExplainManager): Explain job. job (ExplainManager): Explain job.
drop_empty (bool): Whether to drop out inferences without HOC explanations.


Returns: Returns:
dict, the edited sample info. dict, the edited sample info.
""" """
sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original")
for inference_item in sample_cp["inferences"]:
sample["image"] = self._get_image_url(job.train_id, sample["image"], "original")
inferences = sample["inferences"]
i = 0 # init index for while loop
while i < len(inferences):
inference_item = inferences[i]
if drop_empty and not inference_item[ExplanationKeys.HOC.value]:
inferences.pop(i)
continue
new_list = [] new_list = []
for idx, hoc_layer in enumerate(inference_item["hoc_layers"]):
for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]):
hoc_layer["outcome"] = self._get_image_url(job.train_id, hoc_layer["outcome"] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg", f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
"outcome") "outcome")
new_list.append(hoc_layer) new_list.append(hoc_layer)
inference_item["hoc_layers"] = new_list
return sample_cp
inference_item[ExplanationKeys.HOC.value] = new_list
i += 1
return sample

+ 4
- 5
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -15,7 +15,7 @@
"""Saliency map encapsulator.""" """Saliency map encapsulator."""


from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys




class SaliencyEncap(ExplanationEncap): class SaliencyEncap(ExplanationEncap):
@@ -49,8 +49,7 @@ class SaliencyEncap(ExplanationEncap):
if job is None: if job is None:
raise TrainJobNotExistError(train_id) raise TrainJobNotExistError(train_id)


samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
query_type="saliency_maps")
samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types)


sample_infos = [] sample_infos = []
obj_offset = offset * limit obj_offset = offset * limit
@@ -79,10 +78,10 @@ class SaliencyEncap(ExplanationEncap):
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original")
for inference in sample_cp["inferences"]: for inference in sample_cp["inferences"]:
new_list = [] new_list = []
for saliency_map in inference["saliency_maps"]:
for saliency_map in inference[ExplanationKeys.SALIENCY.value]:
if explainers and saliency_map["explainer"] not in explainers: if explainers and saliency_map["explainer"] not in explainers:
continue continue
saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay")
new_list.append(saliency_map) new_list.append(saliency_map)
inference["saliency_maps"] = new_list
inference[ExplanationKeys.SALIENCY.value] = new_list
return sample_cp return sample_cp

+ 12
- 2
mindinsight/explainer/manager/explain_loader.py View File

@@ -99,12 +99,22 @@ class ExplainLoader:
- sample_count (int): Number of samples for each label. - sample_count (int): Number of samples for each label.
""" """
sample_count_per_label = defaultdict(int) sample_count_per_label = defaultdict(int)
saliency_count_per_label = defaultdict(int)
hoc_count_per_label = defaultdict(int)
for sample in self._samples.values(): for sample in self._samples.values():
if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')): if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')):
for label in set(sample['ground_truth_label'] + sample['predicted_label']): for label in set(sample['ground_truth_label'] + sample['predicted_label']):
sample_count_per_label[label] += 1 sample_count_per_label[label] += 1

all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]}
if sample['inferences'][label]['saliency_maps']:
saliency_count_per_label[label] += 1
if sample['inferences'][label]['hoc_layers']:
hoc_count_per_label[label] += 1

all_classes_return = [{'id': label_id,
'label': label_name,
'sample_count': sample_count_per_label[label_id],
'saliency_sample_count': saliency_count_per_label[label_id],
'hoc_sample_count': hoc_count_per_label[label_id]}
for label_id, label_name in enumerate(self._metadata['labels'])] for label_id, label_name in enumerate(self._metadata['labels'])]
return all_classes_return return all_classes_return




Loading…
Cancel
Save