From: @lixiaohui33 Reviewed-by: @ouwenchang Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,8 +14,8 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Explainer restful api.""" | """Explainer restful api.""" | ||||
| import os | |||||
| import json | import json | ||||
| import os | |||||
| import urllib.parse | import urllib.parse | ||||
| from flask import Blueprint | from flask import Blueprint | ||||
| @@ -23,16 +23,18 @@ from flask import jsonify | |||||
| from flask import request | from flask import request | ||||
| from mindinsight.conf import settings | from mindinsight.conf import settings | ||||
| from mindinsight.utils.exceptions import ParamMissError | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| from mindinsight.datavisual.common.validation import Validation | from mindinsight.datavisual.common.validation import Validation | ||||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | ||||
| from mindinsight.datavisual.utils.tools import get_train_id | from mindinsight.datavisual.utils.tools import get_train_id | ||||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||||
| from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | |||||
| from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap | from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap | ||||
| from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | |||||
| from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap | from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap | ||||
| from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | |||||
| from mindinsight.explainer.encapsulator.hierarchical_occlusion_encap import HierarchicalOcclusionEncap | |||||
| from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | |||||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||||
| from mindinsight.utils.exceptions import ParamMissError | |||||
| from mindinsight.utils.exceptions import ParamTypeError | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX | URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX | ||||
| BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | ||||
| @@ -66,6 +68,50 @@ def _read_post_request(post_request): | |||||
| return body | return body | ||||
| def _get_query_sample_parameters(data): | |||||
| """Get parameter for query.""" | |||||
| train_id = data.get("train_id") | |||||
| if train_id is None: | |||||
| raise ParamMissError('train_id') | |||||
| labels = data.get("labels") | |||||
| if labels is not None and not isinstance(labels, list): | |||||
| raise ParamTypeError("labels", (list, None)) | |||||
| if labels: | |||||
| for item in labels: | |||||
| if not isinstance(item, str): | |||||
| raise ParamTypeError("element of labels", str) | |||||
| limit = data.get("limit", 10) | |||||
| limit = Validation.check_limit(limit, min_value=1, max_value=100) | |||||
| offset = data.get("offset", 0) | |||||
| offset = Validation.check_offset(offset=offset) | |||||
| sorted_name = data.get("sorted_name", "") | |||||
| sorted_type = data.get("sorted_type", "descending") | |||||
| if sorted_name not in ("", "confidence", "uncertainty"): | |||||
| raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") | |||||
| if sorted_type not in ("ascending", "descending"): | |||||
| raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") | |||||
| prediction_types = data.get("prediction_types") | |||||
| if prediction_types is not None and not isinstance(prediction_types, list): | |||||
| raise ParamTypeError("prediction_types", (list, None)) | |||||
| if prediction_types: | |||||
| for item in prediction_types: | |||||
| if item not in ['TP', 'FN', 'FP']: | |||||
| raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.") | |||||
| query_kwarg = {"train_id": train_id, | |||||
| "labels": labels, | |||||
| "limit": limit, | |||||
| "offset": offset, | |||||
| "sorted_name": sorted_name, | |||||
| "sorted_type": sorted_type, | |||||
| "prediction_types": prediction_types} | |||||
| return query_kwarg | |||||
| @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) | @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) | ||||
| def query_explain_jobs(): | def query_explain_jobs(): | ||||
| """Query explain jobs.""" | """Query explain jobs.""" | ||||
| @@ -99,37 +145,40 @@ def query_explain_job(): | |||||
| @BLUEPRINT.route("/explainer/saliency", methods=["POST"]) | @BLUEPRINT.route("/explainer/saliency", methods=["POST"]) | ||||
| def query_saliency(): | def query_saliency(): | ||||
| """Query saliency map related results.""" | """Query saliency map related results.""" | ||||
| data = _read_post_request(request) | data = _read_post_request(request) | ||||
| train_id = data.get("train_id") | |||||
| if train_id is None: | |||||
| raise ParamMissError('train_id') | |||||
| labels = data.get("labels") | |||||
| query_kwarg = _get_query_sample_parameters(data) | |||||
| explainers = data.get("explainers") | explainers = data.get("explainers") | ||||
| limit = data.get("limit", 10) | |||||
| limit = Validation.check_limit(limit, min_value=1, max_value=100) | |||||
| offset = data.get("offset", 0) | |||||
| offset = Validation.check_offset(offset=offset) | |||||
| sorted_name = data.get("sorted_name", "") | |||||
| sorted_type = data.get("sorted_type", "descending") | |||||
| if explainers is not None and not isinstance(explainers, list): | |||||
| raise ParamTypeError("explainers", (list, None)) | |||||
| if explainers: | |||||
| for item in explainers: | |||||
| if not isinstance(item, str): | |||||
| raise ParamTypeError("element of explainers", str) | |||||
| if sorted_name not in ("", "confidence", "uncertainty"): | |||||
| raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") | |||||
| if sorted_type not in ("ascending", "descending"): | |||||
| raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") | |||||
| query_kwarg["explainers"] = explainers | |||||
| encapsulator = SaliencyEncap( | encapsulator = SaliencyEncap( | ||||
| _image_url_formatter, | _image_url_formatter, | ||||
| EXPLAIN_MANAGER) | EXPLAIN_MANAGER) | ||||
| count, samples = encapsulator.query_saliency_maps(train_id=train_id, | |||||
| labels=labels, | |||||
| explainers=explainers, | |||||
| limit=limit, | |||||
| offset=offset, | |||||
| sorted_name=sorted_name, | |||||
| sorted_type=sorted_type) | |||||
| count, samples = encapsulator.query_saliency_maps(**query_kwarg) | |||||
| return jsonify({ | |||||
| "count": count, | |||||
| "samples": samples | |||||
| }) | |||||
| @BLUEPRINT.route("/explainer/hoc", methods=["POST"]) | |||||
| def query_hoc(): | |||||
| """Query hierarchical occlusion related results.""" | |||||
| data = _read_post_request(request) | |||||
| query_kwargs = _get_query_sample_parameters(data) | |||||
| encapsulator = HierarchicalOcclusionEncap( | |||||
| _image_url_formatter, | |||||
| EXPLAIN_MANAGER) | |||||
| count, samples = encapsulator.query_hierarchical_occlusion(**query_kwargs) | |||||
| return jsonify({ | return jsonify({ | ||||
| "count": count, | "count": count, | ||||
| @@ -162,8 +211,8 @@ def query_image(): | |||||
| image_type = request.args.get("type") | image_type = request.args.get("type") | ||||
| if image_type is None: | if image_type is None: | ||||
| raise ParamMissError("type") | raise ParamMissError("type") | ||||
| if image_type not in ("original", "overlay"): | |||||
| raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") | |||||
| if image_type not in ("original", "overlay", "outcome"): | |||||
| raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay' 'outcome'") | |||||
| encapsulator = DatafileEncap(EXPLAIN_MANAGER) | encapsulator = DatafileEncap(EXPLAIN_MANAGER) | ||||
| image = encapsulator.query_image_binary(train_id, image_path, image_type) | image = encapsulator.query_image_binary(train_id, image_path, image_type) | ||||
| @@ -177,6 +226,5 @@ def init_module(app): | |||||
| Args: | Args: | ||||
| app: the application obj. | app: the application obj. | ||||
| """ | """ | ||||
| app.register_blueprint(BLUEPRINT) | app.register_blueprint(BLUEPRINT) | ||||
| @@ -138,6 +138,17 @@ message Explain { | |||||
| repeated string benchmark_method = 3; | repeated string benchmark_method = 3; | ||||
| } | } | ||||
| message HocLayer{ | |||||
| optional float prob = 1; | |||||
| repeated int32 box = 2; // List of repeated x, y, w, h | |||||
| } | |||||
| message Hoc { | |||||
| optional int32 label = 1; | |||||
| optional string mask = 2; | |||||
| repeated HocLayer layer = 3; | |||||
| } | |||||
| optional int32 sample_id = 1; // The Metadata and sample id must have one fill in | optional int32 sample_id = 1; // The Metadata and sample id must have one fill in | ||||
| optional string image_path = 2; | optional string image_path = 2; | ||||
| repeated int32 ground_truth_label = 3; | repeated int32 ground_truth_label = 3; | ||||
| @@ -148,4 +159,6 @@ message Explain { | |||||
| optional Metadata metadata = 7; | optional Metadata metadata = 7; | ||||
| optional string status = 8; // enum value: run, end | optional string status = 8; // enum value: run, end | ||||
| repeated Hoc hoc = 9; // hierarchical occlusion counterfactual | |||||
| } | } | ||||
| @@ -40,6 +40,7 @@ class ExplainFieldsEnum(BaseEnum): | |||||
| GROUND_TRUTH_LABEL = 'ground_truth_label' | GROUND_TRUTH_LABEL = 'ground_truth_label' | ||||
| INFERENCE = 'inference' | INFERENCE = 'inference' | ||||
| EXPLANATION = 'explanation' | EXPLANATION = 'explanation' | ||||
| HIERARCHICAL_OCCLUSION = 'hierarchical_occlusion' | |||||
| STATUS = 'status' | STATUS = 'status' | ||||
| @@ -0,0 +1,156 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Utility functions for hierarchcial occlusion image generation.""" | |||||
| import re | |||||
| import PIL | |||||
| from PIL import ImageDraw, ImageEnhance, ImageFilter | |||||
| MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' | |||||
| class EditStep: | |||||
| """ | |||||
| Class that represents an edit step. | |||||
| Args: | |||||
| layer (int): Layer index. | |||||
| x (int): Left pixel coordinate. | |||||
| y (int): Top pixel coordinate. | |||||
| width (int): Width in pixels. | |||||
| height (int): Height in pixels. | |||||
| """ | |||||
| def __init__(self, | |||||
| layer: int, | |||||
| x: int, | |||||
| y: int, | |||||
| width: int, | |||||
| height: int): | |||||
| self.layer = layer | |||||
| self.x = x | |||||
| self.y = y | |||||
| self.width = width | |||||
| self.height = height | |||||
| def to_coord_box(self): | |||||
| """ | |||||
| Convert to pixel coordinate box. | |||||
| Returns: | |||||
| tuple[int, int, int, int], tuple of left, top, right, bottom pixel coordinate. | |||||
| """ | |||||
| return self.x, self.y, self.x + self.width, self.y + self.height | |||||
| def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=False): | |||||
| """ | |||||
| Apply edit steps on a PIL image. | |||||
| Args: | |||||
| image (PIL.Image): The input image in RGB mode. | |||||
| mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string | |||||
| e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object. | |||||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||||
| by_masking (bool): Whether to use masking method. Default: False. | |||||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||||
| Returns: | |||||
| PIL.Image, the result image. | |||||
| """ | |||||
| if isinstance(mask, str): | |||||
| mask = pil_compile_str_mask(mask, image) | |||||
| if by_masking: | |||||
| return _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace) | |||||
| return _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace) | |||||
| def pil_compile_str_mask(mask, image): | |||||
| """Concert string mask to PIL Image.""" | |||||
| match = re.match(MASK_GAUSSIAN_RE, mask) | |||||
| if match: | |||||
| radius = int(match.group(1)) | |||||
| if radius > 0: | |||||
| image_filter = ImageFilter.GaussianBlur(radius=radius) | |||||
| mask_image = image.filter(image_filter) | |||||
| mask_image = ImageEnhance.Brightness(mask_image).enhance(0.7) | |||||
| mask_image = ImageEnhance.Color(mask_image).enhance(0.0) | |||||
| return mask_image | |||||
| raise ValueError(f"Invalid string mask: '{mask}'.") | |||||
| def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False): | |||||
| """ | |||||
| Apply edit steps from unmasking method on a PIL image. | |||||
| Args: | |||||
| image (PIL.Image): The input image. | |||||
| mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey | |||||
| scale intensity [0, 255], a RBG tuple or a PIL Image. | |||||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||||
| Returns: | |||||
| PIL.Image, the result image. | |||||
| """ | |||||
| if isinstance(mask, PIL.Image.Image): | |||||
| if inplace: | |||||
| bg = mask | |||||
| else: | |||||
| bg = mask.copy() | |||||
| else: | |||||
| if inplace: | |||||
| raise ValueError('Argument inplace cannot be True when mask is not a PIL Image.') | |||||
| if isinstance(mask, int): | |||||
| mask = (mask, mask, mask) | |||||
| bg = PIL.Image.new(mode="RGB", size=image.size, color=mask) | |||||
| for step in edit_steps: | |||||
| box = step.to_coord_box() | |||||
| cropped = image.crop(box) | |||||
| bg.paste(cropped, box=box) | |||||
| return bg | |||||
| def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False): | |||||
| """ | |||||
| Apply edit steps from unmasking method on a PIL image. | |||||
| Args: | |||||
| image (PIL.Image): The input image. | |||||
| mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey | |||||
| scale intensity [0, 255], a RBG tuple or a PIL Image. | |||||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||||
| Returns: | |||||
| PIL.Image, the result image. | |||||
| """ | |||||
| if not inplace: | |||||
| image = image.copy() | |||||
| if isinstance(mask, PIL.Image.Image): | |||||
| for step in edit_steps: | |||||
| box = step.to_coord_box() | |||||
| cropped = mask.crop(box) | |||||
| image.paste(cropped, box=box) | |||||
| else: | |||||
| if isinstance(mask, int): | |||||
| mask = (mask, mask, mask) | |||||
| draw = ImageDraw.Draw(image) | |||||
| for step in edit_steps: | |||||
| draw.rectangle(step.to_coord_box(), fill=mask) | |||||
| return image | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,16 +14,17 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Datafile encapsulator.""" | """Datafile encapsulator.""" | ||||
| import os | |||||
| import io | import io | ||||
| import os | |||||
| from PIL import Image | |||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image | |||||
| from mindinsight.utils.exceptions import UnknownError | |||||
| from mindinsight.utils.exceptions import FileSystemPermissionError | |||||
| from mindinsight.datavisual.common.exceptions import ImageNotExistError | from mindinsight.datavisual.common.exceptions import ImageNotExistError | ||||
| from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps | |||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | ||||
| from mindinsight.utils.exceptions import FileSystemPermissionError | |||||
| from mindinsight.utils.exceptions import UnknownError | |||||
| # Max uint8 value. for converting RGB pixels to [0,1] intensity. | # Max uint8 value. for converting RGB pixels to [0,1] intensity. | ||||
| _UINT8_MAX = 255 | _UINT8_MAX = 255 | ||||
| @@ -59,11 +60,44 @@ class DatafileEncap(ExplainDataEncap): | |||||
| Args: | Args: | ||||
| train_id (str): Job ID. | train_id (str): Job ID. | ||||
| image_path (str): Image path relative to explain job's summary directory. | image_path (str): Image path relative to explain job's summary directory. | ||||
| image_type (str): Image type, 'original' or 'overlay'. | |||||
| image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'. | |||||
| Returns: | Returns: | ||||
| bytes, image binary. | bytes, image binary. | ||||
| """ | """ | ||||
| if image_type == "outcome": | |||||
| sample_id, label, layer = image_path.strip(".jpg").split("_") | |||||
| layer = int(layer) | |||||
| job = self.job_manager.get_job(train_id) | |||||
| samples = job.samples | |||||
| label_idx = job.labels.index(label) | |||||
| chosen_sample = samples[int(sample_id)] | |||||
| original_path_image = chosen_sample['image'] | |||||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), | |||||
| original_path_image) | |||||
| if self._is_forbidden(abs_image_path): | |||||
| raise FileSystemPermissionError("Forbidden.") | |||||
| try: | |||||
| image = Image.open(abs_image_path) | |||||
| except FileNotFoundError: | |||||
| raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||||
| except PermissionError: | |||||
| raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||||
| except OSError: | |||||
| raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | |||||
| edit_steps = [] | |||||
| boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] | |||||
| mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] | |||||
| for box in boxes: | |||||
| edit_steps.append(EditStep(layer, *box)) | |||||
| image_cp = pil_apply_edit_steps(image, mask, edit_steps) | |||||
| buffer = io.BytesIO() | |||||
| image_cp.save(buffer, format=_PNG_FORMAT) | |||||
| return buffer.getvalue() | |||||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, | abs_image_path = os.path.join(self.job_manager.summary_base_dir, | ||||
| _clean_train_id_b4_join(train_id), | _clean_train_id_b4_join(train_id), | ||||
| @@ -94,10 +128,10 @@ class DatafileEncap(ExplainDataEncap): | |||||
| raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | ||||
| if image.mode == _SINGLE_CHANNEL_MODE: | if image.mode == _SINGLE_CHANNEL_MODE: | ||||
| saliency = np.asarray(image)/_UINT8_MAX | |||||
| saliency = np.asarray(image) / _UINT8_MAX | |||||
| elif image.mode == _RGB_MODE: | elif image.mode == _RGB_MODE: | ||||
| saliency = np.asarray(image) | saliency = np.asarray(image) | ||||
| saliency = saliency[:, :, 0]/_UINT8_MAX | |||||
| saliency = saliency[:, :, 0] / _UINT8_MAX | |||||
| else: | else: | ||||
| raise UnknownError(f"Invalid overlay image mode:{image.mode}.") | raise UnknownError(f"Invalid overlay image mode:{image.mode}.") | ||||
| @@ -105,7 +139,7 @@ class DatafileEncap(ExplainDataEncap): | |||||
| for c in range(3): | for c in range(3): | ||||
| saliency_stack[:, :, c] = saliency | saliency_stack[:, :, c] = saliency | ||||
| rgba = saliency_stack * _SALIENCY_CMAP_HI | rgba = saliency_stack * _SALIENCY_CMAP_HI | ||||
| rgba += (1-saliency_stack) * _SALIENCY_CMAP_LOW | |||||
| rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW | |||||
| rgba[:, :, 3] = saliency * _UINT8_MAX | rgba[:, :, 3] = saliency * _UINT8_MAX | ||||
| overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE) | overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE) | ||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,7 +12,57 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Explain data encapsulator base class.""" | |||||
| """Common explain data encapsulator base class.""" | |||||
| import copy | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| def _sort_key_min_confidence(sample, labels): | |||||
| """Samples sort key by the min. confidence.""" | |||||
| min_confidence = float("+inf") | |||||
| for inference in sample["inferences"]: | |||||
| if labels and inference["label"] not in labels: | |||||
| continue | |||||
| if inference["confidence"] < min_confidence: | |||||
| min_confidence = inference["confidence"] | |||||
| return min_confidence | |||||
| def _sort_key_max_confidence(sample, labels): | |||||
| """Samples sort key by the max. confidence.""" | |||||
| max_confidence = float("-inf") | |||||
| for inference in sample["inferences"]: | |||||
| if labels and inference["label"] not in labels: | |||||
| continue | |||||
| if inference["confidence"] > max_confidence: | |||||
| max_confidence = inference["confidence"] | |||||
| return max_confidence | |||||
| def _sort_key_min_confidence_sd(sample, labels): | |||||
| """Samples sort key by the min. confidence_sd.""" | |||||
| min_confidence_sd = float("+inf") | |||||
| for inference in sample["inferences"]: | |||||
| if labels and inference["label"] not in labels: | |||||
| continue | |||||
| confidence_sd = inference.get("confidence_sd", float("+inf")) | |||||
| if confidence_sd < min_confidence_sd: | |||||
| min_confidence_sd = confidence_sd | |||||
| return min_confidence_sd | |||||
| def _sort_key_max_confidence_sd(sample, labels): | |||||
| """Samples sort key by the max. confidence_sd.""" | |||||
| max_confidence_sd = float("-inf") | |||||
| for inference in sample["inferences"]: | |||||
| if labels and inference["label"] not in labels: | |||||
| continue | |||||
| confidence_sd = inference.get("confidence_sd", float("-inf")) | |||||
| if confidence_sd > max_confidence_sd: | |||||
| max_confidence_sd = confidence_sd | |||||
| return max_confidence_sd | |||||
| class ExplainDataEncap: | class ExplainDataEncap: | ||||
| @@ -24,3 +74,77 @@ class ExplainDataEncap: | |||||
| @property | @property | ||||
| def job_manager(self): | def job_manager(self): | ||||
| return self._job_manager | return self._job_manager | ||||
| class ExplanationEncap(ExplainDataEncap): | |||||
| """Base encapsulator for explanation queries.""" | |||||
| def __init__(self, image_url_formatter, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self._image_url_formatter = image_url_formatter | |||||
| def _query_samples(self, | |||||
| job, | |||||
| labels, | |||||
| sorted_name, | |||||
| sorted_type, | |||||
| prediction_types=None, | |||||
| query_type="saliency_maps"): | |||||
| """ | |||||
| Query samples. | |||||
| Args: | |||||
| job (ExplainManager): Explain job to be query from. | |||||
| labels (list[str]): Label filter. | |||||
| sorted_name (str): Field to be sorted. | |||||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | |||||
| prediction_types (list[str]): Prediction type filter. | |||||
| Returns: | |||||
| list[dict], samples to be queried. | |||||
| """ | |||||
| samples = copy.deepcopy(job.get_all_samples()) | |||||
| samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])] | |||||
| if labels: | |||||
| filtered = [] | |||||
| for sample in samples: | |||||
| infer_labels = [inference["label"] for inference in sample["inferences"]] | |||||
| for infer_label in infer_labels: | |||||
| if infer_label in labels: | |||||
| filtered.append(sample) | |||||
| break | |||||
| samples = filtered | |||||
| if prediction_types and len(prediction_types) < 3: | |||||
| filtered = [] | |||||
| for sample in samples: | |||||
| infer_types = [inference["prediction_type"] for inference in sample["inferences"]] | |||||
| for infer_type in infer_types: | |||||
| if infer_type in prediction_types: | |||||
| filtered.append(sample) | |||||
| break | |||||
| samples = filtered | |||||
| reverse = sorted_type == "descending" | |||||
| if sorted_name == "confidence": | |||||
| if reverse: | |||||
| samples.sort(key=lambda x: _sort_key_max_confidence(x, labels), reverse=reverse) | |||||
| else: | |||||
| samples.sort(key=lambda x: _sort_key_min_confidence(x, labels), reverse=reverse) | |||||
| elif sorted_name == "uncertainty": | |||||
| if not job.uncertainty_enabled: | |||||
| raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'") | |||||
| if reverse: | |||||
| samples.sort(key=lambda x: _sort_key_max_confidence_sd(x, labels), reverse=reverse) | |||||
| else: | |||||
| samples.sort(key=lambda x: _sort_key_min_confidence_sd(x, labels), reverse=reverse) | |||||
| elif sorted_name != "": | |||||
| raise ParamValueError("sorted_name") | |||||
| return samples | |||||
| def _get_image_url(self, train_id, image_path, image_type): | |||||
| """Returns image's url.""" | |||||
| if self._image_url_formatter is None: | |||||
| return image_path | |||||
| return self._image_url_formatter(train_id, image_path, image_type) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Explain job list encapsulator.""" | """Explain job list encapsulator.""" | ||||
| import copy | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | ||||
| @@ -61,6 +60,9 @@ class ExplainJobEncap(ExplainDataEncap): | |||||
| info["train_id"] = dir_info["relative_path"] | info["train_id"] = dir_info["relative_path"] | ||||
| info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT) | info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT) | ||||
| info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT) | info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT) | ||||
| info["saliency_map"] = dir_info["saliency_map"] | |||||
| info["hierarchical_occlusion"] = dir_info["hierarchical_occlusion"] | |||||
| return info | return info | ||||
| @classmethod | @classmethod | ||||
| @@ -79,7 +81,7 @@ class ExplainJobEncap(ExplainDataEncap): | |||||
| """Convert ExplainJob's meta-data to jsonable info object.""" | """Convert ExplainJob's meta-data to jsonable info object.""" | ||||
| info = cls._job_2_info(job) | info = cls._job_2_info(job) | ||||
| info["sample_count"] = job.sample_count | info["sample_count"] = job.sample_count | ||||
| info["classes"] = copy.deepcopy(job.all_classes) | |||||
| info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0] | |||||
| saliency_info = dict() | saliency_info = dict() | ||||
| if job.min_confidence is None: | if job.min_confidence is None: | ||||
| saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE | saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE | ||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Hierarchical Occlusion encapsulator.""" | |||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||||
| class HierarchicalOcclusionEncap(ExplanationEncap): | |||||
| """Hierarchical occlusion encapsulator.""" | |||||
| def query_hierarchical_occlusion(self, | |||||
| train_id, | |||||
| labels, | |||||
| limit, | |||||
| offset, | |||||
| sorted_name, | |||||
| sorted_type, | |||||
| prediction_types=None | |||||
| ): | |||||
| """ | |||||
| Query hierarchical occlusion results. | |||||
| Args: | |||||
| train_id (str): Job ID. | |||||
| labels (list[str]): Label filter. | |||||
| limit (int): Maximum number of items to be returned. | |||||
| offset (int): Page offset. | |||||
| sorted_name (str): Field to be sorted. | |||||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | |||||
| prediction_types (list[str]): Prediction types filter. | |||||
| Returns: | |||||
| tuple[int, list[dict]], total number of samples after filtering and list of sample results. | |||||
| """ | |||||
| job = self.job_manager.get_job(train_id) | |||||
| if job is None: | |||||
| raise TrainJobNotExistError(train_id) | |||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||||
| query_type="hoc_layers") | |||||
| sample_infos = [] | |||||
| obj_offset = offset * limit | |||||
| count = len(samples) | |||||
| end = count | |||||
| if obj_offset + limit < end: | |||||
| end = obj_offset + limit | |||||
| for i in range(obj_offset, end): | |||||
| sample = samples[i] | |||||
| sample_infos.append(self._touch_sample(sample, job)) | |||||
| return count, sample_infos | |||||
| def _touch_sample(self, sample, job): | |||||
| """ | |||||
| Final edit on single sample info. | |||||
| Args: | |||||
| sample (dict): Sample info. | |||||
| job (ExplainManager): Explain job. | |||||
| Returns: | |||||
| dict, the edited sample info. | |||||
| """ | |||||
| sample_cp = sample.copy() | |||||
| sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original") | |||||
| for inference_item in sample_cp["inferences"]: | |||||
| new_list = [] | |||||
| for idx, hoc_layer in enumerate(inference_item["hoc_layers"]): | |||||
| hoc_layer["outcome"] = self._get_image_url(job.train_id, | |||||
| f"{sample['id']}_{inference_item['label']}_{idx}.jpg", | |||||
| "outcome") | |||||
| new_list.append(hoc_layer) | |||||
| inference_item["hoc_layers"] = new_list | |||||
| return sample_cp | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -14,58 +14,13 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Saliency map encapsulator.""" | """Saliency map encapsulator.""" | ||||
| import copy | |||||
| from mindinsight.utils.exceptions import ParamValueError | |||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | |||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | ||||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||||
| def _sort_key_min_confidence(sample): | |||||
| """Samples sort key by the min. confidence.""" | |||||
| min_confidence = float("+inf") | |||||
| for inference in sample["inferences"]: | |||||
| if inference["confidence"] < min_confidence: | |||||
| min_confidence = inference["confidence"] | |||||
| return min_confidence | |||||
| def _sort_key_max_confidence(sample): | |||||
| """Samples sort key by the max. confidence.""" | |||||
| max_confidence = float("-inf") | |||||
| for inference in sample["inferences"]: | |||||
| if inference["confidence"] > max_confidence: | |||||
| max_confidence = inference["confidence"] | |||||
| return max_confidence | |||||
| def _sort_key_min_confidence_sd(sample): | |||||
| """Samples sort key by the min. confidence_sd.""" | |||||
| min_confidence_sd = float("+inf") | |||||
| for inference in sample["inferences"]: | |||||
| confidence_sd = inference.get("confidence_sd", float("+inf")) | |||||
| if confidence_sd < min_confidence_sd: | |||||
| min_confidence_sd = confidence_sd | |||||
| return min_confidence_sd | |||||
| def _sort_key_max_confidence_sd(sample): | |||||
| """Samples sort key by the max. confidence_sd.""" | |||||
| max_confidence_sd = float("-inf") | |||||
| for inference in sample["inferences"]: | |||||
| confidence_sd = inference.get("confidence_sd", float("-inf")) | |||||
| if confidence_sd > max_confidence_sd: | |||||
| max_confidence_sd = confidence_sd | |||||
| return max_confidence_sd | |||||
| class SaliencyEncap(ExplainDataEncap): | |||||
| class SaliencyEncap(ExplanationEncap): | |||||
| """Saliency map encapsulator.""" | """Saliency map encapsulator.""" | ||||
| def __init__(self, image_url_formatter, *args, **kwargs): | |||||
| super().__init__(*args, **kwargs) | |||||
| self._image_url_formatter = image_url_formatter | |||||
| def query_saliency_maps(self, | def query_saliency_maps(self, | ||||
| train_id, | train_id, | ||||
| labels, | labels, | ||||
| @@ -73,55 +28,32 @@ class SaliencyEncap(ExplainDataEncap): | |||||
| limit, | limit, | ||||
| offset, | offset, | ||||
| sorted_name, | sorted_name, | ||||
| sorted_type): | |||||
| sorted_type, | |||||
| prediction_types=None): | |||||
| """ | """ | ||||
| Query saliency maps. | Query saliency maps. | ||||
| Args: | Args: | ||||
| train_id (str): Job ID. | train_id (str): Job ID. | ||||
| labels (list[str]): Label filter. | labels (list[str]): Label filter. | ||||
| explainers (list[str]): Explainers of saliency maps to be shown. | explainers (list[str]): Explainers of saliency maps to be shown. | ||||
| limit (int): Max. no. of items to be returned. | |||||
| limit (int): Maximum number of items to be returned. | |||||
| offset (int): Page offset. | offset (int): Page offset. | ||||
| sorted_name (str): Field to be sorted. | sorted_name (str): Field to be sorted. | ||||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | sorted_type (str): Sorting order, 'ascending' or 'descending'. | ||||
| prediction_types (list[str]): Prediction types filter. Default: None. | |||||
| Returns: | Returns: | ||||
| tuple[int, list[dict]], total no. of samples after filtering and | |||||
| list of sample result. | |||||
| tuple[int, list[dict]], total number of samples after filtering and list of sample result. | |||||
| """ | """ | ||||
| job = self.job_manager.get_job(train_id) | job = self.job_manager.get_job(train_id) | ||||
| if job is None: | if job is None: | ||||
| raise TrainJobNotExistError(train_id) | raise TrainJobNotExistError(train_id) | ||||
| samples = copy.deepcopy(job.get_all_samples()) | |||||
| if labels: | |||||
| filtered = [] | |||||
| for sample in samples: | |||||
| infer_labels = [inference["label"] for inference in sample["inferences"]] | |||||
| for infer_label in infer_labels: | |||||
| if infer_label in labels: | |||||
| filtered.append(sample) | |||||
| break | |||||
| samples = filtered | |||||
| reverse = sorted_type == "descending" | |||||
| if sorted_name == "confidence": | |||||
| if reverse: | |||||
| samples.sort(key=_sort_key_max_confidence, reverse=reverse) | |||||
| else: | |||||
| samples.sort(key=_sort_key_min_confidence, reverse=reverse) | |||||
| elif sorted_name == "uncertainty": | |||||
| if not job.uncertainty_enabled: | |||||
| raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'") | |||||
| if reverse: | |||||
| samples.sort(key=_sort_key_max_confidence_sd, reverse=reverse) | |||||
| else: | |||||
| samples.sort(key=_sort_key_min_confidence_sd, reverse=reverse) | |||||
| elif sorted_name != "": | |||||
| raise ParamValueError("sorted_name") | |||||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||||
| query_type="saliency_maps") | |||||
| sample_infos = [] | sample_infos = [] | ||||
| obj_offset = offset*limit | |||||
| obj_offset = offset * limit | |||||
| count = len(samples) | count = len(samples) | ||||
| end = count | end = count | ||||
| if obj_offset + limit < end: | if obj_offset + limit < end: | ||||
| @@ -139,11 +71,13 @@ class SaliencyEncap(ExplainDataEncap): | |||||
| sample (dict): Sample info. | sample (dict): Sample info. | ||||
| job (ExplainJob): Explain job. | job (ExplainJob): Explain job. | ||||
| explainers (list[str]): Explainer names. | explainers (list[str]): Explainer names. | ||||
| Returns: | Returns: | ||||
| dict, the edited sample info. | dict, the edited sample info. | ||||
| """ | """ | ||||
| sample["image"] = self._get_image_url(job.train_id, sample['image'], "original") | |||||
| for inference in sample["inferences"]: | |||||
| sample_cp = sample.copy() | |||||
| sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") | |||||
| for inference in sample_cp["inferences"]: | |||||
| new_list = [] | new_list = [] | ||||
| for saliency_map in inference["saliency_maps"]: | for saliency_map in inference["saliency_maps"]: | ||||
| if explainers and saliency_map["explainer"] not in explainers: | if explainers and saliency_map["explainer"] not in explainers: | ||||
| @@ -151,10 +85,4 @@ class SaliencyEncap(ExplainDataEncap): | |||||
| saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") | saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") | ||||
| new_list.append(saliency_map) | new_list.append(saliency_map) | ||||
| inference["saliency_maps"] = new_list | inference["saliency_maps"] = new_list | ||||
| return sample | |||||
| def _get_image_url(self, train_id, image_path, image_type): | |||||
| """Returns image's url.""" | |||||
| if self._image_url_formatter is None: | |||||
| return image_path | |||||
| return self._image_url_formatter(train_id, image_path, image_type) | |||||
| return sample_cp | |||||
| @@ -14,14 +14,13 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ExplainLoader.""" | """ExplainLoader.""" | ||||
| from collections import defaultdict | |||||
| from enum import Enum | |||||
| import math | import math | ||||
| import os | import os | ||||
| import re | import re | ||||
| import threading | import threading | ||||
| from collections import defaultdict | |||||
| from datetime import datetime | from datetime import datetime | ||||
| from enum import Enum | |||||
| from typing import Dict, Iterable, List, Optional, Union | from typing import Dict, Iterable, List, Optional, Union | ||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | ||||
| @@ -44,6 +43,7 @@ _SAMPLE_FIELD_NAMES = [ | |||||
| ExplainFieldsEnum.GROUND_TRUTH_LABEL, | ExplainFieldsEnum.GROUND_TRUTH_LABEL, | ||||
| ExplainFieldsEnum.INFERENCE, | ExplainFieldsEnum.INFERENCE, | ||||
| ExplainFieldsEnum.EXPLANATION, | ExplainFieldsEnum.EXPLANATION, | ||||
| ExplainFieldsEnum.HIERARCHICAL_OCCLUSION | |||||
| ] | ] | ||||
| @@ -75,10 +75,10 @@ class ExplainLoader: | |||||
| 'create_time': os.stat(summary_dir).st_ctime, | 'create_time': os.stat(summary_dir).st_ctime, | ||||
| 'update_time': os.stat(summary_dir).st_mtime, | 'update_time': os.stat(summary_dir).st_mtime, | ||||
| 'query_time': os.stat(summary_dir).st_ctime, | 'query_time': os.stat(summary_dir).st_ctime, | ||||
| 'uncertainty_enabled': False, | |||||
| 'uncertainty_enabled': False | |||||
| } | } | ||||
| self._samples = defaultdict(dict) | self._samples = defaultdict(dict) | ||||
| self._metadata = {'explainers': [], 'metrics': [], 'labels': []} | |||||
| self._metadata = {'explainers': [], 'metrics': [], 'labels': [], 'min_confidence': 0.5} | |||||
| self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} | self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} | ||||
| self._status = _LoaderStatus.STOP.value | self._status = _LoaderStatus.STOP.value | ||||
| @@ -91,25 +91,19 @@ class ExplainLoader: | |||||
| Returns: | Returns: | ||||
| list[dict], a list of dict, each dict contains: | list[dict], a list of dict, each dict contains: | ||||
| - id (int): label id | |||||
| - label (str): label name | |||||
| - sample_count (int): number of samples for each label | |||||
| - id (int): Label id. | |||||
| - label (str): Label name. | |||||
| - sample_count (int): Number of samples for each label. | |||||
| """ | """ | ||||
| sample_count_per_label = defaultdict(int) | sample_count_per_label = defaultdict(int) | ||||
| samples_copy = self._samples.copy() | |||||
| for sample in samples_copy.values(): | |||||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||||
| for label in sample['ground_truth_label']: | |||||
| for sample in self._samples.values(): | |||||
| if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')): | |||||
| for label in set(sample['ground_truth_label'] + sample['predicted_label']): | |||||
| sample_count_per_label[label] += 1 | sample_count_per_label[label] += 1 | ||||
| all_classes_return = [] | |||||
| for label_id, label_name in enumerate(self._metadata['labels']): | |||||
| single_info = { | |||||
| 'id': label_id, | |||||
| 'label': label_name, | |||||
| 'sample_count': sample_count_per_label[label_id] | |||||
| } | |||||
| all_classes_return.append(single_info) | |||||
| all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]} | |||||
| for label_id, label_name in enumerate(self._metadata['labels'])] | |||||
| return all_classes_return | return all_classes_return | ||||
| @property | @property | ||||
| @@ -166,19 +160,23 @@ class ExplainLoader: | |||||
| Returns: | Returns: | ||||
| list[dict], A list of evaluation results of each explainer. Each item contains: | list[dict], A list of evaluation results of each explainer. Each item contains: | ||||
| - explainer (str): Name of evaluated explainer. | - explainer (str): Name of evaluated explainer. | ||||
| - evaluations (list[dict]): A list of evaluation results by different metrics. | - evaluations (list[dict]): A list of evaluation results by different metrics. | ||||
| - class_scores (list[dict]): A list of evaluation results on different labels. | - class_scores (list[dict]): A list of evaluation results on different labels. | ||||
| Each item in the evaluations contains: | Each item in the evaluations contains: | ||||
| - metric (str): name of metric method | - metric (str): name of metric method | ||||
| - score (float): evaluation result | - score (float): evaluation result | ||||
| Each item in the class_scores contains: | Each item in the class_scores contains: | ||||
| - label (str): Name of label | - label (str): Name of label | ||||
| - evaluations (list[dict]): A list of evaluation results on different labels by different metrics. | - evaluations (list[dict]): A list of evaluation results on different labels by different metrics. | ||||
| Each item in evaluations contains: | Each item in evaluations contains: | ||||
| - metric (str): Name of metric method | - metric (str): Name of metric method | ||||
| - score (float): Evaluation scores of explainer on specific label by the metric. | - score (float): Evaluation scores of explainer on specific label by the metric. | ||||
| """ | """ | ||||
| @@ -215,7 +213,7 @@ class ExplainLoader: | |||||
| @property | @property | ||||
| def min_confidence(self) -> Optional[float]: | def min_confidence(self) -> Optional[float]: | ||||
| """Return minimum confidence used to filter the predicted labels.""" | """Return minimum confidence used to filter the predicted labels.""" | ||||
| return None | |||||
| return self._metadata['min_confidence'] | |||||
| @property | @property | ||||
| def sample_count(self) -> int: | def sample_count(self) -> int: | ||||
| @@ -227,19 +225,17 @@ class ExplainLoader: | |||||
| Return: | Return: | ||||
| int, total number of available samples in the loading job. | int, total number of available samples in the loading job. | ||||
| """ | """ | ||||
| sample_count = 0 | sample_count = 0 | ||||
| samples_copy = self._samples.copy() | |||||
| for sample in samples_copy.values(): | |||||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||||
| for sample in self._samples.values(): | |||||
| if sample.get('image', False): | |||||
| sample_count += 1 | sample_count += 1 | ||||
| return sample_count | return sample_count | ||||
| @property | @property | ||||
| def samples(self) -> List[Dict]: | def samples(self) -> List[Dict]: | ||||
| """Return the information of all samples in the job.""" | """Return the information of all samples in the job.""" | ||||
| return self.get_all_samples() | |||||
| return self._samples | |||||
| @property | @property | ||||
| def train_id(self) -> str: | def train_id(self) -> str: | ||||
| @@ -298,6 +294,7 @@ class ExplainLoader: | |||||
| self._clear_job() | self._clear_job() | ||||
| if event_dict: | if event_dict: | ||||
| self._import_data_from_event(event_dict) | self._import_data_from_event(event_dict) | ||||
| self._reform_sample_info() | |||||
| @property | @property | ||||
| def status(self): | def status(self): | ||||
| @@ -317,59 +314,19 @@ class ExplainLoader: | |||||
| def get_all_samples(self) -> List[Dict]: | def get_all_samples(self) -> List[Dict]: | ||||
| """ | """ | ||||
| Return a list of sample information cachced in the explain job | |||||
| Return a list of sample information cached in the explain job | |||||
| Returns: | Returns: | ||||
| sample_list (List[SampleObj]): a list of sample objects, each object | sample_list (List[SampleObj]): a list of sample objects, each object | ||||
| consists of: | consists of: | ||||
| - id (int): sample id | |||||
| - name (str): basename of image | |||||
| - labels (list[str]): list of labels | |||||
| - inferences list[dict]) | |||||
| - id (int): Sample id. | |||||
| - name (str): Basename of image. | |||||
| - inferences (list[dict]): List of inferences for all labels. | |||||
| """ | """ | ||||
| returned_samples = [] | |||||
| samples_copy = self._samples.copy() | |||||
| for sample_id, sample_info in samples_copy.items(): | |||||
| if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): | |||||
| continue | |||||
| returned_sample = { | |||||
| 'id': sample_id, | |||||
| 'name': str(sample_id), | |||||
| 'image': sample_info['image'], | |||||
| 'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']], | |||||
| } | |||||
| # Check whether the sample has valid label-prob pairs. | |||||
| if not ExplainLoader._is_inference_valid(sample_info): | |||||
| continue | |||||
| inferences = {} | |||||
| for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||||
| sample_info['ground_truth_prob'] + sample_info['predicted_prob']): | |||||
| inferences[label] = { | |||||
| 'label': self._metadata['labels'][label], | |||||
| 'confidence': _round(prob), | |||||
| 'saliency_maps': [] | |||||
| } | |||||
| if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: | |||||
| for label, std, low, high in zip( | |||||
| sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||||
| sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], | |||||
| sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], | |||||
| sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] | |||||
| ): | |||||
| inferences[label]['confidence_sd'] = _round(std) | |||||
| inferences[label]['confidence_itl95'] = [_round(low), _round(high)] | |||||
| for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): | |||||
| for label, heatmap_path in label_heatmap_path_dict.items(): | |||||
| if label in inferences: | |||||
| inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) | |||||
| returned_sample['inferences'] = list(inferences.values()) | |||||
| returned_samples.append(returned_sample) | |||||
| returned_samples = [{'id': sample_id, 'name': info['name'], 'image': info['image'], | |||||
| 'inferences': list(info['inferences'].values())} for sample_id, info in | |||||
| self._samples.items() if info.get('image', False)] | |||||
| return returned_samples | return returned_samples | ||||
| def _import_data_from_event(self, event_dict: Dict): | def _import_data_from_event(self, event_dict: Dict): | ||||
| @@ -461,30 +418,29 @@ class ExplainLoader: | |||||
| - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. | - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. | ||||
| - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary | - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary | ||||
| of saliency maps. The structure of explanations demonstrates below: | of saliency maps. The structure of explanations demonstrates below: | ||||
| { | { | ||||
| explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | ||||
| explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | ||||
| ... | ... | ||||
| } | } | ||||
| - hierarchical_occlusion (dict): A dictionary where each label is matched to a dictionary: | |||||
| {label_1: [{prob: layer1_prob, bbox: []}, {prob: layer2_prob, bbox: []}], | |||||
| label_2: | |||||
| } | |||||
| """ | """ | ||||
| if getattr(sample, 'sample_id', None) is None: | if getattr(sample, 'sample_id', None) is None: | ||||
| raise ParamValueError('sample_event has no sample_id') | raise ParamValueError('sample_event has no sample_id') | ||||
| sample_id = sample.sample_id | sample_id = sample.sample_id | ||||
| samples_copy = self._samples.copy() | |||||
| if sample_id not in samples_copy: | |||||
| if sample_id not in self._samples: | |||||
| self._samples[sample_id] = { | self._samples[sample_id] = { | ||||
| 'id': sample_id, | |||||
| 'name': str(sample_id), | |||||
| 'image': sample.image_path, | |||||
| 'ground_truth_label': [], | 'ground_truth_label': [], | ||||
| 'ground_truth_prob': [], | |||||
| 'ground_truth_prob_sd': [], | |||||
| 'ground_truth_prob_itl95_low': [], | |||||
| 'ground_truth_prob_itl95_hi': [], | |||||
| 'predicted_label': [], | 'predicted_label': [], | ||||
| 'predicted_prob': [], | |||||
| 'predicted_prob_sd': [], | |||||
| 'predicted_prob_itl95_low': [], | |||||
| 'predicted_prob_itl95_hi': [], | |||||
| 'explanation': defaultdict(dict) | |||||
| 'inferences': defaultdict(dict), | |||||
| 'explanation': defaultdict(dict), | |||||
| 'hierarchical_occlusion': defaultdict(dict) | |||||
| } | } | ||||
| if sample.image_path: | if sample.image_path: | ||||
| @@ -492,27 +448,66 @@ class ExplainLoader: | |||||
| for tag in _SAMPLE_FIELD_NAMES: | for tag in _SAMPLE_FIELD_NAMES: | ||||
| if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: | if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: | ||||
| self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) | |||||
| if not self._samples[sample_id]['ground_truth_label']: | |||||
| self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) | |||||
| elif tag == ExplainFieldsEnum.INFERENCE: | elif tag == ExplainFieldsEnum.INFERENCE: | ||||
| self._import_inference_from_event(sample, sample_id) | self._import_inference_from_event(sample, sample_id) | ||||
| else: | |||||
| elif tag == ExplainFieldsEnum.EXPLANATION: | |||||
| self._import_explanation_from_event(sample, sample_id) | self._import_explanation_from_event(sample, sample_id) | ||||
| elif tag == ExplainFieldsEnum.HIERARCHICAL_OCCLUSION: | |||||
| self._import_hoc_from_event(sample, sample_id) | |||||
| def _reform_sample_info(self): | |||||
| """Reform the sample info.""" | |||||
| for _, sample_info in self._samples.items(): | |||||
| inferences = sample_info['inferences'] | |||||
| res_dict = defaultdict(list) | |||||
| for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): | |||||
| for label, heatmap_path in label_heatmap_path_dict.items(): | |||||
| res_dict[label].append({'explainer': explainer, 'overlay': heatmap_path}) | |||||
| for label, item in inferences.items(): | |||||
| item['saliency_maps'] = res_dict[label] | |||||
| for label, item in sample_info['hierarchical_occlusion'].items(): | |||||
| inferences[label]['hoc_layers'] = item['hoc_layers'] | |||||
| def _import_inference_from_event(self, event, sample_id): | def _import_inference_from_event(self, event, sample_id): | ||||
| """Parse the inference event.""" | """Parse the inference event.""" | ||||
| inference = event.inference | inference = event.inference | ||||
| self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) | |||||
| self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) | |||||
| self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) | |||||
| self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) | |||||
| self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) | |||||
| self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) | |||||
| self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) | |||||
| self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) | |||||
| self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) | |||||
| if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: | |||||
| if inference.ground_truth_prob_sd or inference.predicted_prob_sd: | |||||
| self._loader_info['uncertainty_enabled'] = True | self._loader_info['uncertainty_enabled'] = True | ||||
| if not self._samples[sample_id]['predicted_label']: | |||||
| self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) | |||||
| if not self._samples[sample_id]['inferences']: | |||||
| inferences = {} | |||||
| for label, prob in zip(list(event.ground_truth_label) + list(inference.predicted_label), | |||||
| list(inference.ground_truth_prob) + list(inference.predicted_prob)): | |||||
| inferences[label] = { | |||||
| 'label': self._metadata['labels'][label], | |||||
| 'confidence': _round(prob), | |||||
| 'saliency_maps': [], | |||||
| 'hoc_layers': {}, | |||||
| } | |||||
| if not event.ground_truth_label: | |||||
| inferences[label]['prediction_type'] = None | |||||
| else: | |||||
| if prob < self.min_confidence: | |||||
| inferences[label]['prediction_type'] = 'FN' | |||||
| elif label in event.ground_truth_label: | |||||
| inferences[label]['prediction_type'] = 'TP' | |||||
| else: | |||||
| inferences[label]['prediction_type'] = 'FP' | |||||
| if self._loader_info['uncertainty_enabled']: | |||||
| for label, std, low, high in zip( | |||||
| list(event.ground_truth_label) + list(inference.predicted_label), | |||||
| list(inference.ground_truth_prob_sd) + list(inference.predicted_prob_sd), | |||||
| list(inference.ground_truth_prob_itl95_low) + list(inference.predicted_prob_itl95_low), | |||||
| list(inference.ground_truth_prob_itl95_hi) + list(inference.predicted_prob_itl95_hi)): | |||||
| inferences[label]['confidence_sd'] = _round(std) | |||||
| inferences[label]['confidence_itl95'] = [_round(low), _round(high)] | |||||
| self._samples[sample_id]['inferences'] = inferences | |||||
| def _import_explanation_from_event(self, event, sample_id): | def _import_explanation_from_event(self, event, sample_id): | ||||
| """Parse the explanation event.""" | """Parse the explanation event.""" | ||||
| @@ -525,6 +520,24 @@ class ExplainLoader: | |||||
| label = explanation_item.label | label = explanation_item.label | ||||
| sample_explanation[explainer][label] = explanation_item.heatmap_path | sample_explanation[explainer][label] = explanation_item.heatmap_path | ||||
| def _import_hoc_from_event(self, event, sample_id): | |||||
| """Parse the mango event.""" | |||||
| sample_hoc = self._samples[sample_id]['hierarchical_occlusion'] | |||||
| if event.hierarchical_occlusion: | |||||
| for hoc_item in event.hierarchical_occlusion: | |||||
| label = hoc_item.label | |||||
| sample_hoc[label] = {} | |||||
| sample_hoc[label]['label'] = label | |||||
| sample_hoc[label]['mask'] = hoc_item.mask | |||||
| sample_hoc[label]['confidence'] = self._samples[sample_id]['inferences'][label]['confidence'] | |||||
| sample_hoc[label]['hoc_layers'] = [] | |||||
| for hoc_layer in hoc_item.layer: | |||||
| sample_hoc_dict = {'confidence': hoc_layer.prob} | |||||
| box_lst = list(hoc_layer.box) | |||||
| box = [box_lst[i: i + 4] for i in range(0, len(hoc_layer.box), 4)] | |||||
| sample_hoc_dict['boxes'] = box | |||||
| sample_hoc[label]['hoc_layers'].append(sample_hoc_dict) | |||||
| def _clear_job(self): | def _clear_job(self): | ||||
| """Clear the cached data and update the time info of the loader.""" | """Clear the cached data and update the time info of the loader.""" | ||||
| self._samples.clear() | self._samples.clear() | ||||
| @@ -114,6 +114,7 @@ class ExplainManager: | |||||
| Returns: | Returns: | ||||
| tuple[total, directories], total indicates the overall number of explain directories and directories | tuple[total, directories], total indicates the overall number of explain directories and directories | ||||
| indicate list of summary directory info including the following attributes. | indicate list of summary directory info including the following attributes. | ||||
| - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, | - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, | ||||
| starting with "./". | starting with "./". | ||||
| - create_time (datetime): Creation time of summary file. | - create_time (datetime): Creation time of summary file. | ||||
| @@ -30,6 +30,7 @@ from mindinsight.utils.exceptions import UnknownError | |||||
| HEADER_SIZE = 8 | HEADER_SIZE = 8 | ||||
| CRC_STR_SIZE = 4 | CRC_STR_SIZE = 4 | ||||
| MAX_EVENT_STRING = 500000000 | MAX_EVENT_STRING = 500000000 | ||||
| BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) | BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) | ||||
| MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) | MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) | ||||
| InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | ||||
| @@ -42,7 +43,7 @@ InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | |||||
| 'predicted_prob_itl95_low', | 'predicted_prob_itl95_low', | ||||
| 'predicted_prob_itl95_hi']) | 'predicted_prob_itl95_hi']) | ||||
| SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', | SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', | ||||
| 'explanation', 'status']) | |||||
| 'explanation', 'hierarchical_occlusion', 'status']) | |||||
| class ExplainParser(_SummaryParser): | class ExplainParser(_SummaryParser): | ||||
| @@ -193,6 +194,7 @@ class ExplainParser(_SummaryParser): | |||||
| ground_truth_label=tensor_event_value.ground_truth_label, | ground_truth_label=tensor_event_value.ground_truth_label, | ||||
| inference=inference, | inference=inference, | ||||
| explanation=tensor_event_value.explanation, | explanation=tensor_event_value.explanation, | ||||
| hierarchical_occlusion=tensor_event_value.hoc, | |||||
| status=tensor_event_value.status | status=tensor_event_value.status | ||||
| ) | ) | ||||
| return sample_data | return sample_data | ||||
| @@ -105,7 +105,9 @@ class MockExplainManager: | |||||
| { | { | ||||
| "relative_path": "./mock_job_1", | "relative_path": "./mock_job_1", | ||||
| "create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), | "create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), | ||||
| "update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT) | |||||
| "update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), | |||||
| "saliency_map": True, | |||||
| "hierarchical_occlusion": True | |||||
| } | } | ||||
| ] | ] | ||||
| return 1, job_list | return 1, job_list | ||||
| @@ -31,7 +31,9 @@ class TestExplainJobEncap: | |||||
| { | { | ||||
| "train_id": "./mock_job_1", | "train_id": "./mock_job_1", | ||||
| "create_time": "2020-10-01 20:21:23", | "create_time": "2020-10-01 20:21:23", | ||||
| "update_time": "2020-10-01 20:21:23" | |||||
| "update_time": "2020-10-01 20:21:23", | |||||
| "saliency_map": True, | |||||
| "hierarchical_occlusion": True | |||||
| } | } | ||||
| ]) | ]) | ||||
| assert job_list == expected_result | assert job_list == expected_result | ||||
| @@ -24,10 +24,6 @@ from mindinsight.explainer.manager.explain_loader import _LoaderStatus | |||||
| from mindinsight.explainer.manager.explain_parser import ExplainParser | from mindinsight.explainer.manager.explain_parser import ExplainParser | ||||
| def abc(): | |||||
| FileHandler.is_file('aaa') | |||||
| print('after') | |||||
| class TestExplainLoader: | class TestExplainLoader: | ||||
| """Test explain loader class.""" | """Test explain loader class.""" | ||||