From cb839798b22f2a2b69f104f95a8ec11ecdd2a93b Mon Sep 17 00:00:00 2001 From: lixiaohui Date: Thu, 7 Jan 2021 11:34:35 +0800 Subject: [PATCH] explainer add HOC, prediction type filtering and sort with filtered label confidence --- .../backend/explainer/explainer_api.py | 116 +++++++--- .../proto_files/mindinsight_summary.proto | 13 ++ mindinsight/explainer/common/enums.py | 1 + .../explainer/encapsulator/_hoc_pil_apply.py | 156 +++++++++++++ .../explainer/encapsulator/datafile_encap.py | 52 ++++- .../encapsulator/explain_data_encap.py | 128 ++++++++++- .../encapsulator/explain_job_encap.py | 8 +- .../hierarchical_occlusion_encap.py | 87 ++++++++ .../explainer/encapsulator/saliency_encap.py | 104 ++------- .../explainer/manager/explain_loader.py | 209 ++++++++++-------- .../explainer/manager/explain_manager.py | 1 + .../explainer/manager/explain_parser.py | 4 +- .../encapsulator/mock_explain_manager.py | 4 +- .../encapsulator/test_explain_job_encap.py | 4 +- .../explainer/manager/test_explain_loader.py | 4 - 15 files changed, 650 insertions(+), 241 deletions(-) create mode 100644 mindinsight/explainer/encapsulator/_hoc_pil_apply.py create mode 100644 mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py diff --git a/mindinsight/backend/explainer/explainer_api.py b/mindinsight/backend/explainer/explainer_api.py index 2126b2c9..6e669449 100644 --- a/mindinsight/backend/explainer/explainer_api.py +++ b/mindinsight/backend/explainer/explainer_api.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ # ============================================================================ """Explainer restful api.""" -import os import json +import os import urllib.parse from flask import Blueprint @@ -23,16 +23,18 @@ from flask import jsonify from flask import request from mindinsight.conf import settings -from mindinsight.utils.exceptions import ParamMissError -from mindinsight.utils.exceptions import ParamValueError from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.utils.tools import get_train_id -from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER -from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap -from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap +from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap +from mindinsight.explainer.encapsulator.hierarchical_occlusion_encap import HierarchicalOcclusionEncap +from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap +from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER +from mindinsight.utils.exceptions import ParamMissError +from mindinsight.utils.exceptions import ParamTypeError +from mindinsight.utils.exceptions import ParamValueError URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) @@ -66,6 +68,50 @@ def _read_post_request(post_request): return body +def _get_query_sample_parameters(data): + """Get parameter for query.""" + + train_id = data.get("train_id") + if train_id is None: + raise ParamMissError('train_id') + + labels = data.get("labels") + if labels is not None and not isinstance(labels, list): + raise ParamTypeError("labels", (list, None)) + if labels: + for item in labels: + if not isinstance(item, str): + raise ParamTypeError("element of labels", str) + + limit = data.get("limit", 10) + limit = Validation.check_limit(limit, min_value=1, max_value=100) + offset = data.get("offset", 0) + offset = Validation.check_offset(offset=offset) + sorted_name = data.get("sorted_name", "") + sorted_type = data.get("sorted_type", "descending") + if sorted_name not in ("", "confidence", "uncertainty"): + raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") + if sorted_type not in ("ascending", "descending"): + raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") + + prediction_types = data.get("prediction_types") + if prediction_types is not None and not isinstance(prediction_types, list): + raise ParamTypeError("prediction_types", (list, None)) + if prediction_types: + for item in prediction_types: + if item not in ['TP', 'FN', 'FP']: + raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.") + + query_kwarg = {"train_id": train_id, + "labels": labels, + "limit": limit, + "offset": offset, + "sorted_name": sorted_name, + "sorted_type": sorted_type, + "prediction_types": prediction_types} + return query_kwarg + + @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) def query_explain_jobs(): """Query explain jobs.""" @@ -99,37 +145,40 @@ def query_explain_job(): @BLUEPRINT.route("/explainer/saliency", methods=["POST"]) def query_saliency(): """Query saliency map related results.""" - data = _read_post_request(request) - - train_id = data.get("train_id") - if train_id is None: - raise ParamMissError('train_id') - - labels = data.get("labels") + query_kwarg = _get_query_sample_parameters(data) explainers = data.get("explainers") - limit = data.get("limit", 10) - limit = Validation.check_limit(limit, min_value=1, max_value=100) - offset = data.get("offset", 0) - offset = Validation.check_offset(offset=offset) - sorted_name = data.get("sorted_name", "") - sorted_type = data.get("sorted_type", "descending") + if explainers is not None and not isinstance(explainers, list): + raise ParamTypeError("explainers", (list, None)) + if explainers: + for item in explainers: + if not isinstance(item, str): + raise ParamTypeError("element of explainers", str) - if sorted_name not in ("", "confidence", "uncertainty"): - raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") - if sorted_type not in ("ascending", "descending"): - raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") + query_kwarg["explainers"] = explainers encapsulator = SaliencyEncap( _image_url_formatter, EXPLAIN_MANAGER) - count, samples = encapsulator.query_saliency_maps(train_id=train_id, - labels=labels, - explainers=explainers, - limit=limit, - offset=offset, - sorted_name=sorted_name, - sorted_type=sorted_type) + count, samples = encapsulator.query_saliency_maps(**query_kwarg) + + return jsonify({ + "count": count, + "samples": samples + }) + + +@BLUEPRINT.route("/explainer/hoc", methods=["POST"]) +def query_hoc(): + """Query hierarchical occlusion related results.""" + data = _read_post_request(request) + + query_kwargs = _get_query_sample_parameters(data) + + encapsulator = HierarchicalOcclusionEncap( + _image_url_formatter, + EXPLAIN_MANAGER) + count, samples = encapsulator.query_hierarchical_occlusion(**query_kwargs) return jsonify({ "count": count, @@ -162,8 +211,8 @@ def query_image(): image_type = request.args.get("type") if image_type is None: raise ParamMissError("type") - if image_type not in ("original", "overlay"): - raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") + if image_type not in ("original", "overlay", "outcome"): + raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay' 'outcome'") encapsulator = DatafileEncap(EXPLAIN_MANAGER) image = encapsulator.query_image_binary(train_id, image_path, image_type) @@ -177,6 +226,5 @@ def init_module(app): Args: app: the application obj. - """ app.register_blueprint(BLUEPRINT) diff --git a/mindinsight/datavisual/proto_files/mindinsight_summary.proto b/mindinsight/datavisual/proto_files/mindinsight_summary.proto index 2973820a..51871e1b 100644 --- a/mindinsight/datavisual/proto_files/mindinsight_summary.proto +++ b/mindinsight/datavisual/proto_files/mindinsight_summary.proto @@ -138,6 +138,17 @@ message Explain { repeated string benchmark_method = 3; } + message HocLayer{ + optional float prob = 1; + repeated int32 box = 2; // List of repeated x, y, w, h + } + + message Hoc { + optional int32 label = 1; + optional string mask = 2; + repeated HocLayer layer = 3; + } + optional int32 sample_id = 1; // The Metadata and sample id must have one fill in optional string image_path = 2; repeated int32 ground_truth_label = 3; @@ -148,4 +159,6 @@ message Explain { optional Metadata metadata = 7; optional string status = 8; // enum value: run, end + + repeated Hoc hoc = 9; // hierarchical occlusion counterfactual } diff --git a/mindinsight/explainer/common/enums.py b/mindinsight/explainer/common/enums.py index 6afc8efb..d2c29f5e 100644 --- a/mindinsight/explainer/common/enums.py +++ b/mindinsight/explainer/common/enums.py @@ -40,6 +40,7 @@ class ExplainFieldsEnum(BaseEnum): GROUND_TRUTH_LABEL = 'ground_truth_label' INFERENCE = 'inference' EXPLANATION = 'explanation' + HIERARCHICAL_OCCLUSION = 'hierarchical_occlusion' STATUS = 'status' diff --git a/mindinsight/explainer/encapsulator/_hoc_pil_apply.py b/mindinsight/explainer/encapsulator/_hoc_pil_apply.py new file mode 100644 index 00000000..2e0f8996 --- /dev/null +++ b/mindinsight/explainer/encapsulator/_hoc_pil_apply.py @@ -0,0 +1,156 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Utility functions for hierarchcial occlusion image generation.""" + +import re + +import PIL +from PIL import ImageDraw, ImageEnhance, ImageFilter + +MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' + + +class EditStep: + """ + Class that represents an edit step. + + Args: + layer (int): Layer index. + x (int): Left pixel coordinate. + y (int): Top pixel coordinate. + width (int): Width in pixels. + height (int): Height in pixels. + """ + def __init__(self, + layer: int, + x: int, + y: int, + width: int, + height: int): + self.layer = layer + self.x = x + self.y = y + self.width = width + self.height = height + + def to_coord_box(self): + """ + Convert to pixel coordinate box. + + Returns: + tuple[int, int, int, int], tuple of left, top, right, bottom pixel coordinate. + """ + return self.x, self.y, self.x + self.width, self.y + self.height + + +def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=False): + """ + Apply edit steps on a PIL image. + + Args: + image (PIL.Image): The input image in RGB mode. + mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string + e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object. + edit_steps (list[EditStep]): Edit steps to be drawn. + by_masking (bool): Whether to use masking method. Default: False. + inplace (bool): True to draw on the input image, otherwise draw on a cloned image. + + Returns: + PIL.Image, the result image. + """ + if isinstance(mask, str): + mask = pil_compile_str_mask(mask, image) + if by_masking: + return _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace) + return _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace) + + +def pil_compile_str_mask(mask, image): + """Concert string mask to PIL Image.""" + match = re.match(MASK_GAUSSIAN_RE, mask) + if match: + radius = int(match.group(1)) + if radius > 0: + image_filter = ImageFilter.GaussianBlur(radius=radius) + mask_image = image.filter(image_filter) + mask_image = ImageEnhance.Brightness(mask_image).enhance(0.7) + mask_image = ImageEnhance.Color(mask_image).enhance(0.0) + return mask_image + raise ValueError(f"Invalid string mask: '{mask}'.") + + +def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False): + """ + Apply edit steps from unmasking method on a PIL image. + + Args: + image (PIL.Image): The input image. + mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey + scale intensity [0, 255], a RBG tuple or a PIL Image. + edit_steps (list[EditStep]): Edit steps to be drawn. + inplace (bool): True to draw on the input image, otherwise draw on a cloned image. + + Returns: + PIL.Image, the result image. + """ + if isinstance(mask, PIL.Image.Image): + if inplace: + bg = mask + else: + bg = mask.copy() + else: + if inplace: + raise ValueError('Argument inplace cannot be True when mask is not a PIL Image.') + if isinstance(mask, int): + mask = (mask, mask, mask) + bg = PIL.Image.new(mode="RGB", size=image.size, color=mask) + + for step in edit_steps: + box = step.to_coord_box() + cropped = image.crop(box) + bg.paste(cropped, box=box) + return bg + + +def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False): + """ + Apply edit steps from unmasking method on a PIL image. + + Args: + image (PIL.Image): The input image. + mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey + scale intensity [0, 255], a RBG tuple or a PIL Image. + edit_steps (list[EditStep]): Edit steps to be drawn. + inplace (bool): True to draw on the input image, otherwise draw on a cloned image. + + Returns: + PIL.Image, the result image. + """ + if not inplace: + image = image.copy() + + if isinstance(mask, PIL.Image.Image): + for step in edit_steps: + box = step.to_coord_box() + cropped = mask.crop(box) + image.paste(cropped, box=box) + else: + if isinstance(mask, int): + mask = (mask, mask, mask) + draw = ImageDraw.Draw(image) + for step in edit_steps: + draw.rectangle(step.to_coord_box(), fill=mask) + return image diff --git a/mindinsight/explainer/encapsulator/datafile_encap.py b/mindinsight/explainer/encapsulator/datafile_encap.py index bfcb06d1..249b6858 100644 --- a/mindinsight/explainer/encapsulator/datafile_encap.py +++ b/mindinsight/explainer/encapsulator/datafile_encap.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +14,17 @@ # ============================================================================ """Datafile encapsulator.""" -import os import io +import os -from PIL import Image import numpy as np +from PIL import Image -from mindinsight.utils.exceptions import UnknownError -from mindinsight.utils.exceptions import FileSystemPermissionError from mindinsight.datavisual.common.exceptions import ImageNotExistError +from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap +from mindinsight.utils.exceptions import FileSystemPermissionError +from mindinsight.utils.exceptions import UnknownError # Max uint8 value. for converting RGB pixels to [0,1] intensity. _UINT8_MAX = 255 @@ -59,11 +60,44 @@ class DatafileEncap(ExplainDataEncap): Args: train_id (str): Job ID. image_path (str): Image path relative to explain job's summary directory. - image_type (str): Image type, 'original' or 'overlay'. + image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'. Returns: bytes, image binary. """ + if image_type == "outcome": + sample_id, label, layer = image_path.strip(".jpg").split("_") + layer = int(layer) + job = self.job_manager.get_job(train_id) + samples = job.samples + label_idx = job.labels.index(label) + + chosen_sample = samples[int(sample_id)] + original_path_image = chosen_sample['image'] + abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), + original_path_image) + if self._is_forbidden(abs_image_path): + raise FileSystemPermissionError("Forbidden.") + try: + image = Image.open(abs_image_path) + except FileNotFoundError: + raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") + except PermissionError: + raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") + except OSError: + raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") + + edit_steps = [] + boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] + mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] + + for box in boxes: + edit_steps.append(EditStep(layer, *box)) + image_cp = pil_apply_edit_steps(image, mask, edit_steps) + buffer = io.BytesIO() + image_cp.save(buffer, format=_PNG_FORMAT) + + return buffer.getvalue() abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), @@ -94,10 +128,10 @@ class DatafileEncap(ExplainDataEncap): raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") if image.mode == _SINGLE_CHANNEL_MODE: - saliency = np.asarray(image)/_UINT8_MAX + saliency = np.asarray(image) / _UINT8_MAX elif image.mode == _RGB_MODE: saliency = np.asarray(image) - saliency = saliency[:, :, 0]/_UINT8_MAX + saliency = saliency[:, :, 0] / _UINT8_MAX else: raise UnknownError(f"Invalid overlay image mode:{image.mode}.") @@ -105,7 +139,7 @@ class DatafileEncap(ExplainDataEncap): for c in range(3): saliency_stack[:, :, c] = saliency rgba = saliency_stack * _SALIENCY_CMAP_HI - rgba += (1-saliency_stack) * _SALIENCY_CMAP_LOW + rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW rgba[:, :, 3] = saliency * _UINT8_MAX overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE) diff --git a/mindinsight/explainer/encapsulator/explain_data_encap.py b/mindinsight/explainer/encapsulator/explain_data_encap.py index 7f819bc2..eb12fd7d 100644 --- a/mindinsight/explainer/encapsulator/explain_data_encap.py +++ b/mindinsight/explainer/encapsulator/explain_data_encap.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""Explain data encapsulator base class.""" +"""Common explain data encapsulator base class.""" + +import copy + +from mindinsight.utils.exceptions import ParamValueError + + +def _sort_key_min_confidence(sample, labels): + """Samples sort key by the min. confidence.""" + min_confidence = float("+inf") + for inference in sample["inferences"]: + if labels and inference["label"] not in labels: + continue + if inference["confidence"] < min_confidence: + min_confidence = inference["confidence"] + return min_confidence + + +def _sort_key_max_confidence(sample, labels): + """Samples sort key by the max. confidence.""" + max_confidence = float("-inf") + for inference in sample["inferences"]: + if labels and inference["label"] not in labels: + continue + if inference["confidence"] > max_confidence: + max_confidence = inference["confidence"] + return max_confidence + + +def _sort_key_min_confidence_sd(sample, labels): + """Samples sort key by the min. confidence_sd.""" + min_confidence_sd = float("+inf") + for inference in sample["inferences"]: + if labels and inference["label"] not in labels: + continue + confidence_sd = inference.get("confidence_sd", float("+inf")) + if confidence_sd < min_confidence_sd: + min_confidence_sd = confidence_sd + return min_confidence_sd + + +def _sort_key_max_confidence_sd(sample, labels): + """Samples sort key by the max. confidence_sd.""" + max_confidence_sd = float("-inf") + for inference in sample["inferences"]: + if labels and inference["label"] not in labels: + continue + confidence_sd = inference.get("confidence_sd", float("-inf")) + if confidence_sd > max_confidence_sd: + max_confidence_sd = confidence_sd + return max_confidence_sd class ExplainDataEncap: @@ -24,3 +74,77 @@ class ExplainDataEncap: @property def job_manager(self): return self._job_manager + + +class ExplanationEncap(ExplainDataEncap): + """Base encapsulator for explanation queries.""" + + def __init__(self, image_url_formatter, *args, **kwargs): + super().__init__(*args, **kwargs) + self._image_url_formatter = image_url_formatter + + def _query_samples(self, + job, + labels, + sorted_name, + sorted_type, + prediction_types=None, + query_type="saliency_maps"): + """ + Query samples. + + Args: + job (ExplainManager): Explain job to be query from. + labels (list[str]): Label filter. + sorted_name (str): Field to be sorted. + sorted_type (str): Sorting order, 'ascending' or 'descending'. + prediction_types (list[str]): Prediction type filter. + + Returns: + list[dict], samples to be queried. + """ + + samples = copy.deepcopy(job.get_all_samples()) + samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])] + if labels: + filtered = [] + for sample in samples: + infer_labels = [inference["label"] for inference in sample["inferences"]] + for infer_label in infer_labels: + if infer_label in labels: + filtered.append(sample) + break + samples = filtered + + if prediction_types and len(prediction_types) < 3: + filtered = [] + for sample in samples: + infer_types = [inference["prediction_type"] for inference in sample["inferences"]] + for infer_type in infer_types: + if infer_type in prediction_types: + filtered.append(sample) + break + samples = filtered + + reverse = sorted_type == "descending" + if sorted_name == "confidence": + if reverse: + samples.sort(key=lambda x: _sort_key_max_confidence(x, labels), reverse=reverse) + else: + samples.sort(key=lambda x: _sort_key_min_confidence(x, labels), reverse=reverse) + elif sorted_name == "uncertainty": + if not job.uncertainty_enabled: + raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'") + if reverse: + samples.sort(key=lambda x: _sort_key_max_confidence_sd(x, labels), reverse=reverse) + else: + samples.sort(key=lambda x: _sort_key_min_confidence_sd(x, labels), reverse=reverse) + elif sorted_name != "": + raise ParamValueError("sorted_name") + return samples + + def _get_image_url(self, train_id, image_path, image_type): + """Returns image's url.""" + if self._image_url_formatter is None: + return image_path + return self._image_url_formatter(train_id, image_path, image_type) diff --git a/mindinsight/explainer/encapsulator/explain_job_encap.py b/mindinsight/explainer/encapsulator/explain_job_encap.py index 0b2bd206..5e45facc 100644 --- a/mindinsight/explainer/encapsulator/explain_job_encap.py +++ b/mindinsight/explainer/encapsulator/explain_job_encap.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ # ============================================================================ """Explain job list encapsulator.""" -import copy from datetime import datetime from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap @@ -61,6 +60,9 @@ class ExplainJobEncap(ExplainDataEncap): info["train_id"] = dir_info["relative_path"] info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT) info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT) + info["saliency_map"] = dir_info["saliency_map"] + info["hierarchical_occlusion"] = dir_info["hierarchical_occlusion"] + return info @classmethod @@ -79,7 +81,7 @@ class ExplainJobEncap(ExplainDataEncap): """Convert ExplainJob's meta-data to jsonable info object.""" info = cls._job_2_info(job) info["sample_count"] = job.sample_count - info["classes"] = copy.deepcopy(job.all_classes) + info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0] saliency_info = dict() if job.min_confidence is None: saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE diff --git a/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py b/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py new file mode 100644 index 00000000..35098843 --- /dev/null +++ b/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py @@ -0,0 +1,87 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Hierarchical Occlusion encapsulator.""" + +from mindinsight.datavisual.common.exceptions import TrainJobNotExistError +from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap + + +class HierarchicalOcclusionEncap(ExplanationEncap): + """Hierarchical occlusion encapsulator.""" + + def query_hierarchical_occlusion(self, + train_id, + labels, + limit, + offset, + sorted_name, + sorted_type, + prediction_types=None + ): + """ + Query hierarchical occlusion results. + + Args: + train_id (str): Job ID. + labels (list[str]): Label filter. + limit (int): Maximum number of items to be returned. + offset (int): Page offset. + sorted_name (str): Field to be sorted. + sorted_type (str): Sorting order, 'ascending' or 'descending'. + prediction_types (list[str]): Prediction types filter. + + Returns: + tuple[int, list[dict]], total number of samples after filtering and list of sample results. + """ + job = self.job_manager.get_job(train_id) + if job is None: + raise TrainJobNotExistError(train_id) + + samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, + query_type="hoc_layers") + sample_infos = [] + obj_offset = offset * limit + count = len(samples) + end = count + if obj_offset + limit < end: + end = obj_offset + limit + for i in range(obj_offset, end): + sample = samples[i] + sample_infos.append(self._touch_sample(sample, job)) + + return count, sample_infos + + def _touch_sample(self, sample, job): + """ + Final edit on single sample info. + + Args: + sample (dict): Sample info. + job (ExplainManager): Explain job. + + Returns: + dict, the edited sample info. + """ + sample_cp = sample.copy() + sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original") + for inference_item in sample_cp["inferences"]: + new_list = [] + for idx, hoc_layer in enumerate(inference_item["hoc_layers"]): + hoc_layer["outcome"] = self._get_image_url(job.train_id, + f"{sample['id']}_{inference_item['label']}_{idx}.jpg", + "outcome") + new_list.append(hoc_layer) + inference_item["hoc_layers"] = new_list + return sample_cp diff --git a/mindinsight/explainer/encapsulator/saliency_encap.py b/mindinsight/explainer/encapsulator/saliency_encap.py index 07f434da..8270a0df 100644 --- a/mindinsight/explainer/encapsulator/saliency_encap.py +++ b/mindinsight/explainer/encapsulator/saliency_encap.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,58 +14,13 @@ # ============================================================================ """Saliency map encapsulator.""" -import copy - -from mindinsight.utils.exceptions import ParamValueError -from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap from mindinsight.datavisual.common.exceptions import TrainJobNotExistError +from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap -def _sort_key_min_confidence(sample): - """Samples sort key by the min. confidence.""" - min_confidence = float("+inf") - for inference in sample["inferences"]: - if inference["confidence"] < min_confidence: - min_confidence = inference["confidence"] - return min_confidence - - -def _sort_key_max_confidence(sample): - """Samples sort key by the max. confidence.""" - max_confidence = float("-inf") - for inference in sample["inferences"]: - if inference["confidence"] > max_confidence: - max_confidence = inference["confidence"] - return max_confidence - - -def _sort_key_min_confidence_sd(sample): - """Samples sort key by the min. confidence_sd.""" - min_confidence_sd = float("+inf") - for inference in sample["inferences"]: - confidence_sd = inference.get("confidence_sd", float("+inf")) - if confidence_sd < min_confidence_sd: - min_confidence_sd = confidence_sd - return min_confidence_sd - - -def _sort_key_max_confidence_sd(sample): - """Samples sort key by the max. confidence_sd.""" - max_confidence_sd = float("-inf") - for inference in sample["inferences"]: - confidence_sd = inference.get("confidence_sd", float("-inf")) - if confidence_sd > max_confidence_sd: - max_confidence_sd = confidence_sd - return max_confidence_sd - - -class SaliencyEncap(ExplainDataEncap): +class SaliencyEncap(ExplanationEncap): """Saliency map encapsulator.""" - def __init__(self, image_url_formatter, *args, **kwargs): - super().__init__(*args, **kwargs) - self._image_url_formatter = image_url_formatter - def query_saliency_maps(self, train_id, labels, @@ -73,55 +28,32 @@ class SaliencyEncap(ExplainDataEncap): limit, offset, sorted_name, - sorted_type): + sorted_type, + prediction_types=None): """ Query saliency maps. Args: train_id (str): Job ID. labels (list[str]): Label filter. explainers (list[str]): Explainers of saliency maps to be shown. - limit (int): Max. no. of items to be returned. + limit (int): Maximum number of items to be returned. offset (int): Page offset. sorted_name (str): Field to be sorted. sorted_type (str): Sorting order, 'ascending' or 'descending'. + prediction_types (list[str]): Prediction types filter. Default: None. Returns: - tuple[int, list[dict]], total no. of samples after filtering and - list of sample result. + tuple[int, list[dict]], total number of samples after filtering and list of sample result. """ job = self.job_manager.get_job(train_id) if job is None: raise TrainJobNotExistError(train_id) - samples = copy.deepcopy(job.get_all_samples()) - if labels: - filtered = [] - for sample in samples: - infer_labels = [inference["label"] for inference in sample["inferences"]] - for infer_label in infer_labels: - if infer_label in labels: - filtered.append(sample) - break - samples = filtered - - reverse = sorted_type == "descending" - if sorted_name == "confidence": - if reverse: - samples.sort(key=_sort_key_max_confidence, reverse=reverse) - else: - samples.sort(key=_sort_key_min_confidence, reverse=reverse) - elif sorted_name == "uncertainty": - if not job.uncertainty_enabled: - raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'") - if reverse: - samples.sort(key=_sort_key_max_confidence_sd, reverse=reverse) - else: - samples.sort(key=_sort_key_min_confidence_sd, reverse=reverse) - elif sorted_name != "": - raise ParamValueError("sorted_name") + samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, + query_type="saliency_maps") sample_infos = [] - obj_offset = offset*limit + obj_offset = offset * limit count = len(samples) end = count if obj_offset + limit < end: @@ -139,11 +71,13 @@ class SaliencyEncap(ExplainDataEncap): sample (dict): Sample info. job (ExplainJob): Explain job. explainers (list[str]): Explainer names. + Returns: dict, the edited sample info. """ - sample["image"] = self._get_image_url(job.train_id, sample['image'], "original") - for inference in sample["inferences"]: + sample_cp = sample.copy() + sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") + for inference in sample_cp["inferences"]: new_list = [] for saliency_map in inference["saliency_maps"]: if explainers and saliency_map["explainer"] not in explainers: @@ -151,10 +85,4 @@ class SaliencyEncap(ExplainDataEncap): saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") new_list.append(saliency_map) inference["saliency_maps"] = new_list - return sample - - def _get_image_url(self, train_id, image_path, image_type): - """Returns image's url.""" - if self._image_url_formatter is None: - return image_path - return self._image_url_formatter(train_id, image_path, image_type) + return sample_cp diff --git a/mindinsight/explainer/manager/explain_loader.py b/mindinsight/explainer/manager/explain_loader.py index 5ee9c332..d6b9f275 100644 --- a/mindinsight/explainer/manager/explain_loader.py +++ b/mindinsight/explainer/manager/explain_loader.py @@ -14,14 +14,13 @@ # ============================================================================ """ExplainLoader.""" -from collections import defaultdict -from enum import Enum - import math import os import re import threading +from collections import defaultdict from datetime import datetime +from enum import Enum from typing import Dict, Iterable, List, Optional, Union from mindinsight.datavisual.common.exceptions import TrainJobNotExistError @@ -44,6 +43,7 @@ _SAMPLE_FIELD_NAMES = [ ExplainFieldsEnum.GROUND_TRUTH_LABEL, ExplainFieldsEnum.INFERENCE, ExplainFieldsEnum.EXPLANATION, + ExplainFieldsEnum.HIERARCHICAL_OCCLUSION ] @@ -75,10 +75,10 @@ class ExplainLoader: 'create_time': os.stat(summary_dir).st_ctime, 'update_time': os.stat(summary_dir).st_mtime, 'query_time': os.stat(summary_dir).st_ctime, - 'uncertainty_enabled': False, + 'uncertainty_enabled': False } self._samples = defaultdict(dict) - self._metadata = {'explainers': [], 'metrics': [], 'labels': []} + self._metadata = {'explainers': [], 'metrics': [], 'labels': [], 'min_confidence': 0.5} self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} self._status = _LoaderStatus.STOP.value @@ -91,25 +91,19 @@ class ExplainLoader: Returns: list[dict], a list of dict, each dict contains: - - id (int): label id - - label (str): label name - - sample_count (int): number of samples for each label + + - id (int): Label id. + - label (str): Label name. + - sample_count (int): Number of samples for each label. """ sample_count_per_label = defaultdict(int) - samples_copy = self._samples.copy() - for sample in samples_copy.values(): - if sample.get('image', False) and sample.get('ground_truth_label', False): - for label in sample['ground_truth_label']: + for sample in self._samples.values(): + if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')): + for label in set(sample['ground_truth_label'] + sample['predicted_label']): sample_count_per_label[label] += 1 - all_classes_return = [] - for label_id, label_name in enumerate(self._metadata['labels']): - single_info = { - 'id': label_id, - 'label': label_name, - 'sample_count': sample_count_per_label[label_id] - } - all_classes_return.append(single_info) + all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]} + for label_id, label_name in enumerate(self._metadata['labels'])] return all_classes_return @property @@ -166,19 +160,23 @@ class ExplainLoader: Returns: list[dict], A list of evaluation results of each explainer. Each item contains: + - explainer (str): Name of evaluated explainer. - evaluations (list[dict]): A list of evaluation results by different metrics. - class_scores (list[dict]): A list of evaluation results on different labels. Each item in the evaluations contains: + - metric (str): name of metric method - score (float): evaluation result Each item in the class_scores contains: + - label (str): Name of label - evaluations (list[dict]): A list of evaluation results on different labels by different metrics. Each item in evaluations contains: + - metric (str): Name of metric method - score (float): Evaluation scores of explainer on specific label by the metric. """ @@ -215,7 +213,7 @@ class ExplainLoader: @property def min_confidence(self) -> Optional[float]: """Return minimum confidence used to filter the predicted labels.""" - return None + return self._metadata['min_confidence'] @property def sample_count(self) -> int: @@ -227,19 +225,17 @@ class ExplainLoader: Return: int, total number of available samples in the loading job. - """ sample_count = 0 - samples_copy = self._samples.copy() - for sample in samples_copy.values(): - if sample.get('image', False) and sample.get('ground_truth_label', False): + for sample in self._samples.values(): + if sample.get('image', False): sample_count += 1 return sample_count @property def samples(self) -> List[Dict]: """Return the information of all samples in the job.""" - return self.get_all_samples() + return self._samples @property def train_id(self) -> str: @@ -298,6 +294,7 @@ class ExplainLoader: self._clear_job() if event_dict: self._import_data_from_event(event_dict) + self._reform_sample_info() @property def status(self): @@ -317,59 +314,19 @@ class ExplainLoader: def get_all_samples(self) -> List[Dict]: """ - Return a list of sample information cachced in the explain job + Return a list of sample information cached in the explain job Returns: sample_list (List[SampleObj]): a list of sample objects, each object consists of: - - id (int): sample id - - name (str): basename of image - - labels (list[str]): list of labels - - inferences list[dict]) + - id (int): Sample id. + - name (str): Basename of image. + - inferences (list[dict]): List of inferences for all labels. """ - returned_samples = [] - samples_copy = self._samples.copy() - for sample_id, sample_info in samples_copy.items(): - if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): - continue - returned_sample = { - 'id': sample_id, - 'name': str(sample_id), - 'image': sample_info['image'], - 'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']], - } - - # Check whether the sample has valid label-prob pairs. - if not ExplainLoader._is_inference_valid(sample_info): - continue - - inferences = {} - for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], - sample_info['ground_truth_prob'] + sample_info['predicted_prob']): - inferences[label] = { - 'label': self._metadata['labels'][label], - 'confidence': _round(prob), - 'saliency_maps': [] - } - - if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: - for label, std, low, high in zip( - sample_info['ground_truth_label'] + sample_info['predicted_label'], - sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], - sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], - sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] - ): - inferences[label]['confidence_sd'] = _round(std) - inferences[label]['confidence_itl95'] = [_round(low), _round(high)] - - for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): - for label, heatmap_path in label_heatmap_path_dict.items(): - if label in inferences: - inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) - - returned_sample['inferences'] = list(inferences.values()) - returned_samples.append(returned_sample) + returned_samples = [{'id': sample_id, 'name': info['name'], 'image': info['image'], + 'inferences': list(info['inferences'].values())} for sample_id, info in + self._samples.items() if info.get('image', False)] return returned_samples def _import_data_from_event(self, event_dict: Dict): @@ -461,30 +418,29 @@ class ExplainLoader: - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary of saliency maps. The structure of explanations demonstrates below: - { explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, ... } + - hierarchical_occlusion (dict): A dictionary where each label is matched to a dictionary: + {label_1: [{prob: layer1_prob, bbox: []}, {prob: layer2_prob, bbox: []}], + label_2: + } """ if getattr(sample, 'sample_id', None) is None: raise ParamValueError('sample_event has no sample_id') sample_id = sample.sample_id - samples_copy = self._samples.copy() - if sample_id not in samples_copy: + if sample_id not in self._samples: self._samples[sample_id] = { + 'id': sample_id, + 'name': str(sample_id), + 'image': sample.image_path, 'ground_truth_label': [], - 'ground_truth_prob': [], - 'ground_truth_prob_sd': [], - 'ground_truth_prob_itl95_low': [], - 'ground_truth_prob_itl95_hi': [], 'predicted_label': [], - 'predicted_prob': [], - 'predicted_prob_sd': [], - 'predicted_prob_itl95_low': [], - 'predicted_prob_itl95_hi': [], - 'explanation': defaultdict(dict) + 'inferences': defaultdict(dict), + 'explanation': defaultdict(dict), + 'hierarchical_occlusion': defaultdict(dict) } if sample.image_path: @@ -492,27 +448,66 @@ class ExplainLoader: for tag in _SAMPLE_FIELD_NAMES: if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: - self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) + if not self._samples[sample_id]['ground_truth_label']: + self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) elif tag == ExplainFieldsEnum.INFERENCE: self._import_inference_from_event(sample, sample_id) - else: + elif tag == ExplainFieldsEnum.EXPLANATION: self._import_explanation_from_event(sample, sample_id) + elif tag == ExplainFieldsEnum.HIERARCHICAL_OCCLUSION: + self._import_hoc_from_event(sample, sample_id) + + def _reform_sample_info(self): + """Reform the sample info.""" + for _, sample_info in self._samples.items(): + inferences = sample_info['inferences'] + res_dict = defaultdict(list) + for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): + for label, heatmap_path in label_heatmap_path_dict.items(): + res_dict[label].append({'explainer': explainer, 'overlay': heatmap_path}) + + for label, item in inferences.items(): + item['saliency_maps'] = res_dict[label] + + for label, item in sample_info['hierarchical_occlusion'].items(): + inferences[label]['hoc_layers'] = item['hoc_layers'] def _import_inference_from_event(self, event, sample_id): """Parse the inference event.""" inference = event.inference - self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) - self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) - self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) - self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) - self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) - self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) - self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) - self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) - self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) - - if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: + if inference.ground_truth_prob_sd or inference.predicted_prob_sd: self._loader_info['uncertainty_enabled'] = True + if not self._samples[sample_id]['predicted_label']: + self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) + if not self._samples[sample_id]['inferences']: + inferences = {} + for label, prob in zip(list(event.ground_truth_label) + list(inference.predicted_label), + list(inference.ground_truth_prob) + list(inference.predicted_prob)): + inferences[label] = { + 'label': self._metadata['labels'][label], + 'confidence': _round(prob), + 'saliency_maps': [], + 'hoc_layers': {}, + } + if not event.ground_truth_label: + inferences[label]['prediction_type'] = None + else: + if prob < self.min_confidence: + inferences[label]['prediction_type'] = 'FN' + elif label in event.ground_truth_label: + inferences[label]['prediction_type'] = 'TP' + else: + inferences[label]['prediction_type'] = 'FP' + if self._loader_info['uncertainty_enabled']: + for label, std, low, high in zip( + list(event.ground_truth_label) + list(inference.predicted_label), + list(inference.ground_truth_prob_sd) + list(inference.predicted_prob_sd), + list(inference.ground_truth_prob_itl95_low) + list(inference.predicted_prob_itl95_low), + list(inference.ground_truth_prob_itl95_hi) + list(inference.predicted_prob_itl95_hi)): + inferences[label]['confidence_sd'] = _round(std) + inferences[label]['confidence_itl95'] = [_round(low), _round(high)] + + self._samples[sample_id]['inferences'] = inferences def _import_explanation_from_event(self, event, sample_id): """Parse the explanation event.""" @@ -525,6 +520,24 @@ class ExplainLoader: label = explanation_item.label sample_explanation[explainer][label] = explanation_item.heatmap_path + def _import_hoc_from_event(self, event, sample_id): + """Parse the mango event.""" + sample_hoc = self._samples[sample_id]['hierarchical_occlusion'] + if event.hierarchical_occlusion: + for hoc_item in event.hierarchical_occlusion: + label = hoc_item.label + sample_hoc[label] = {} + sample_hoc[label]['label'] = label + sample_hoc[label]['mask'] = hoc_item.mask + sample_hoc[label]['confidence'] = self._samples[sample_id]['inferences'][label]['confidence'] + sample_hoc[label]['hoc_layers'] = [] + for hoc_layer in hoc_item.layer: + sample_hoc_dict = {'confidence': hoc_layer.prob} + box_lst = list(hoc_layer.box) + box = [box_lst[i: i + 4] for i in range(0, len(hoc_layer.box), 4)] + sample_hoc_dict['boxes'] = box + sample_hoc[label]['hoc_layers'].append(sample_hoc_dict) + def _clear_job(self): """Clear the cached data and update the time info of the loader.""" self._samples.clear() diff --git a/mindinsight/explainer/manager/explain_manager.py b/mindinsight/explainer/manager/explain_manager.py index 4d8ea791..f04a5c81 100644 --- a/mindinsight/explainer/manager/explain_manager.py +++ b/mindinsight/explainer/manager/explain_manager.py @@ -114,6 +114,7 @@ class ExplainManager: Returns: tuple[total, directories], total indicates the overall number of explain directories and directories indicate list of summary directory info including the following attributes. + - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, starting with "./". - create_time (datetime): Creation time of summary file. diff --git a/mindinsight/explainer/manager/explain_parser.py b/mindinsight/explainer/manager/explain_parser.py index 2e7e8e7c..ec6aaebd 100644 --- a/mindinsight/explainer/manager/explain_parser.py +++ b/mindinsight/explainer/manager/explain_parser.py @@ -30,6 +30,7 @@ from mindinsight.utils.exceptions import UnknownError HEADER_SIZE = 8 CRC_STR_SIZE = 4 MAX_EVENT_STRING = 500000000 + BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', @@ -42,7 +43,7 @@ InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', 'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']) SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', - 'explanation', 'status']) + 'explanation', 'hierarchical_occlusion', 'status']) class ExplainParser(_SummaryParser): @@ -193,6 +194,7 @@ class ExplainParser(_SummaryParser): ground_truth_label=tensor_event_value.ground_truth_label, inference=inference, explanation=tensor_event_value.explanation, + hierarchical_occlusion=tensor_event_value.hoc, status=tensor_event_value.status ) return sample_data diff --git a/tests/ut/explainer/encapsulator/mock_explain_manager.py b/tests/ut/explainer/encapsulator/mock_explain_manager.py index 624beb1e..a54db862 100644 --- a/tests/ut/explainer/encapsulator/mock_explain_manager.py +++ b/tests/ut/explainer/encapsulator/mock_explain_manager.py @@ -105,7 +105,9 @@ class MockExplainManager: { "relative_path": "./mock_job_1", "create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), - "update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT) + "update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), + "saliency_map": True, + "hierarchical_occlusion": True } ] return 1, job_list diff --git a/tests/ut/explainer/encapsulator/test_explain_job_encap.py b/tests/ut/explainer/encapsulator/test_explain_job_encap.py index 72675c80..f5752c11 100644 --- a/tests/ut/explainer/encapsulator/test_explain_job_encap.py +++ b/tests/ut/explainer/encapsulator/test_explain_job_encap.py @@ -31,7 +31,9 @@ class TestExplainJobEncap: { "train_id": "./mock_job_1", "create_time": "2020-10-01 20:21:23", - "update_time": "2020-10-01 20:21:23" + "update_time": "2020-10-01 20:21:23", + "saliency_map": True, + "hierarchical_occlusion": True } ]) assert job_list == expected_result diff --git a/tests/ut/explainer/manager/test_explain_loader.py b/tests/ut/explainer/manager/test_explain_loader.py index d0b75991..f719b15a 100644 --- a/tests/ut/explainer/manager/test_explain_loader.py +++ b/tests/ut/explainer/manager/test_explain_loader.py @@ -24,10 +24,6 @@ from mindinsight.explainer.manager.explain_loader import _LoaderStatus from mindinsight.explainer.manager.explain_parser import ExplainParser -def abc(): - FileHandler.is_file('aaa') - print('after') - class TestExplainLoader: """Test explain loader class."""