| @@ -138,7 +138,6 @@ def query_explain_job(): | |||||
| raise ParamMissError("train_id") | raise ParamMissError("train_id") | ||||
| encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) | encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) | ||||
| metadata = encapsulator.query_meta(train_id) | metadata = encapsulator.query_meta(train_id) | ||||
| return jsonify(metadata) | return jsonify(metadata) | ||||
| @@ -175,6 +174,12 @@ def query_hoc(): | |||||
| query_kwargs = _get_query_sample_parameters(data) | query_kwargs = _get_query_sample_parameters(data) | ||||
| filter_empty = data.get("drop_empty", True) | |||||
| if not isinstance(filter_empty, bool): | |||||
| raise ParamTypeError("drop_empty", bool) | |||||
| query_kwargs["drop_empty"] = filter_empty | |||||
| encapsulator = HierarchicalOcclusionEncap( | encapsulator = HierarchicalOcclusionEncap( | ||||
| _image_url_formatter, | _image_url_formatter, | ||||
| EXPLAIN_MANAGER) | EXPLAIN_MANAGER) | ||||
| @@ -15,6 +15,7 @@ | |||||
| """Common explain data encapsulator base class.""" | """Common explain data encapsulator base class.""" | ||||
| import copy | import copy | ||||
| from enum import Enum | |||||
| from mindinsight.utils.exceptions import ParamValueError | from mindinsight.utils.exceptions import ParamValueError | ||||
| @@ -64,6 +65,11 @@ def _sort_key_max_confidence_sd(sample, labels): | |||||
| max_confidence_sd = confidence_sd | max_confidence_sd = confidence_sd | ||||
| return max_confidence_sd | return max_confidence_sd | ||||
| class ExplanationKeys(Enum): | |||||
| """Query type enums.""" | |||||
| HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose | |||||
| SALIENCY = "saliency_maps" | |||||
| class ExplainDataEncap: | class ExplainDataEncap: | ||||
| """Explain data encapsulator base class.""" | """Explain data encapsulator base class.""" | ||||
| @@ -89,7 +95,7 @@ class ExplanationEncap(ExplainDataEncap): | |||||
| sorted_name, | sorted_name, | ||||
| sorted_type, | sorted_type, | ||||
| prediction_types=None, | prediction_types=None, | ||||
| query_type="saliency_maps"): | |||||
| drop_type=None): | |||||
| """ | """ | ||||
| Query samples. | Query samples. | ||||
| @@ -99,13 +105,22 @@ class ExplanationEncap(ExplainDataEncap): | |||||
| sorted_name (str): Field to be sorted. | sorted_name (str): Field to be sorted. | ||||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | sorted_type (str): Sorting order, 'ascending' or 'descending'. | ||||
| prediction_types (list[str]): Prediction type filter. | prediction_types (list[str]): Prediction type filter. | ||||
| drop_type (str, None): When it is None, no filer will be applied. When it is 'hoc_layers', samples without | |||||
| hoc explanations will be filtered out. When it is 'saliency_maps', samples without saliency explanations | |||||
| will be filtered out. | |||||
| Returns: | Returns: | ||||
| list[dict], samples to be queried. | list[dict], samples to be queried. | ||||
| """ | """ | ||||
| samples = copy.deepcopy(job.get_all_samples()) | samples = copy.deepcopy(job.get_all_samples()) | ||||
| samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])] | |||||
| if drop_type not in (None, ExplanationKeys.SALIENCY.value, ExplanationKeys.HOC.value): | |||||
| raise ParamValueError( | |||||
| f"Argument drop_type valid options: None, {ExplanationKeys.SALIENCY.value}, " | |||||
| f"{ExplanationKeys.HOC.value}, but got {drop_type}.") | |||||
| if drop_type is not None: | |||||
| samples = [sample for sample in samples if any(infer[drop_type] for infer in sample['inferences'])] | |||||
| if labels: | if labels: | ||||
| filtered = [] | filtered = [] | ||||
| for sample in samples: | for sample in samples: | ||||
| @@ -43,8 +43,10 @@ class ExplainJobEncap(ExplainDataEncap): | |||||
| def query_meta(self, train_id): | def query_meta(self, train_id): | ||||
| """ | """ | ||||
| Query explain job meta-data. | Query explain job meta-data. | ||||
| Args: | Args: | ||||
| train_id (str): Job ID. | train_id (str): Job ID. | ||||
| Returns: | Returns: | ||||
| dict, the metadata. | dict, the metadata. | ||||
| """ | """ | ||||
| @@ -81,7 +83,7 @@ class ExplainJobEncap(ExplainDataEncap): | |||||
| """Convert ExplainJob's meta-data to jsonable info object.""" | """Convert ExplainJob's meta-data to jsonable info object.""" | ||||
| info = cls._job_2_info(job) | info = cls._job_2_info(job) | ||||
| info["sample_count"] = job.sample_count | info["sample_count"] = job.sample_count | ||||
| info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0] | |||||
| info["classes"] = job.all_classes | |||||
| saliency_info = dict() | saliency_info = dict() | ||||
| if job.min_confidence is None: | if job.min_confidence is None: | ||||
| saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE | saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE | ||||
| @@ -15,7 +15,7 @@ | |||||
| """Hierarchical Occlusion encapsulator.""" | """Hierarchical Occlusion encapsulator.""" | ||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | ||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys | |||||
| class HierarchicalOcclusionEncap(ExplanationEncap): | class HierarchicalOcclusionEncap(ExplanationEncap): | ||||
| @@ -28,7 +28,8 @@ class HierarchicalOcclusionEncap(ExplanationEncap): | |||||
| offset, | offset, | ||||
| sorted_name, | sorted_name, | ||||
| sorted_type, | sorted_type, | ||||
| prediction_types=None | |||||
| prediction_types=None, | |||||
| drop_empty=True, | |||||
| ): | ): | ||||
| """ | """ | ||||
| Query hierarchical occlusion results. | Query hierarchical occlusion results. | ||||
| @@ -41,6 +42,7 @@ class HierarchicalOcclusionEncap(ExplanationEncap): | |||||
| sorted_name (str): Field to be sorted. | sorted_name (str): Field to be sorted. | ||||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | sorted_type (str): Sorting order, 'ascending' or 'descending'. | ||||
| prediction_types (list[str]): Prediction types filter. | prediction_types (list[str]): Prediction types filter. | ||||
| drop_empty (bool): Whether to drop out the data without hoc data. Default: True. | |||||
| Returns: | Returns: | ||||
| tuple[int, list[dict]], total number of samples after filtering and list of sample results. | tuple[int, list[dict]], total number of samples after filtering and list of sample results. | ||||
| @@ -49,8 +51,12 @@ class HierarchicalOcclusionEncap(ExplanationEncap): | |||||
| if job is None: | if job is None: | ||||
| raise TrainJobNotExistError(train_id) | raise TrainJobNotExistError(train_id) | ||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||||
| query_type="hoc_layers") | |||||
| if drop_empty: | |||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||||
| drop_type=ExplanationKeys.HOC.value) | |||||
| else: | |||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types) | |||||
| sample_infos = [] | sample_infos = [] | ||||
| obj_offset = offset * limit | obj_offset = offset * limit | ||||
| count = len(samples) | count = len(samples) | ||||
| @@ -59,29 +65,36 @@ class HierarchicalOcclusionEncap(ExplanationEncap): | |||||
| end = obj_offset + limit | end = obj_offset + limit | ||||
| for i in range(obj_offset, end): | for i in range(obj_offset, end): | ||||
| sample = samples[i] | sample = samples[i] | ||||
| sample_infos.append(self._touch_sample(sample, job)) | |||||
| sample_infos.append(self._touch_sample(sample, job, drop_empty)) | |||||
| return count, sample_infos | return count, sample_infos | ||||
| def _touch_sample(self, sample, job): | |||||
| def _touch_sample(self, sample, job, drop_empty): | |||||
| """ | """ | ||||
| Final edit on single sample info. | Final edit on single sample info. | ||||
| Args: | Args: | ||||
| sample (dict): Sample info. | sample (dict): Sample info. | ||||
| job (ExplainManager): Explain job. | job (ExplainManager): Explain job. | ||||
| drop_empty (bool): Whether to drop out inferences without HOC explanations. | |||||
| Returns: | Returns: | ||||
| dict, the edited sample info. | dict, the edited sample info. | ||||
| """ | """ | ||||
| sample_cp = sample.copy() | |||||
| sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original") | |||||
| for inference_item in sample_cp["inferences"]: | |||||
| sample["image"] = self._get_image_url(job.train_id, sample["image"], "original") | |||||
| inferences = sample["inferences"] | |||||
| i = 0 # init index for while loop | |||||
| while i < len(inferences): | |||||
| inference_item = inferences[i] | |||||
| if drop_empty and not inference_item[ExplanationKeys.HOC.value]: | |||||
| inferences.pop(i) | |||||
| continue | |||||
| new_list = [] | new_list = [] | ||||
| for idx, hoc_layer in enumerate(inference_item["hoc_layers"]): | |||||
| for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]): | |||||
| hoc_layer["outcome"] = self._get_image_url(job.train_id, | hoc_layer["outcome"] = self._get_image_url(job.train_id, | ||||
| f"{sample['id']}_{inference_item['label']}_{idx}.jpg", | f"{sample['id']}_{inference_item['label']}_{idx}.jpg", | ||||
| "outcome") | "outcome") | ||||
| new_list.append(hoc_layer) | new_list.append(hoc_layer) | ||||
| inference_item["hoc_layers"] = new_list | |||||
| return sample_cp | |||||
| inference_item[ExplanationKeys.HOC.value] = new_list | |||||
| i += 1 | |||||
| return sample | |||||
| @@ -15,7 +15,7 @@ | |||||
| """Saliency map encapsulator.""" | """Saliency map encapsulator.""" | ||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | ||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys | |||||
| class SaliencyEncap(ExplanationEncap): | class SaliencyEncap(ExplanationEncap): | ||||
| @@ -49,8 +49,7 @@ class SaliencyEncap(ExplanationEncap): | |||||
| if job is None: | if job is None: | ||||
| raise TrainJobNotExistError(train_id) | raise TrainJobNotExistError(train_id) | ||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||||
| query_type="saliency_maps") | |||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types) | |||||
| sample_infos = [] | sample_infos = [] | ||||
| obj_offset = offset * limit | obj_offset = offset * limit | ||||
| @@ -79,10 +78,10 @@ class SaliencyEncap(ExplanationEncap): | |||||
| sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") | sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") | ||||
| for inference in sample_cp["inferences"]: | for inference in sample_cp["inferences"]: | ||||
| new_list = [] | new_list = [] | ||||
| for saliency_map in inference["saliency_maps"]: | |||||
| for saliency_map in inference[ExplanationKeys.SALIENCY.value]: | |||||
| if explainers and saliency_map["explainer"] not in explainers: | if explainers and saliency_map["explainer"] not in explainers: | ||||
| continue | continue | ||||
| saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") | saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") | ||||
| new_list.append(saliency_map) | new_list.append(saliency_map) | ||||
| inference["saliency_maps"] = new_list | |||||
| inference[ExplanationKeys.SALIENCY.value] = new_list | |||||
| return sample_cp | return sample_cp | ||||
| @@ -99,12 +99,22 @@ class ExplainLoader: | |||||
| - sample_count (int): Number of samples for each label. | - sample_count (int): Number of samples for each label. | ||||
| """ | """ | ||||
| sample_count_per_label = defaultdict(int) | sample_count_per_label = defaultdict(int) | ||||
| saliency_count_per_label = defaultdict(int) | |||||
| hoc_count_per_label = defaultdict(int) | |||||
| for sample in self._samples.values(): | for sample in self._samples.values(): | ||||
| if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')): | if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')): | ||||
| for label in set(sample['ground_truth_label'] + sample['predicted_label']): | for label in set(sample['ground_truth_label'] + sample['predicted_label']): | ||||
| sample_count_per_label[label] += 1 | sample_count_per_label[label] += 1 | ||||
| all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]} | |||||
| if sample['inferences'][label]['saliency_maps']: | |||||
| saliency_count_per_label[label] += 1 | |||||
| if sample['inferences'][label]['hoc_layers']: | |||||
| hoc_count_per_label[label] += 1 | |||||
| all_classes_return = [{'id': label_id, | |||||
| 'label': label_name, | |||||
| 'sample_count': sample_count_per_label[label_id], | |||||
| 'saliency_sample_count': saliency_count_per_label[label_id], | |||||
| 'hoc_sample_count': hoc_count_per_label[label_id]} | |||||
| for label_id, label_name in enumerate(self._metadata['labels'])] | for label_id, label_name in enumerate(self._metadata['labels'])] | ||||
| return all_classes_return | return all_classes_return | ||||