From: @lixiaohui33 Reviewed-by: @ouwenchang Signed-off-by:tags/v1.2.0-rc1
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,8 +14,8 @@ | |||
| # ============================================================================ | |||
| """Explainer restful api.""" | |||
| import os | |||
| import json | |||
| import os | |||
| import urllib.parse | |||
| from flask import Blueprint | |||
| @@ -23,16 +23,18 @@ from flask import jsonify | |||
| from flask import request | |||
| from mindinsight.conf import settings | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.datavisual.utils.tools import get_train_id | |||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||
| from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | |||
| from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap | |||
| from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | |||
| from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap | |||
| from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | |||
| from mindinsight.explainer.encapsulator.hierarchical_occlusion_encap import HierarchicalOcclusionEncap | |||
| from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | |||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||
| from mindinsight.utils.exceptions import ParamMissError | |||
| from mindinsight.utils.exceptions import ParamTypeError | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX | |||
| BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | |||
| @@ -66,6 +68,50 @@ def _read_post_request(post_request): | |||
| return body | |||
| def _get_query_sample_parameters(data): | |||
| """Get parameter for query.""" | |||
| train_id = data.get("train_id") | |||
| if train_id is None: | |||
| raise ParamMissError('train_id') | |||
| labels = data.get("labels") | |||
| if labels is not None and not isinstance(labels, list): | |||
| raise ParamTypeError("labels", (list, None)) | |||
| if labels: | |||
| for item in labels: | |||
| if not isinstance(item, str): | |||
| raise ParamTypeError("element of labels", str) | |||
| limit = data.get("limit", 10) | |||
| limit = Validation.check_limit(limit, min_value=1, max_value=100) | |||
| offset = data.get("offset", 0) | |||
| offset = Validation.check_offset(offset=offset) | |||
| sorted_name = data.get("sorted_name", "") | |||
| sorted_type = data.get("sorted_type", "descending") | |||
| if sorted_name not in ("", "confidence", "uncertainty"): | |||
| raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") | |||
| if sorted_type not in ("ascending", "descending"): | |||
| raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") | |||
| prediction_types = data.get("prediction_types") | |||
| if prediction_types is not None and not isinstance(prediction_types, list): | |||
| raise ParamTypeError("prediction_types", (list, None)) | |||
| if prediction_types: | |||
| for item in prediction_types: | |||
| if item not in ['TP', 'FN', 'FP']: | |||
| raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.") | |||
| query_kwarg = {"train_id": train_id, | |||
| "labels": labels, | |||
| "limit": limit, | |||
| "offset": offset, | |||
| "sorted_name": sorted_name, | |||
| "sorted_type": sorted_type, | |||
| "prediction_types": prediction_types} | |||
| return query_kwarg | |||
| @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) | |||
| def query_explain_jobs(): | |||
| """Query explain jobs.""" | |||
| @@ -99,37 +145,40 @@ def query_explain_job(): | |||
| @BLUEPRINT.route("/explainer/saliency", methods=["POST"]) | |||
| def query_saliency(): | |||
| """Query saliency map related results.""" | |||
| data = _read_post_request(request) | |||
| train_id = data.get("train_id") | |||
| if train_id is None: | |||
| raise ParamMissError('train_id') | |||
| labels = data.get("labels") | |||
| query_kwarg = _get_query_sample_parameters(data) | |||
| explainers = data.get("explainers") | |||
| limit = data.get("limit", 10) | |||
| limit = Validation.check_limit(limit, min_value=1, max_value=100) | |||
| offset = data.get("offset", 0) | |||
| offset = Validation.check_offset(offset=offset) | |||
| sorted_name = data.get("sorted_name", "") | |||
| sorted_type = data.get("sorted_type", "descending") | |||
| if explainers is not None and not isinstance(explainers, list): | |||
| raise ParamTypeError("explainers", (list, None)) | |||
| if explainers: | |||
| for item in explainers: | |||
| if not isinstance(item, str): | |||
| raise ParamTypeError("element of explainers", str) | |||
| if sorted_name not in ("", "confidence", "uncertainty"): | |||
| raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") | |||
| if sorted_type not in ("ascending", "descending"): | |||
| raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") | |||
| query_kwarg["explainers"] = explainers | |||
| encapsulator = SaliencyEncap( | |||
| _image_url_formatter, | |||
| EXPLAIN_MANAGER) | |||
| count, samples = encapsulator.query_saliency_maps(train_id=train_id, | |||
| labels=labels, | |||
| explainers=explainers, | |||
| limit=limit, | |||
| offset=offset, | |||
| sorted_name=sorted_name, | |||
| sorted_type=sorted_type) | |||
| count, samples = encapsulator.query_saliency_maps(**query_kwarg) | |||
| return jsonify({ | |||
| "count": count, | |||
| "samples": samples | |||
| }) | |||
| @BLUEPRINT.route("/explainer/hoc", methods=["POST"]) | |||
| def query_hoc(): | |||
| """Query hierarchical occlusion related results.""" | |||
| data = _read_post_request(request) | |||
| query_kwargs = _get_query_sample_parameters(data) | |||
| encapsulator = HierarchicalOcclusionEncap( | |||
| _image_url_formatter, | |||
| EXPLAIN_MANAGER) | |||
| count, samples = encapsulator.query_hierarchical_occlusion(**query_kwargs) | |||
| return jsonify({ | |||
| "count": count, | |||
| @@ -162,8 +211,8 @@ def query_image(): | |||
| image_type = request.args.get("type") | |||
| if image_type is None: | |||
| raise ParamMissError("type") | |||
| if image_type not in ("original", "overlay"): | |||
| raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") | |||
| if image_type not in ("original", "overlay", "outcome"): | |||
| raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay' 'outcome'") | |||
| encapsulator = DatafileEncap(EXPLAIN_MANAGER) | |||
| image = encapsulator.query_image_binary(train_id, image_path, image_type) | |||
| @@ -177,6 +226,5 @@ def init_module(app): | |||
| Args: | |||
| app: the application obj. | |||
| """ | |||
| app.register_blueprint(BLUEPRINT) | |||
| @@ -138,6 +138,17 @@ message Explain { | |||
| repeated string benchmark_method = 3; | |||
| } | |||
| message HocLayer{ | |||
| optional float prob = 1; | |||
| repeated int32 box = 2; // List of repeated x, y, w, h | |||
| } | |||
| message Hoc { | |||
| optional int32 label = 1; | |||
| optional string mask = 2; | |||
| repeated HocLayer layer = 3; | |||
| } | |||
| optional int32 sample_id = 1; // The Metadata and sample id must have one fill in | |||
| optional string image_path = 2; | |||
| repeated int32 ground_truth_label = 3; | |||
| @@ -148,4 +159,6 @@ message Explain { | |||
| optional Metadata metadata = 7; | |||
| optional string status = 8; // enum value: run, end | |||
| repeated Hoc hoc = 9; // hierarchical occlusion counterfactual | |||
| } | |||
| @@ -40,6 +40,7 @@ class ExplainFieldsEnum(BaseEnum): | |||
| GROUND_TRUTH_LABEL = 'ground_truth_label' | |||
| INFERENCE = 'inference' | |||
| EXPLANATION = 'explanation' | |||
| HIERARCHICAL_OCCLUSION = 'hierarchical_occlusion' | |||
| STATUS = 'status' | |||
| @@ -0,0 +1,156 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Utility functions for hierarchcial occlusion image generation.""" | |||
| import re | |||
| import PIL | |||
| from PIL import ImageDraw, ImageEnhance, ImageFilter | |||
| MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$' | |||
| class EditStep: | |||
| """ | |||
| Class that represents an edit step. | |||
| Args: | |||
| layer (int): Layer index. | |||
| x (int): Left pixel coordinate. | |||
| y (int): Top pixel coordinate. | |||
| width (int): Width in pixels. | |||
| height (int): Height in pixels. | |||
| """ | |||
| def __init__(self, | |||
| layer: int, | |||
| x: int, | |||
| y: int, | |||
| width: int, | |||
| height: int): | |||
| self.layer = layer | |||
| self.x = x | |||
| self.y = y | |||
| self.width = width | |||
| self.height = height | |||
| def to_coord_box(self): | |||
| """ | |||
| Convert to pixel coordinate box. | |||
| Returns: | |||
| tuple[int, int, int, int], tuple of left, top, right, bottom pixel coordinate. | |||
| """ | |||
| return self.x, self.y, self.x + self.width, self.y + self.height | |||
| def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=False): | |||
| """ | |||
| Apply edit steps on a PIL image. | |||
| Args: | |||
| image (PIL.Image): The input image in RGB mode. | |||
| mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string | |||
| e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object. | |||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||
| by_masking (bool): Whether to use masking method. Default: False. | |||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||
| Returns: | |||
| PIL.Image, the result image. | |||
| """ | |||
| if isinstance(mask, str): | |||
| mask = pil_compile_str_mask(mask, image) | |||
| if by_masking: | |||
| return _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace) | |||
| return _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace) | |||
| def pil_compile_str_mask(mask, image): | |||
| """Concert string mask to PIL Image.""" | |||
| match = re.match(MASK_GAUSSIAN_RE, mask) | |||
| if match: | |||
| radius = int(match.group(1)) | |||
| if radius > 0: | |||
| image_filter = ImageFilter.GaussianBlur(radius=radius) | |||
| mask_image = image.filter(image_filter) | |||
| mask_image = ImageEnhance.Brightness(mask_image).enhance(0.7) | |||
| mask_image = ImageEnhance.Color(mask_image).enhance(0.0) | |||
| return mask_image | |||
| raise ValueError(f"Invalid string mask: '{mask}'.") | |||
| def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False): | |||
| """ | |||
| Apply edit steps from unmasking method on a PIL image. | |||
| Args: | |||
| image (PIL.Image): The input image. | |||
| mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey | |||
| scale intensity [0, 255], a RBG tuple or a PIL Image. | |||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||
| Returns: | |||
| PIL.Image, the result image. | |||
| """ | |||
| if isinstance(mask, PIL.Image.Image): | |||
| if inplace: | |||
| bg = mask | |||
| else: | |||
| bg = mask.copy() | |||
| else: | |||
| if inplace: | |||
| raise ValueError('Argument inplace cannot be True when mask is not a PIL Image.') | |||
| if isinstance(mask, int): | |||
| mask = (mask, mask, mask) | |||
| bg = PIL.Image.new(mode="RGB", size=image.size, color=mask) | |||
| for step in edit_steps: | |||
| box = step.to_coord_box() | |||
| cropped = image.crop(box) | |||
| bg.paste(cropped, box=box) | |||
| return bg | |||
| def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False): | |||
| """ | |||
| Apply edit steps from unmasking method on a PIL image. | |||
| Args: | |||
| image (PIL.Image): The input image. | |||
| mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey | |||
| scale intensity [0, 255], a RBG tuple or a PIL Image. | |||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||
| Returns: | |||
| PIL.Image, the result image. | |||
| """ | |||
| if not inplace: | |||
| image = image.copy() | |||
| if isinstance(mask, PIL.Image.Image): | |||
| for step in edit_steps: | |||
| box = step.to_coord_box() | |||
| cropped = mask.crop(box) | |||
| image.paste(cropped, box=box) | |||
| else: | |||
| if isinstance(mask, int): | |||
| mask = (mask, mask, mask) | |||
| draw = ImageDraw.Draw(image) | |||
| for step in edit_steps: | |||
| draw.rectangle(step.to_coord_box(), fill=mask) | |||
| return image | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,16 +14,17 @@ | |||
| # ============================================================================ | |||
| """Datafile encapsulator.""" | |||
| import os | |||
| import io | |||
| import os | |||
| from PIL import Image | |||
| import numpy as np | |||
| from PIL import Image | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| from mindinsight.utils.exceptions import FileSystemPermissionError | |||
| from mindinsight.datavisual.common.exceptions import ImageNotExistError | |||
| from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | |||
| from mindinsight.utils.exceptions import FileSystemPermissionError | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| # Max uint8 value. for converting RGB pixels to [0,1] intensity. | |||
| _UINT8_MAX = 255 | |||
| @@ -59,11 +60,44 @@ class DatafileEncap(ExplainDataEncap): | |||
| Args: | |||
| train_id (str): Job ID. | |||
| image_path (str): Image path relative to explain job's summary directory. | |||
| image_type (str): Image type, 'original' or 'overlay'. | |||
| image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'. | |||
| Returns: | |||
| bytes, image binary. | |||
| """ | |||
| if image_type == "outcome": | |||
| sample_id, label, layer = image_path.strip(".jpg").split("_") | |||
| layer = int(layer) | |||
| job = self.job_manager.get_job(train_id) | |||
| samples = job.samples | |||
| label_idx = job.labels.index(label) | |||
| chosen_sample = samples[int(sample_id)] | |||
| original_path_image = chosen_sample['image'] | |||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), | |||
| original_path_image) | |||
| if self._is_forbidden(abs_image_path): | |||
| raise FileSystemPermissionError("Forbidden.") | |||
| try: | |||
| image = Image.open(abs_image_path) | |||
| except FileNotFoundError: | |||
| raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||
| except PermissionError: | |||
| raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||
| except OSError: | |||
| raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | |||
| edit_steps = [] | |||
| boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] | |||
| mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] | |||
| for box in boxes: | |||
| edit_steps.append(EditStep(layer, *box)) | |||
| image_cp = pil_apply_edit_steps(image, mask, edit_steps) | |||
| buffer = io.BytesIO() | |||
| image_cp.save(buffer, format=_PNG_FORMAT) | |||
| return buffer.getvalue() | |||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, | |||
| _clean_train_id_b4_join(train_id), | |||
| @@ -94,10 +128,10 @@ class DatafileEncap(ExplainDataEncap): | |||
| raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | |||
| if image.mode == _SINGLE_CHANNEL_MODE: | |||
| saliency = np.asarray(image)/_UINT8_MAX | |||
| saliency = np.asarray(image) / _UINT8_MAX | |||
| elif image.mode == _RGB_MODE: | |||
| saliency = np.asarray(image) | |||
| saliency = saliency[:, :, 0]/_UINT8_MAX | |||
| saliency = saliency[:, :, 0] / _UINT8_MAX | |||
| else: | |||
| raise UnknownError(f"Invalid overlay image mode:{image.mode}.") | |||
| @@ -105,7 +139,7 @@ class DatafileEncap(ExplainDataEncap): | |||
| for c in range(3): | |||
| saliency_stack[:, :, c] = saliency | |||
| rgba = saliency_stack * _SALIENCY_CMAP_HI | |||
| rgba += (1-saliency_stack) * _SALIENCY_CMAP_LOW | |||
| rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW | |||
| rgba[:, :, 3] = saliency * _UINT8_MAX | |||
| overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -12,7 +12,57 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Explain data encapsulator base class.""" | |||
| """Common explain data encapsulator base class.""" | |||
| import copy | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| def _sort_key_min_confidence(sample, labels): | |||
| """Samples sort key by the min. confidence.""" | |||
| min_confidence = float("+inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| continue | |||
| if inference["confidence"] < min_confidence: | |||
| min_confidence = inference["confidence"] | |||
| return min_confidence | |||
| def _sort_key_max_confidence(sample, labels): | |||
| """Samples sort key by the max. confidence.""" | |||
| max_confidence = float("-inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| continue | |||
| if inference["confidence"] > max_confidence: | |||
| max_confidence = inference["confidence"] | |||
| return max_confidence | |||
| def _sort_key_min_confidence_sd(sample, labels): | |||
| """Samples sort key by the min. confidence_sd.""" | |||
| min_confidence_sd = float("+inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| continue | |||
| confidence_sd = inference.get("confidence_sd", float("+inf")) | |||
| if confidence_sd < min_confidence_sd: | |||
| min_confidence_sd = confidence_sd | |||
| return min_confidence_sd | |||
| def _sort_key_max_confidence_sd(sample, labels): | |||
| """Samples sort key by the max. confidence_sd.""" | |||
| max_confidence_sd = float("-inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| continue | |||
| confidence_sd = inference.get("confidence_sd", float("-inf")) | |||
| if confidence_sd > max_confidence_sd: | |||
| max_confidence_sd = confidence_sd | |||
| return max_confidence_sd | |||
| class ExplainDataEncap: | |||
| @@ -24,3 +74,77 @@ class ExplainDataEncap: | |||
| @property | |||
| def job_manager(self): | |||
| return self._job_manager | |||
| class ExplanationEncap(ExplainDataEncap): | |||
| """Base encapsulator for explanation queries.""" | |||
| def __init__(self, image_url_formatter, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self._image_url_formatter = image_url_formatter | |||
| def _query_samples(self, | |||
| job, | |||
| labels, | |||
| sorted_name, | |||
| sorted_type, | |||
| prediction_types=None, | |||
| query_type="saliency_maps"): | |||
| """ | |||
| Query samples. | |||
| Args: | |||
| job (ExplainManager): Explain job to be query from. | |||
| labels (list[str]): Label filter. | |||
| sorted_name (str): Field to be sorted. | |||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | |||
| prediction_types (list[str]): Prediction type filter. | |||
| Returns: | |||
| list[dict], samples to be queried. | |||
| """ | |||
| samples = copy.deepcopy(job.get_all_samples()) | |||
| samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])] | |||
| if labels: | |||
| filtered = [] | |||
| for sample in samples: | |||
| infer_labels = [inference["label"] for inference in sample["inferences"]] | |||
| for infer_label in infer_labels: | |||
| if infer_label in labels: | |||
| filtered.append(sample) | |||
| break | |||
| samples = filtered | |||
| if prediction_types and len(prediction_types) < 3: | |||
| filtered = [] | |||
| for sample in samples: | |||
| infer_types = [inference["prediction_type"] for inference in sample["inferences"]] | |||
| for infer_type in infer_types: | |||
| if infer_type in prediction_types: | |||
| filtered.append(sample) | |||
| break | |||
| samples = filtered | |||
| reverse = sorted_type == "descending" | |||
| if sorted_name == "confidence": | |||
| if reverse: | |||
| samples.sort(key=lambda x: _sort_key_max_confidence(x, labels), reverse=reverse) | |||
| else: | |||
| samples.sort(key=lambda x: _sort_key_min_confidence(x, labels), reverse=reverse) | |||
| elif sorted_name == "uncertainty": | |||
| if not job.uncertainty_enabled: | |||
| raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'") | |||
| if reverse: | |||
| samples.sort(key=lambda x: _sort_key_max_confidence_sd(x, labels), reverse=reverse) | |||
| else: | |||
| samples.sort(key=lambda x: _sort_key_min_confidence_sd(x, labels), reverse=reverse) | |||
| elif sorted_name != "": | |||
| raise ParamValueError("sorted_name") | |||
| return samples | |||
| def _get_image_url(self, train_id, image_path, image_type): | |||
| """Returns image's url.""" | |||
| if self._image_url_formatter is None: | |||
| return image_path | |||
| return self._image_url_formatter(train_id, image_path, image_type) | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """Explain job list encapsulator.""" | |||
| import copy | |||
| from datetime import datetime | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | |||
| @@ -61,6 +60,9 @@ class ExplainJobEncap(ExplainDataEncap): | |||
| info["train_id"] = dir_info["relative_path"] | |||
| info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT) | |||
| info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT) | |||
| info["saliency_map"] = dir_info["saliency_map"] | |||
| info["hierarchical_occlusion"] = dir_info["hierarchical_occlusion"] | |||
| return info | |||
| @classmethod | |||
| @@ -79,7 +81,7 @@ class ExplainJobEncap(ExplainDataEncap): | |||
| """Convert ExplainJob's meta-data to jsonable info object.""" | |||
| info = cls._job_2_info(job) | |||
| info["sample_count"] = job.sample_count | |||
| info["classes"] = copy.deepcopy(job.all_classes) | |||
| info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0] | |||
| saliency_info = dict() | |||
| if job.min_confidence is None: | |||
| saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Hierarchical Occlusion encapsulator.""" | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||
| class HierarchicalOcclusionEncap(ExplanationEncap): | |||
| """Hierarchical occlusion encapsulator.""" | |||
| def query_hierarchical_occlusion(self, | |||
| train_id, | |||
| labels, | |||
| limit, | |||
| offset, | |||
| sorted_name, | |||
| sorted_type, | |||
| prediction_types=None | |||
| ): | |||
| """ | |||
| Query hierarchical occlusion results. | |||
| Args: | |||
| train_id (str): Job ID. | |||
| labels (list[str]): Label filter. | |||
| limit (int): Maximum number of items to be returned. | |||
| offset (int): Page offset. | |||
| sorted_name (str): Field to be sorted. | |||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | |||
| prediction_types (list[str]): Prediction types filter. | |||
| Returns: | |||
| tuple[int, list[dict]], total number of samples after filtering and list of sample results. | |||
| """ | |||
| job = self.job_manager.get_job(train_id) | |||
| if job is None: | |||
| raise TrainJobNotExistError(train_id) | |||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||
| query_type="hoc_layers") | |||
| sample_infos = [] | |||
| obj_offset = offset * limit | |||
| count = len(samples) | |||
| end = count | |||
| if obj_offset + limit < end: | |||
| end = obj_offset + limit | |||
| for i in range(obj_offset, end): | |||
| sample = samples[i] | |||
| sample_infos.append(self._touch_sample(sample, job)) | |||
| return count, sample_infos | |||
| def _touch_sample(self, sample, job): | |||
| """ | |||
| Final edit on single sample info. | |||
| Args: | |||
| sample (dict): Sample info. | |||
| job (ExplainManager): Explain job. | |||
| Returns: | |||
| dict, the edited sample info. | |||
| """ | |||
| sample_cp = sample.copy() | |||
| sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original") | |||
| for inference_item in sample_cp["inferences"]: | |||
| new_list = [] | |||
| for idx, hoc_layer in enumerate(inference_item["hoc_layers"]): | |||
| hoc_layer["outcome"] = self._get_image_url(job.train_id, | |||
| f"{sample['id']}_{inference_item['label']}_{idx}.jpg", | |||
| "outcome") | |||
| new_list.append(hoc_layer) | |||
| inference_item["hoc_layers"] = new_list | |||
| return sample_cp | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -14,58 +14,13 @@ | |||
| # ============================================================================ | |||
| """Saliency map encapsulator.""" | |||
| import copy | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||
| def _sort_key_min_confidence(sample): | |||
| """Samples sort key by the min. confidence.""" | |||
| min_confidence = float("+inf") | |||
| for inference in sample["inferences"]: | |||
| if inference["confidence"] < min_confidence: | |||
| min_confidence = inference["confidence"] | |||
| return min_confidence | |||
| def _sort_key_max_confidence(sample): | |||
| """Samples sort key by the max. confidence.""" | |||
| max_confidence = float("-inf") | |||
| for inference in sample["inferences"]: | |||
| if inference["confidence"] > max_confidence: | |||
| max_confidence = inference["confidence"] | |||
| return max_confidence | |||
| def _sort_key_min_confidence_sd(sample): | |||
| """Samples sort key by the min. confidence_sd.""" | |||
| min_confidence_sd = float("+inf") | |||
| for inference in sample["inferences"]: | |||
| confidence_sd = inference.get("confidence_sd", float("+inf")) | |||
| if confidence_sd < min_confidence_sd: | |||
| min_confidence_sd = confidence_sd | |||
| return min_confidence_sd | |||
| def _sort_key_max_confidence_sd(sample): | |||
| """Samples sort key by the max. confidence_sd.""" | |||
| max_confidence_sd = float("-inf") | |||
| for inference in sample["inferences"]: | |||
| confidence_sd = inference.get("confidence_sd", float("-inf")) | |||
| if confidence_sd > max_confidence_sd: | |||
| max_confidence_sd = confidence_sd | |||
| return max_confidence_sd | |||
| class SaliencyEncap(ExplainDataEncap): | |||
| class SaliencyEncap(ExplanationEncap): | |||
| """Saliency map encapsulator.""" | |||
| def __init__(self, image_url_formatter, *args, **kwargs): | |||
| super().__init__(*args, **kwargs) | |||
| self._image_url_formatter = image_url_formatter | |||
| def query_saliency_maps(self, | |||
| train_id, | |||
| labels, | |||
| @@ -73,55 +28,32 @@ class SaliencyEncap(ExplainDataEncap): | |||
| limit, | |||
| offset, | |||
| sorted_name, | |||
| sorted_type): | |||
| sorted_type, | |||
| prediction_types=None): | |||
| """ | |||
| Query saliency maps. | |||
| Args: | |||
| train_id (str): Job ID. | |||
| labels (list[str]): Label filter. | |||
| explainers (list[str]): Explainers of saliency maps to be shown. | |||
| limit (int): Max. no. of items to be returned. | |||
| limit (int): Maximum number of items to be returned. | |||
| offset (int): Page offset. | |||
| sorted_name (str): Field to be sorted. | |||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | |||
| prediction_types (list[str]): Prediction types filter. Default: None. | |||
| Returns: | |||
| tuple[int, list[dict]], total no. of samples after filtering and | |||
| list of sample result. | |||
| tuple[int, list[dict]], total number of samples after filtering and list of sample result. | |||
| """ | |||
| job = self.job_manager.get_job(train_id) | |||
| if job is None: | |||
| raise TrainJobNotExistError(train_id) | |||
| samples = copy.deepcopy(job.get_all_samples()) | |||
| if labels: | |||
| filtered = [] | |||
| for sample in samples: | |||
| infer_labels = [inference["label"] for inference in sample["inferences"]] | |||
| for infer_label in infer_labels: | |||
| if infer_label in labels: | |||
| filtered.append(sample) | |||
| break | |||
| samples = filtered | |||
| reverse = sorted_type == "descending" | |||
| if sorted_name == "confidence": | |||
| if reverse: | |||
| samples.sort(key=_sort_key_max_confidence, reverse=reverse) | |||
| else: | |||
| samples.sort(key=_sort_key_min_confidence, reverse=reverse) | |||
| elif sorted_name == "uncertainty": | |||
| if not job.uncertainty_enabled: | |||
| raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'") | |||
| if reverse: | |||
| samples.sort(key=_sort_key_max_confidence_sd, reverse=reverse) | |||
| else: | |||
| samples.sort(key=_sort_key_min_confidence_sd, reverse=reverse) | |||
| elif sorted_name != "": | |||
| raise ParamValueError("sorted_name") | |||
| samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types, | |||
| query_type="saliency_maps") | |||
| sample_infos = [] | |||
| obj_offset = offset*limit | |||
| obj_offset = offset * limit | |||
| count = len(samples) | |||
| end = count | |||
| if obj_offset + limit < end: | |||
| @@ -139,11 +71,13 @@ class SaliencyEncap(ExplainDataEncap): | |||
| sample (dict): Sample info. | |||
| job (ExplainJob): Explain job. | |||
| explainers (list[str]): Explainer names. | |||
| Returns: | |||
| dict, the edited sample info. | |||
| """ | |||
| sample["image"] = self._get_image_url(job.train_id, sample['image'], "original") | |||
| for inference in sample["inferences"]: | |||
| sample_cp = sample.copy() | |||
| sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") | |||
| for inference in sample_cp["inferences"]: | |||
| new_list = [] | |||
| for saliency_map in inference["saliency_maps"]: | |||
| if explainers and saliency_map["explainer"] not in explainers: | |||
| @@ -151,10 +85,4 @@ class SaliencyEncap(ExplainDataEncap): | |||
| saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") | |||
| new_list.append(saliency_map) | |||
| inference["saliency_maps"] = new_list | |||
| return sample | |||
| def _get_image_url(self, train_id, image_path, image_type): | |||
| """Returns image's url.""" | |||
| if self._image_url_formatter is None: | |||
| return image_path | |||
| return self._image_url_formatter(train_id, image_path, image_type) | |||
| return sample_cp | |||
| @@ -14,14 +14,13 @@ | |||
| # ============================================================================ | |||
| """ExplainLoader.""" | |||
| from collections import defaultdict | |||
| from enum import Enum | |||
| import math | |||
| import os | |||
| import re | |||
| import threading | |||
| from collections import defaultdict | |||
| from datetime import datetime | |||
| from enum import Enum | |||
| from typing import Dict, Iterable, List, Optional, Union | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| @@ -44,6 +43,7 @@ _SAMPLE_FIELD_NAMES = [ | |||
| ExplainFieldsEnum.GROUND_TRUTH_LABEL, | |||
| ExplainFieldsEnum.INFERENCE, | |||
| ExplainFieldsEnum.EXPLANATION, | |||
| ExplainFieldsEnum.HIERARCHICAL_OCCLUSION | |||
| ] | |||
| @@ -75,10 +75,10 @@ class ExplainLoader: | |||
| 'create_time': os.stat(summary_dir).st_ctime, | |||
| 'update_time': os.stat(summary_dir).st_mtime, | |||
| 'query_time': os.stat(summary_dir).st_ctime, | |||
| 'uncertainty_enabled': False, | |||
| 'uncertainty_enabled': False | |||
| } | |||
| self._samples = defaultdict(dict) | |||
| self._metadata = {'explainers': [], 'metrics': [], 'labels': []} | |||
| self._metadata = {'explainers': [], 'metrics': [], 'labels': [], 'min_confidence': 0.5} | |||
| self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} | |||
| self._status = _LoaderStatus.STOP.value | |||
| @@ -91,25 +91,19 @@ class ExplainLoader: | |||
| Returns: | |||
| list[dict], a list of dict, each dict contains: | |||
| - id (int): label id | |||
| - label (str): label name | |||
| - sample_count (int): number of samples for each label | |||
| - id (int): Label id. | |||
| - label (str): Label name. | |||
| - sample_count (int): Number of samples for each label. | |||
| """ | |||
| sample_count_per_label = defaultdict(int) | |||
| samples_copy = self._samples.copy() | |||
| for sample in samples_copy.values(): | |||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||
| for label in sample['ground_truth_label']: | |||
| for sample in self._samples.values(): | |||
| if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')): | |||
| for label in set(sample['ground_truth_label'] + sample['predicted_label']): | |||
| sample_count_per_label[label] += 1 | |||
| all_classes_return = [] | |||
| for label_id, label_name in enumerate(self._metadata['labels']): | |||
| single_info = { | |||
| 'id': label_id, | |||
| 'label': label_name, | |||
| 'sample_count': sample_count_per_label[label_id] | |||
| } | |||
| all_classes_return.append(single_info) | |||
| all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]} | |||
| for label_id, label_name in enumerate(self._metadata['labels'])] | |||
| return all_classes_return | |||
| @property | |||
| @@ -166,19 +160,23 @@ class ExplainLoader: | |||
| Returns: | |||
| list[dict], A list of evaluation results of each explainer. Each item contains: | |||
| - explainer (str): Name of evaluated explainer. | |||
| - evaluations (list[dict]): A list of evaluation results by different metrics. | |||
| - class_scores (list[dict]): A list of evaluation results on different labels. | |||
| Each item in the evaluations contains: | |||
| - metric (str): name of metric method | |||
| - score (float): evaluation result | |||
| Each item in the class_scores contains: | |||
| - label (str): Name of label | |||
| - evaluations (list[dict]): A list of evaluation results on different labels by different metrics. | |||
| Each item in evaluations contains: | |||
| - metric (str): Name of metric method | |||
| - score (float): Evaluation scores of explainer on specific label by the metric. | |||
| """ | |||
| @@ -215,7 +213,7 @@ class ExplainLoader: | |||
| @property | |||
| def min_confidence(self) -> Optional[float]: | |||
| """Return minimum confidence used to filter the predicted labels.""" | |||
| return None | |||
| return self._metadata['min_confidence'] | |||
| @property | |||
| def sample_count(self) -> int: | |||
| @@ -227,19 +225,17 @@ class ExplainLoader: | |||
| Return: | |||
| int, total number of available samples in the loading job. | |||
| """ | |||
| sample_count = 0 | |||
| samples_copy = self._samples.copy() | |||
| for sample in samples_copy.values(): | |||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||
| for sample in self._samples.values(): | |||
| if sample.get('image', False): | |||
| sample_count += 1 | |||
| return sample_count | |||
| @property | |||
| def samples(self) -> List[Dict]: | |||
| """Return the information of all samples in the job.""" | |||
| return self.get_all_samples() | |||
| return self._samples | |||
| @property | |||
| def train_id(self) -> str: | |||
| @@ -298,6 +294,7 @@ class ExplainLoader: | |||
| self._clear_job() | |||
| if event_dict: | |||
| self._import_data_from_event(event_dict) | |||
| self._reform_sample_info() | |||
| @property | |||
| def status(self): | |||
| @@ -317,59 +314,19 @@ class ExplainLoader: | |||
| def get_all_samples(self) -> List[Dict]: | |||
| """ | |||
| Return a list of sample information cachced in the explain job | |||
| Return a list of sample information cached in the explain job | |||
| Returns: | |||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||
| consists of: | |||
| - id (int): sample id | |||
| - name (str): basename of image | |||
| - labels (list[str]): list of labels | |||
| - inferences list[dict]) | |||
| - id (int): Sample id. | |||
| - name (str): Basename of image. | |||
| - inferences (list[dict]): List of inferences for all labels. | |||
| """ | |||
| returned_samples = [] | |||
| samples_copy = self._samples.copy() | |||
| for sample_id, sample_info in samples_copy.items(): | |||
| if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): | |||
| continue | |||
| returned_sample = { | |||
| 'id': sample_id, | |||
| 'name': str(sample_id), | |||
| 'image': sample_info['image'], | |||
| 'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']], | |||
| } | |||
| # Check whether the sample has valid label-prob pairs. | |||
| if not ExplainLoader._is_inference_valid(sample_info): | |||
| continue | |||
| inferences = {} | |||
| for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||
| sample_info['ground_truth_prob'] + sample_info['predicted_prob']): | |||
| inferences[label] = { | |||
| 'label': self._metadata['labels'][label], | |||
| 'confidence': _round(prob), | |||
| 'saliency_maps': [] | |||
| } | |||
| if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: | |||
| for label, std, low, high in zip( | |||
| sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||
| sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], | |||
| sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], | |||
| sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] | |||
| ): | |||
| inferences[label]['confidence_sd'] = _round(std) | |||
| inferences[label]['confidence_itl95'] = [_round(low), _round(high)] | |||
| for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): | |||
| for label, heatmap_path in label_heatmap_path_dict.items(): | |||
| if label in inferences: | |||
| inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) | |||
| returned_sample['inferences'] = list(inferences.values()) | |||
| returned_samples.append(returned_sample) | |||
| returned_samples = [{'id': sample_id, 'name': info['name'], 'image': info['image'], | |||
| 'inferences': list(info['inferences'].values())} for sample_id, info in | |||
| self._samples.items() if info.get('image', False)] | |||
| return returned_samples | |||
| def _import_data_from_event(self, event_dict: Dict): | |||
| @@ -461,30 +418,29 @@ class ExplainLoader: | |||
| - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. | |||
| - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary | |||
| of saliency maps. The structure of explanations demonstrates below: | |||
| { | |||
| explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | |||
| explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | |||
| ... | |||
| } | |||
| - hierarchical_occlusion (dict): A dictionary where each label is matched to a dictionary: | |||
| {label_1: [{prob: layer1_prob, bbox: []}, {prob: layer2_prob, bbox: []}], | |||
| label_2: | |||
| } | |||
| """ | |||
| if getattr(sample, 'sample_id', None) is None: | |||
| raise ParamValueError('sample_event has no sample_id') | |||
| sample_id = sample.sample_id | |||
| samples_copy = self._samples.copy() | |||
| if sample_id not in samples_copy: | |||
| if sample_id not in self._samples: | |||
| self._samples[sample_id] = { | |||
| 'id': sample_id, | |||
| 'name': str(sample_id), | |||
| 'image': sample.image_path, | |||
| 'ground_truth_label': [], | |||
| 'ground_truth_prob': [], | |||
| 'ground_truth_prob_sd': [], | |||
| 'ground_truth_prob_itl95_low': [], | |||
| 'ground_truth_prob_itl95_hi': [], | |||
| 'predicted_label': [], | |||
| 'predicted_prob': [], | |||
| 'predicted_prob_sd': [], | |||
| 'predicted_prob_itl95_low': [], | |||
| 'predicted_prob_itl95_hi': [], | |||
| 'explanation': defaultdict(dict) | |||
| 'inferences': defaultdict(dict), | |||
| 'explanation': defaultdict(dict), | |||
| 'hierarchical_occlusion': defaultdict(dict) | |||
| } | |||
| if sample.image_path: | |||
| @@ -492,27 +448,66 @@ class ExplainLoader: | |||
| for tag in _SAMPLE_FIELD_NAMES: | |||
| if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: | |||
| self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) | |||
| if not self._samples[sample_id]['ground_truth_label']: | |||
| self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) | |||
| elif tag == ExplainFieldsEnum.INFERENCE: | |||
| self._import_inference_from_event(sample, sample_id) | |||
| else: | |||
| elif tag == ExplainFieldsEnum.EXPLANATION: | |||
| self._import_explanation_from_event(sample, sample_id) | |||
| elif tag == ExplainFieldsEnum.HIERARCHICAL_OCCLUSION: | |||
| self._import_hoc_from_event(sample, sample_id) | |||
| def _reform_sample_info(self): | |||
| """Reform the sample info.""" | |||
| for _, sample_info in self._samples.items(): | |||
| inferences = sample_info['inferences'] | |||
| res_dict = defaultdict(list) | |||
| for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): | |||
| for label, heatmap_path in label_heatmap_path_dict.items(): | |||
| res_dict[label].append({'explainer': explainer, 'overlay': heatmap_path}) | |||
| for label, item in inferences.items(): | |||
| item['saliency_maps'] = res_dict[label] | |||
| for label, item in sample_info['hierarchical_occlusion'].items(): | |||
| inferences[label]['hoc_layers'] = item['hoc_layers'] | |||
| def _import_inference_from_event(self, event, sample_id): | |||
| """Parse the inference event.""" | |||
| inference = event.inference | |||
| self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) | |||
| self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) | |||
| self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) | |||
| self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) | |||
| self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) | |||
| self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) | |||
| self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) | |||
| self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) | |||
| self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) | |||
| if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: | |||
| if inference.ground_truth_prob_sd or inference.predicted_prob_sd: | |||
| self._loader_info['uncertainty_enabled'] = True | |||
| if not self._samples[sample_id]['predicted_label']: | |||
| self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) | |||
| if not self._samples[sample_id]['inferences']: | |||
| inferences = {} | |||
| for label, prob in zip(list(event.ground_truth_label) + list(inference.predicted_label), | |||
| list(inference.ground_truth_prob) + list(inference.predicted_prob)): | |||
| inferences[label] = { | |||
| 'label': self._metadata['labels'][label], | |||
| 'confidence': _round(prob), | |||
| 'saliency_maps': [], | |||
| 'hoc_layers': {}, | |||
| } | |||
| if not event.ground_truth_label: | |||
| inferences[label]['prediction_type'] = None | |||
| else: | |||
| if prob < self.min_confidence: | |||
| inferences[label]['prediction_type'] = 'FN' | |||
| elif label in event.ground_truth_label: | |||
| inferences[label]['prediction_type'] = 'TP' | |||
| else: | |||
| inferences[label]['prediction_type'] = 'FP' | |||
| if self._loader_info['uncertainty_enabled']: | |||
| for label, std, low, high in zip( | |||
| list(event.ground_truth_label) + list(inference.predicted_label), | |||
| list(inference.ground_truth_prob_sd) + list(inference.predicted_prob_sd), | |||
| list(inference.ground_truth_prob_itl95_low) + list(inference.predicted_prob_itl95_low), | |||
| list(inference.ground_truth_prob_itl95_hi) + list(inference.predicted_prob_itl95_hi)): | |||
| inferences[label]['confidence_sd'] = _round(std) | |||
| inferences[label]['confidence_itl95'] = [_round(low), _round(high)] | |||
| self._samples[sample_id]['inferences'] = inferences | |||
| def _import_explanation_from_event(self, event, sample_id): | |||
| """Parse the explanation event.""" | |||
| @@ -525,6 +520,24 @@ class ExplainLoader: | |||
| label = explanation_item.label | |||
| sample_explanation[explainer][label] = explanation_item.heatmap_path | |||
| def _import_hoc_from_event(self, event, sample_id): | |||
| """Parse the mango event.""" | |||
| sample_hoc = self._samples[sample_id]['hierarchical_occlusion'] | |||
| if event.hierarchical_occlusion: | |||
| for hoc_item in event.hierarchical_occlusion: | |||
| label = hoc_item.label | |||
| sample_hoc[label] = {} | |||
| sample_hoc[label]['label'] = label | |||
| sample_hoc[label]['mask'] = hoc_item.mask | |||
| sample_hoc[label]['confidence'] = self._samples[sample_id]['inferences'][label]['confidence'] | |||
| sample_hoc[label]['hoc_layers'] = [] | |||
| for hoc_layer in hoc_item.layer: | |||
| sample_hoc_dict = {'confidence': hoc_layer.prob} | |||
| box_lst = list(hoc_layer.box) | |||
| box = [box_lst[i: i + 4] for i in range(0, len(hoc_layer.box), 4)] | |||
| sample_hoc_dict['boxes'] = box | |||
| sample_hoc[label]['hoc_layers'].append(sample_hoc_dict) | |||
| def _clear_job(self): | |||
| """Clear the cached data and update the time info of the loader.""" | |||
| self._samples.clear() | |||
| @@ -114,6 +114,7 @@ class ExplainManager: | |||
| Returns: | |||
| tuple[total, directories], total indicates the overall number of explain directories and directories | |||
| indicate list of summary directory info including the following attributes. | |||
| - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, | |||
| starting with "./". | |||
| - create_time (datetime): Creation time of summary file. | |||
| @@ -30,6 +30,7 @@ from mindinsight.utils.exceptions import UnknownError | |||
| HEADER_SIZE = 8 | |||
| CRC_STR_SIZE = 4 | |||
| MAX_EVENT_STRING = 500000000 | |||
| BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) | |||
| MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) | |||
| InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | |||
| @@ -42,7 +43,7 @@ InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | |||
| 'predicted_prob_itl95_low', | |||
| 'predicted_prob_itl95_hi']) | |||
| SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', | |||
| 'explanation', 'status']) | |||
| 'explanation', 'hierarchical_occlusion', 'status']) | |||
| class ExplainParser(_SummaryParser): | |||
| @@ -193,6 +194,7 @@ class ExplainParser(_SummaryParser): | |||
| ground_truth_label=tensor_event_value.ground_truth_label, | |||
| inference=inference, | |||
| explanation=tensor_event_value.explanation, | |||
| hierarchical_occlusion=tensor_event_value.hoc, | |||
| status=tensor_event_value.status | |||
| ) | |||
| return sample_data | |||
| @@ -105,7 +105,9 @@ class MockExplainManager: | |||
| { | |||
| "relative_path": "./mock_job_1", | |||
| "create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), | |||
| "update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT) | |||
| "update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), | |||
| "saliency_map": True, | |||
| "hierarchical_occlusion": True | |||
| } | |||
| ] | |||
| return 1, job_list | |||
| @@ -31,7 +31,9 @@ class TestExplainJobEncap: | |||
| { | |||
| "train_id": "./mock_job_1", | |||
| "create_time": "2020-10-01 20:21:23", | |||
| "update_time": "2020-10-01 20:21:23" | |||
| "update_time": "2020-10-01 20:21:23", | |||
| "saliency_map": True, | |||
| "hierarchical_occlusion": True | |||
| } | |||
| ]) | |||
| assert job_list == expected_result | |||
| @@ -24,10 +24,6 @@ from mindinsight.explainer.manager.explain_loader import _LoaderStatus | |||
| from mindinsight.explainer.manager.explain_parser import ExplainParser | |||
| def abc(): | |||
| FileHandler.is_file('aaa') | |||
| print('after') | |||
| class TestExplainLoader: | |||
| """Test explain loader class.""" | |||