Browse Source

!1135 Adapt mindinsight backend to support CV counterfactual explanation

From: @lixiaohui33
Reviewed-by: @ouwenchang
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
9c83dad1f8
15 changed files with 650 additions and 241 deletions
  1. +82
    -34
      mindinsight/backend/explainer/explainer_api.py
  2. +13
    -0
      mindinsight/datavisual/proto_files/mindinsight_summary.proto
  3. +1
    -0
      mindinsight/explainer/common/enums.py
  4. +156
    -0
      mindinsight/explainer/encapsulator/_hoc_pil_apply.py
  5. +43
    -9
      mindinsight/explainer/encapsulator/datafile_encap.py
  6. +126
    -2
      mindinsight/explainer/encapsulator/explain_data_encap.py
  7. +5
    -3
      mindinsight/explainer/encapsulator/explain_job_encap.py
  8. +87
    -0
      mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py
  9. +16
    -88
      mindinsight/explainer/encapsulator/saliency_encap.py
  10. +111
    -98
      mindinsight/explainer/manager/explain_loader.py
  11. +1
    -0
      mindinsight/explainer/manager/explain_manager.py
  12. +3
    -1
      mindinsight/explainer/manager/explain_parser.py
  13. +3
    -1
      tests/ut/explainer/encapsulator/mock_explain_manager.py
  14. +3
    -1
      tests/ut/explainer/encapsulator/test_explain_job_encap.py
  15. +0
    -4
      tests/ut/explainer/manager/test_explain_loader.py

+ 82
- 34
mindinsight/backend/explainer/explainer_api.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,8 +14,8 @@
# ============================================================================ # ============================================================================
"""Explainer restful api.""" """Explainer restful api."""


import os
import json import json
import os
import urllib.parse import urllib.parse


from flask import Blueprint from flask import Blueprint
@@ -23,16 +23,18 @@ from flask import jsonify
from flask import request from flask import request


from mindinsight.conf import settings from mindinsight.conf import settings
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.common.validation import Validation
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.datavisual.utils.tools import get_train_id from mindinsight.datavisual.utils.tools import get_train_id
from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER
from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap
from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from mindinsight.explainer.encapsulator.hierarchical_occlusion_encap import HierarchicalOcclusionEncap
from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER
from mindinsight.utils.exceptions import ParamMissError
from mindinsight.utils.exceptions import ParamTypeError
from mindinsight.utils.exceptions import ParamValueError


URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX
BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX)
@@ -66,6 +68,50 @@ def _read_post_request(post_request):
return body return body




def _get_query_sample_parameters(data):
"""Get parameter for query."""

train_id = data.get("train_id")
if train_id is None:
raise ParamMissError('train_id')

labels = data.get("labels")
if labels is not None and not isinstance(labels, list):
raise ParamTypeError("labels", (list, None))
if labels:
for item in labels:
if not isinstance(item, str):
raise ParamTypeError("element of labels", str)

limit = data.get("limit", 10)
limit = Validation.check_limit(limit, min_value=1, max_value=100)
offset = data.get("offset", 0)
offset = Validation.check_offset(offset=offset)
sorted_name = data.get("sorted_name", "")
sorted_type = data.get("sorted_type", "descending")
if sorted_name not in ("", "confidence", "uncertainty"):
raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'")
if sorted_type not in ("ascending", "descending"):
raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'")

prediction_types = data.get("prediction_types")
if prediction_types is not None and not isinstance(prediction_types, list):
raise ParamTypeError("prediction_types", (list, None))
if prediction_types:
for item in prediction_types:
if item not in ['TP', 'FN', 'FP']:
raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.")

query_kwarg = {"train_id": train_id,
"labels": labels,
"limit": limit,
"offset": offset,
"sorted_name": sorted_name,
"sorted_type": sorted_type,
"prediction_types": prediction_types}
return query_kwarg


@BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"])
def query_explain_jobs(): def query_explain_jobs():
"""Query explain jobs.""" """Query explain jobs."""
@@ -99,37 +145,40 @@ def query_explain_job():
@BLUEPRINT.route("/explainer/saliency", methods=["POST"]) @BLUEPRINT.route("/explainer/saliency", methods=["POST"])
def query_saliency(): def query_saliency():
"""Query saliency map related results.""" """Query saliency map related results."""

data = _read_post_request(request) data = _read_post_request(request)

train_id = data.get("train_id")
if train_id is None:
raise ParamMissError('train_id')

labels = data.get("labels")
query_kwarg = _get_query_sample_parameters(data)
explainers = data.get("explainers") explainers = data.get("explainers")
limit = data.get("limit", 10)
limit = Validation.check_limit(limit, min_value=1, max_value=100)
offset = data.get("offset", 0)
offset = Validation.check_offset(offset=offset)
sorted_name = data.get("sorted_name", "")
sorted_type = data.get("sorted_type", "descending")
if explainers is not None and not isinstance(explainers, list):
raise ParamTypeError("explainers", (list, None))
if explainers:
for item in explainers:
if not isinstance(item, str):
raise ParamTypeError("element of explainers", str)


if sorted_name not in ("", "confidence", "uncertainty"):
raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'")
if sorted_type not in ("ascending", "descending"):
raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'")
query_kwarg["explainers"] = explainers


encapsulator = SaliencyEncap( encapsulator = SaliencyEncap(
_image_url_formatter, _image_url_formatter,
EXPLAIN_MANAGER) EXPLAIN_MANAGER)
count, samples = encapsulator.query_saliency_maps(train_id=train_id,
labels=labels,
explainers=explainers,
limit=limit,
offset=offset,
sorted_name=sorted_name,
sorted_type=sorted_type)
count, samples = encapsulator.query_saliency_maps(**query_kwarg)

return jsonify({
"count": count,
"samples": samples
})


@BLUEPRINT.route("/explainer/hoc", methods=["POST"])
def query_hoc():
"""Query hierarchical occlusion related results."""
data = _read_post_request(request)

query_kwargs = _get_query_sample_parameters(data)

encapsulator = HierarchicalOcclusionEncap(
_image_url_formatter,
EXPLAIN_MANAGER)
count, samples = encapsulator.query_hierarchical_occlusion(**query_kwargs)


return jsonify({ return jsonify({
"count": count, "count": count,
@@ -162,8 +211,8 @@ def query_image():
image_type = request.args.get("type") image_type = request.args.get("type")
if image_type is None: if image_type is None:
raise ParamMissError("type") raise ParamMissError("type")
if image_type not in ("original", "overlay"):
raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'")
if image_type not in ("original", "overlay", "outcome"):
raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay' 'outcome'")


encapsulator = DatafileEncap(EXPLAIN_MANAGER) encapsulator = DatafileEncap(EXPLAIN_MANAGER)
image = encapsulator.query_image_binary(train_id, image_path, image_type) image = encapsulator.query_image_binary(train_id, image_path, image_type)
@@ -177,6 +226,5 @@ def init_module(app):


Args: Args:
app: the application obj. app: the application obj.

""" """
app.register_blueprint(BLUEPRINT) app.register_blueprint(BLUEPRINT)

+ 13
- 0
mindinsight/datavisual/proto_files/mindinsight_summary.proto View File

@@ -138,6 +138,17 @@ message Explain {
repeated string benchmark_method = 3; repeated string benchmark_method = 3;
} }


message HocLayer{
optional float prob = 1;
repeated int32 box = 2; // List of repeated x, y, w, h
}

message Hoc {
optional int32 label = 1;
optional string mask = 2;
repeated HocLayer layer = 3;
}

optional int32 sample_id = 1; // The Metadata and sample id must have one fill in optional int32 sample_id = 1; // The Metadata and sample id must have one fill in
optional string image_path = 2; optional string image_path = 2;
repeated int32 ground_truth_label = 3; repeated int32 ground_truth_label = 3;
@@ -148,4 +159,6 @@ message Explain {


optional Metadata metadata = 7; optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end optional string status = 8; // enum value: run, end

repeated Hoc hoc = 9; // hierarchical occlusion counterfactual
} }

+ 1
- 0
mindinsight/explainer/common/enums.py View File

@@ -40,6 +40,7 @@ class ExplainFieldsEnum(BaseEnum):
GROUND_TRUTH_LABEL = 'ground_truth_label' GROUND_TRUTH_LABEL = 'ground_truth_label'
INFERENCE = 'inference' INFERENCE = 'inference'
EXPLANATION = 'explanation' EXPLANATION = 'explanation'
HIERARCHICAL_OCCLUSION = 'hierarchical_occlusion'
STATUS = 'status' STATUS = 'status'






+ 156
- 0
mindinsight/explainer/encapsulator/_hoc_pil_apply.py View File

@@ -0,0 +1,156 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Utility functions for hierarchcial occlusion image generation."""

import re

import PIL
from PIL import ImageDraw, ImageEnhance, ImageFilter

MASK_GAUSSIAN_RE = r'^gaussian:(\d+)$'


class EditStep:
"""
Class that represents an edit step.

Args:
layer (int): Layer index.
x (int): Left pixel coordinate.
y (int): Top pixel coordinate.
width (int): Width in pixels.
height (int): Height in pixels.
"""
def __init__(self,
layer: int,
x: int,
y: int,
width: int,
height: int):
self.layer = layer
self.x = x
self.y = y
self.width = width
self.height = height

def to_coord_box(self):
"""
Convert to pixel coordinate box.

Returns:
tuple[int, int, int, int], tuple of left, top, right, bottom pixel coordinate.
"""
return self.x, self.y, self.x + self.width, self.y + self.height


def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=False):
"""
Apply edit steps on a PIL image.

Args:
image (PIL.Image): The input image in RGB mode.
mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string
e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object.
edit_steps (list[EditStep]): Edit steps to be drawn.
by_masking (bool): Whether to use masking method. Default: False.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

Returns:
PIL.Image, the result image.
"""
if isinstance(mask, str):
mask = pil_compile_str_mask(mask, image)
if by_masking:
return _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace)
return _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace)


def pil_compile_str_mask(mask, image):
"""Concert string mask to PIL Image."""
match = re.match(MASK_GAUSSIAN_RE, mask)
if match:
radius = int(match.group(1))
if radius > 0:
image_filter = ImageFilter.GaussianBlur(radius=radius)
mask_image = image.filter(image_filter)
mask_image = ImageEnhance.Brightness(mask_image).enhance(0.7)
mask_image = ImageEnhance.Color(mask_image).enhance(0.0)
return mask_image
raise ValueError(f"Invalid string mask: '{mask}'.")


def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False):
"""
Apply edit steps from unmasking method on a PIL image.

Args:
image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

Returns:
PIL.Image, the result image.
"""
if isinstance(mask, PIL.Image.Image):
if inplace:
bg = mask
else:
bg = mask.copy()
else:
if inplace:
raise ValueError('Argument inplace cannot be True when mask is not a PIL Image.')
if isinstance(mask, int):
mask = (mask, mask, mask)
bg = PIL.Image.new(mode="RGB", size=image.size, color=mask)

for step in edit_steps:
box = step.to_coord_box()
cropped = image.crop(box)
bg.paste(cropped, box=box)
return bg


def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False):
"""
Apply edit steps from unmasking method on a PIL image.

Args:
image (PIL.Image): The input image.
mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey
scale intensity [0, 255], a RBG tuple or a PIL Image.
edit_steps (list[EditStep]): Edit steps to be drawn.
inplace (bool): True to draw on the input image, otherwise draw on a cloned image.

Returns:
PIL.Image, the result image.
"""
if not inplace:
image = image.copy()

if isinstance(mask, PIL.Image.Image):
for step in edit_steps:
box = step.to_coord_box()
cropped = mask.crop(box)
image.paste(cropped, box=box)
else:
if isinstance(mask, int):
mask = (mask, mask, mask)
draw = ImageDraw.Draw(image)
for step in edit_steps:
draw.rectangle(step.to_coord_box(), fill=mask)
return image

+ 43
- 9
mindinsight/explainer/encapsulator/datafile_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,16 +14,17 @@
# ============================================================================ # ============================================================================
"""Datafile encapsulator.""" """Datafile encapsulator."""


import os
import io import io
import os


from PIL import Image
import numpy as np import numpy as np
from PIL import Image


from mindinsight.utils.exceptions import UnknownError
from mindinsight.utils.exceptions import FileSystemPermissionError
from mindinsight.datavisual.common.exceptions import ImageNotExistError from mindinsight.datavisual.common.exceptions import ImageNotExistError
from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.utils.exceptions import FileSystemPermissionError
from mindinsight.utils.exceptions import UnknownError


# Max uint8 value. for converting RGB pixels to [0,1] intensity. # Max uint8 value. for converting RGB pixels to [0,1] intensity.
_UINT8_MAX = 255 _UINT8_MAX = 255
@@ -59,11 +60,44 @@ class DatafileEncap(ExplainDataEncap):
Args: Args:
train_id (str): Job ID. train_id (str): Job ID.
image_path (str): Image path relative to explain job's summary directory. image_path (str): Image path relative to explain job's summary directory.
image_type (str): Image type, 'original' or 'overlay'.
image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'.


Returns: Returns:
bytes, image binary. bytes, image binary.
""" """
if image_type == "outcome":
sample_id, label, layer = image_path.strip(".jpg").split("_")
layer = int(layer)
job = self.job_manager.get_job(train_id)
samples = job.samples
label_idx = job.labels.index(label)

chosen_sample = samples[int(sample_id)]
original_path_image = chosen_sample['image']
abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id),
original_path_image)
if self._is_forbidden(abs_image_path):
raise FileSystemPermissionError("Forbidden.")
try:
image = Image.open(abs_image_path)
except FileNotFoundError:
raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except PermissionError:
raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}")
except OSError:
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")

edit_steps = []
boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"]
mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"]

for box in boxes:
edit_steps.append(EditStep(layer, *box))
image_cp = pil_apply_edit_steps(image, mask, edit_steps)
buffer = io.BytesIO()
image_cp.save(buffer, format=_PNG_FORMAT)

return buffer.getvalue()


abs_image_path = os.path.join(self.job_manager.summary_base_dir, abs_image_path = os.path.join(self.job_manager.summary_base_dir,
_clean_train_id_b4_join(train_id), _clean_train_id_b4_join(train_id),
@@ -94,10 +128,10 @@ class DatafileEncap(ExplainDataEncap):
raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}")


if image.mode == _SINGLE_CHANNEL_MODE: if image.mode == _SINGLE_CHANNEL_MODE:
saliency = np.asarray(image)/_UINT8_MAX
saliency = np.asarray(image) / _UINT8_MAX
elif image.mode == _RGB_MODE: elif image.mode == _RGB_MODE:
saliency = np.asarray(image) saliency = np.asarray(image)
saliency = saliency[:, :, 0]/_UINT8_MAX
saliency = saliency[:, :, 0] / _UINT8_MAX
else: else:
raise UnknownError(f"Invalid overlay image mode:{image.mode}.") raise UnknownError(f"Invalid overlay image mode:{image.mode}.")


@@ -105,7 +139,7 @@ class DatafileEncap(ExplainDataEncap):
for c in range(3): for c in range(3):
saliency_stack[:, :, c] = saliency saliency_stack[:, :, c] = saliency
rgba = saliency_stack * _SALIENCY_CMAP_HI rgba = saliency_stack * _SALIENCY_CMAP_HI
rgba += (1-saliency_stack) * _SALIENCY_CMAP_LOW
rgba += (1 - saliency_stack) * _SALIENCY_CMAP_LOW
rgba[:, :, 3] = saliency * _UINT8_MAX rgba[:, :, 3] = saliency * _UINT8_MAX


overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE) overlay = Image.fromarray(np.uint8(rgba), mode=_RGBA_MODE)


+ 126
- 2
mindinsight/explainer/encapsulator/explain_data_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -12,7 +12,57 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Explain data encapsulator base class."""
"""Common explain data encapsulator base class."""

import copy

from mindinsight.utils.exceptions import ParamValueError


def _sort_key_min_confidence(sample, labels):
"""Samples sort key by the min. confidence."""
min_confidence = float("+inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
if inference["confidence"] < min_confidence:
min_confidence = inference["confidence"]
return min_confidence


def _sort_key_max_confidence(sample, labels):
"""Samples sort key by the max. confidence."""
max_confidence = float("-inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
if inference["confidence"] > max_confidence:
max_confidence = inference["confidence"]
return max_confidence


def _sort_key_min_confidence_sd(sample, labels):
"""Samples sort key by the min. confidence_sd."""
min_confidence_sd = float("+inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
confidence_sd = inference.get("confidence_sd", float("+inf"))
if confidence_sd < min_confidence_sd:
min_confidence_sd = confidence_sd
return min_confidence_sd


def _sort_key_max_confidence_sd(sample, labels):
"""Samples sort key by the max. confidence_sd."""
max_confidence_sd = float("-inf")
for inference in sample["inferences"]:
if labels and inference["label"] not in labels:
continue
confidence_sd = inference.get("confidence_sd", float("-inf"))
if confidence_sd > max_confidence_sd:
max_confidence_sd = confidence_sd
return max_confidence_sd




class ExplainDataEncap: class ExplainDataEncap:
@@ -24,3 +74,77 @@ class ExplainDataEncap:
@property @property
def job_manager(self): def job_manager(self):
return self._job_manager return self._job_manager


class ExplanationEncap(ExplainDataEncap):
"""Base encapsulator for explanation queries."""

def __init__(self, image_url_formatter, *args, **kwargs):
super().__init__(*args, **kwargs)
self._image_url_formatter = image_url_formatter

def _query_samples(self,
job,
labels,
sorted_name,
sorted_type,
prediction_types=None,
query_type="saliency_maps"):
"""
Query samples.

Args:
job (ExplainManager): Explain job to be query from.
labels (list[str]): Label filter.
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction type filter.

Returns:
list[dict], samples to be queried.
"""

samples = copy.deepcopy(job.get_all_samples())
samples = [sample for sample in samples if any(infer[query_type] for infer in sample['inferences'])]
if labels:
filtered = []
for sample in samples:
infer_labels = [inference["label"] for inference in sample["inferences"]]
for infer_label in infer_labels:
if infer_label in labels:
filtered.append(sample)
break
samples = filtered

if prediction_types and len(prediction_types) < 3:
filtered = []
for sample in samples:
infer_types = [inference["prediction_type"] for inference in sample["inferences"]]
for infer_type in infer_types:
if infer_type in prediction_types:
filtered.append(sample)
break
samples = filtered

reverse = sorted_type == "descending"
if sorted_name == "confidence":
if reverse:
samples.sort(key=lambda x: _sort_key_max_confidence(x, labels), reverse=reverse)
else:
samples.sort(key=lambda x: _sort_key_min_confidence(x, labels), reverse=reverse)
elif sorted_name == "uncertainty":
if not job.uncertainty_enabled:
raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'")
if reverse:
samples.sort(key=lambda x: _sort_key_max_confidence_sd(x, labels), reverse=reverse)
else:
samples.sort(key=lambda x: _sort_key_min_confidence_sd(x, labels), reverse=reverse)
elif sorted_name != "":
raise ParamValueError("sorted_name")
return samples

def _get_image_url(self, train_id, image_path, image_type):
"""Returns image's url."""
if self._image_url_formatter is None:
return image_path
return self._image_url_formatter(train_id, image_path, image_type)

+ 5
- 3
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""Explain job list encapsulator.""" """Explain job list encapsulator."""


import copy
from datetime import datetime from datetime import datetime


from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
@@ -61,6 +60,9 @@ class ExplainJobEncap(ExplainDataEncap):
info["train_id"] = dir_info["relative_path"] info["train_id"] = dir_info["relative_path"]
info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT) info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT)
info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT) info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT)
info["saliency_map"] = dir_info["saliency_map"]
info["hierarchical_occlusion"] = dir_info["hierarchical_occlusion"]

return info return info


@classmethod @classmethod
@@ -79,7 +81,7 @@ class ExplainJobEncap(ExplainDataEncap):
"""Convert ExplainJob's meta-data to jsonable info object.""" """Convert ExplainJob's meta-data to jsonable info object."""
info = cls._job_2_info(job) info = cls._job_2_info(job)
info["sample_count"] = job.sample_count info["sample_count"] = job.sample_count
info["classes"] = copy.deepcopy(job.all_classes)
info["classes"] = [item for item in job.all_classes if item['sample_count'] > 0]
saliency_info = dict() saliency_info = dict()
if job.min_confidence is None: if job.min_confidence is None:
saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE


+ 87
- 0
mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py View File

@@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Hierarchical Occlusion encapsulator."""

from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap


class HierarchicalOcclusionEncap(ExplanationEncap):
"""Hierarchical occlusion encapsulator."""

def query_hierarchical_occlusion(self,
train_id,
labels,
limit,
offset,
sorted_name,
sorted_type,
prediction_types=None
):
"""
Query hierarchical occlusion results.

Args:
train_id (str): Job ID.
labels (list[str]): Label filter.
limit (int): Maximum number of items to be returned.
offset (int): Page offset.
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction types filter.

Returns:
tuple[int, list[dict]], total number of samples after filtering and list of sample results.
"""
job = self.job_manager.get_job(train_id)
if job is None:
raise TrainJobNotExistError(train_id)

samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
query_type="hoc_layers")
sample_infos = []
obj_offset = offset * limit
count = len(samples)
end = count
if obj_offset + limit < end:
end = obj_offset + limit
for i in range(obj_offset, end):
sample = samples[i]
sample_infos.append(self._touch_sample(sample, job))

return count, sample_infos

def _touch_sample(self, sample, job):
"""
Final edit on single sample info.

Args:
sample (dict): Sample info.
job (ExplainManager): Explain job.

Returns:
dict, the edited sample info.
"""
sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample["image"], "original")
for inference_item in sample_cp["inferences"]:
new_list = []
for idx, hoc_layer in enumerate(inference_item["hoc_layers"]):
hoc_layer["outcome"] = self._get_image_url(job.train_id,
f"{sample['id']}_{inference_item['label']}_{idx}.jpg",
"outcome")
new_list.append(hoc_layer)
inference_item["hoc_layers"] = new_list
return sample_cp

+ 16
- 88
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -14,58 +14,13 @@
# ============================================================================ # ============================================================================
"""Saliency map encapsulator.""" """Saliency map encapsulator."""


import copy

from mindinsight.utils.exceptions import ParamValueError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap




def _sort_key_min_confidence(sample):
"""Samples sort key by the min. confidence."""
min_confidence = float("+inf")
for inference in sample["inferences"]:
if inference["confidence"] < min_confidence:
min_confidence = inference["confidence"]
return min_confidence


def _sort_key_max_confidence(sample):
"""Samples sort key by the max. confidence."""
max_confidence = float("-inf")
for inference in sample["inferences"]:
if inference["confidence"] > max_confidence:
max_confidence = inference["confidence"]
return max_confidence


def _sort_key_min_confidence_sd(sample):
"""Samples sort key by the min. confidence_sd."""
min_confidence_sd = float("+inf")
for inference in sample["inferences"]:
confidence_sd = inference.get("confidence_sd", float("+inf"))
if confidence_sd < min_confidence_sd:
min_confidence_sd = confidence_sd
return min_confidence_sd


def _sort_key_max_confidence_sd(sample):
"""Samples sort key by the max. confidence_sd."""
max_confidence_sd = float("-inf")
for inference in sample["inferences"]:
confidence_sd = inference.get("confidence_sd", float("-inf"))
if confidence_sd > max_confidence_sd:
max_confidence_sd = confidence_sd
return max_confidence_sd


class SaliencyEncap(ExplainDataEncap):
class SaliencyEncap(ExplanationEncap):
"""Saliency map encapsulator.""" """Saliency map encapsulator."""


def __init__(self, image_url_formatter, *args, **kwargs):
super().__init__(*args, **kwargs)
self._image_url_formatter = image_url_formatter

def query_saliency_maps(self, def query_saliency_maps(self,
train_id, train_id,
labels, labels,
@@ -73,55 +28,32 @@ class SaliencyEncap(ExplainDataEncap):
limit, limit,
offset, offset,
sorted_name, sorted_name,
sorted_type):
sorted_type,
prediction_types=None):
""" """
Query saliency maps. Query saliency maps.
Args: Args:
train_id (str): Job ID. train_id (str): Job ID.
labels (list[str]): Label filter. labels (list[str]): Label filter.
explainers (list[str]): Explainers of saliency maps to be shown. explainers (list[str]): Explainers of saliency maps to be shown.
limit (int): Max. no. of items to be returned.
limit (int): Maximum number of items to be returned.
offset (int): Page offset. offset (int): Page offset.
sorted_name (str): Field to be sorted. sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'. sorted_type (str): Sorting order, 'ascending' or 'descending'.
prediction_types (list[str]): Prediction types filter. Default: None.


Returns: Returns:
tuple[int, list[dict]], total no. of samples after filtering and
list of sample result.
tuple[int, list[dict]], total number of samples after filtering and list of sample result.
""" """
job = self.job_manager.get_job(train_id) job = self.job_manager.get_job(train_id)
if job is None: if job is None:
raise TrainJobNotExistError(train_id) raise TrainJobNotExistError(train_id)


samples = copy.deepcopy(job.get_all_samples())
if labels:
filtered = []
for sample in samples:
infer_labels = [inference["label"] for inference in sample["inferences"]]
for infer_label in infer_labels:
if infer_label in labels:
filtered.append(sample)
break
samples = filtered

reverse = sorted_type == "descending"
if sorted_name == "confidence":
if reverse:
samples.sort(key=_sort_key_max_confidence, reverse=reverse)
else:
samples.sort(key=_sort_key_min_confidence, reverse=reverse)
elif sorted_name == "uncertainty":
if not job.uncertainty_enabled:
raise ParamValueError("Uncertainty is not enabled, sorted_name cannot be 'uncertainty'")
if reverse:
samples.sort(key=_sort_key_max_confidence_sd, reverse=reverse)
else:
samples.sort(key=_sort_key_min_confidence_sd, reverse=reverse)
elif sorted_name != "":
raise ParamValueError("sorted_name")
samples = self._query_samples(job, labels, sorted_name, sorted_type, prediction_types,
query_type="saliency_maps")


sample_infos = [] sample_infos = []
obj_offset = offset*limit
obj_offset = offset * limit
count = len(samples) count = len(samples)
end = count end = count
if obj_offset + limit < end: if obj_offset + limit < end:
@@ -139,11 +71,13 @@ class SaliencyEncap(ExplainDataEncap):
sample (dict): Sample info. sample (dict): Sample info.
job (ExplainJob): Explain job. job (ExplainJob): Explain job.
explainers (list[str]): Explainer names. explainers (list[str]): Explainer names.

Returns: Returns:
dict, the edited sample info. dict, the edited sample info.
""" """
sample["image"] = self._get_image_url(job.train_id, sample['image'], "original")
for inference in sample["inferences"]:
sample_cp = sample.copy()
sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original")
for inference in sample_cp["inferences"]:
new_list = [] new_list = []
for saliency_map in inference["saliency_maps"]: for saliency_map in inference["saliency_maps"]:
if explainers and saliency_map["explainer"] not in explainers: if explainers and saliency_map["explainer"] not in explainers:
@@ -151,10 +85,4 @@ class SaliencyEncap(ExplainDataEncap):
saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay")
new_list.append(saliency_map) new_list.append(saliency_map)
inference["saliency_maps"] = new_list inference["saliency_maps"] = new_list
return sample

def _get_image_url(self, train_id, image_path, image_type):
"""Returns image's url."""
if self._image_url_formatter is None:
return image_path
return self._image_url_formatter(train_id, image_path, image_type)
return sample_cp

+ 111
- 98
mindinsight/explainer/manager/explain_loader.py View File

@@ -14,14 +14,13 @@
# ============================================================================ # ============================================================================
"""ExplainLoader.""" """ExplainLoader."""


from collections import defaultdict
from enum import Enum

import math import math
import os import os
import re import re
import threading import threading
from collections import defaultdict
from datetime import datetime from datetime import datetime
from enum import Enum
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union


from mindinsight.datavisual.common.exceptions import TrainJobNotExistError from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
@@ -44,6 +43,7 @@ _SAMPLE_FIELD_NAMES = [
ExplainFieldsEnum.GROUND_TRUTH_LABEL, ExplainFieldsEnum.GROUND_TRUTH_LABEL,
ExplainFieldsEnum.INFERENCE, ExplainFieldsEnum.INFERENCE,
ExplainFieldsEnum.EXPLANATION, ExplainFieldsEnum.EXPLANATION,
ExplainFieldsEnum.HIERARCHICAL_OCCLUSION
] ]




@@ -75,10 +75,10 @@ class ExplainLoader:
'create_time': os.stat(summary_dir).st_ctime, 'create_time': os.stat(summary_dir).st_ctime,
'update_time': os.stat(summary_dir).st_mtime, 'update_time': os.stat(summary_dir).st_mtime,
'query_time': os.stat(summary_dir).st_ctime, 'query_time': os.stat(summary_dir).st_ctime,
'uncertainty_enabled': False,
'uncertainty_enabled': False
} }
self._samples = defaultdict(dict) self._samples = defaultdict(dict)
self._metadata = {'explainers': [], 'metrics': [], 'labels': []}
self._metadata = {'explainers': [], 'metrics': [], 'labels': [], 'min_confidence': 0.5}
self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)}


self._status = _LoaderStatus.STOP.value self._status = _LoaderStatus.STOP.value
@@ -91,25 +91,19 @@ class ExplainLoader:


Returns: Returns:
list[dict], a list of dict, each dict contains: list[dict], a list of dict, each dict contains:
- id (int): label id
- label (str): label name
- sample_count (int): number of samples for each label

- id (int): Label id.
- label (str): Label name.
- sample_count (int): Number of samples for each label.
""" """
sample_count_per_label = defaultdict(int) sample_count_per_label = defaultdict(int)
samples_copy = self._samples.copy()
for sample in samples_copy.values():
if sample.get('image', False) and sample.get('ground_truth_label', False):
for label in sample['ground_truth_label']:
for sample in self._samples.values():
if sample.get('image') and (sample.get('ground_truth_label') or sample.get('predicted_label')):
for label in set(sample['ground_truth_label'] + sample['predicted_label']):
sample_count_per_label[label] += 1 sample_count_per_label[label] += 1


all_classes_return = []
for label_id, label_name in enumerate(self._metadata['labels']):
single_info = {
'id': label_id,
'label': label_name,
'sample_count': sample_count_per_label[label_id]
}
all_classes_return.append(single_info)
all_classes_return = [{'id': label_id, 'label': label_name, 'sample_count': sample_count_per_label[label_id]}
for label_id, label_name in enumerate(self._metadata['labels'])]
return all_classes_return return all_classes_return


@property @property
@@ -166,19 +160,23 @@ class ExplainLoader:


Returns: Returns:
list[dict], A list of evaluation results of each explainer. Each item contains: list[dict], A list of evaluation results of each explainer. Each item contains:

- explainer (str): Name of evaluated explainer. - explainer (str): Name of evaluated explainer.
- evaluations (list[dict]): A list of evaluation results by different metrics. - evaluations (list[dict]): A list of evaluation results by different metrics.
- class_scores (list[dict]): A list of evaluation results on different labels. - class_scores (list[dict]): A list of evaluation results on different labels.


Each item in the evaluations contains: Each item in the evaluations contains:

- metric (str): name of metric method - metric (str): name of metric method
- score (float): evaluation result - score (float): evaluation result


Each item in the class_scores contains: Each item in the class_scores contains:

- label (str): Name of label - label (str): Name of label
- evaluations (list[dict]): A list of evaluation results on different labels by different metrics. - evaluations (list[dict]): A list of evaluation results on different labels by different metrics.


Each item in evaluations contains: Each item in evaluations contains:

- metric (str): Name of metric method - metric (str): Name of metric method
- score (float): Evaluation scores of explainer on specific label by the metric. - score (float): Evaluation scores of explainer on specific label by the metric.
""" """
@@ -215,7 +213,7 @@ class ExplainLoader:
@property @property
def min_confidence(self) -> Optional[float]: def min_confidence(self) -> Optional[float]:
"""Return minimum confidence used to filter the predicted labels.""" """Return minimum confidence used to filter the predicted labels."""
return None
return self._metadata['min_confidence']


@property @property
def sample_count(self) -> int: def sample_count(self) -> int:
@@ -227,19 +225,17 @@ class ExplainLoader:


Return: Return:
int, total number of available samples in the loading job. int, total number of available samples in the loading job.

""" """
sample_count = 0 sample_count = 0
samples_copy = self._samples.copy()
for sample in samples_copy.values():
if sample.get('image', False) and sample.get('ground_truth_label', False):
for sample in self._samples.values():
if sample.get('image', False):
sample_count += 1 sample_count += 1
return sample_count return sample_count


@property @property
def samples(self) -> List[Dict]: def samples(self) -> List[Dict]:
"""Return the information of all samples in the job.""" """Return the information of all samples in the job."""
return self.get_all_samples()
return self._samples


@property @property
def train_id(self) -> str: def train_id(self) -> str:
@@ -298,6 +294,7 @@ class ExplainLoader:
self._clear_job() self._clear_job()
if event_dict: if event_dict:
self._import_data_from_event(event_dict) self._import_data_from_event(event_dict)
self._reform_sample_info()


@property @property
def status(self): def status(self):
@@ -317,59 +314,19 @@ class ExplainLoader:


def get_all_samples(self) -> List[Dict]: def get_all_samples(self) -> List[Dict]:
""" """
Return a list of sample information cachced in the explain job
Return a list of sample information cached in the explain job


Returns: Returns:
sample_list (List[SampleObj]): a list of sample objects, each object sample_list (List[SampleObj]): a list of sample objects, each object
consists of: consists of:


- id (int): sample id
- name (str): basename of image
- labels (list[str]): list of labels
- inferences list[dict])
- id (int): Sample id.
- name (str): Basename of image.
- inferences (list[dict]): List of inferences for all labels.
""" """
returned_samples = []
samples_copy = self._samples.copy()
for sample_id, sample_info in samples_copy.items():
if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False):
continue
returned_sample = {
'id': sample_id,
'name': str(sample_id),
'image': sample_info['image'],
'labels': [self._metadata['labels'][i] for i in sample_info['ground_truth_label']],
}

# Check whether the sample has valid label-prob pairs.
if not ExplainLoader._is_inference_valid(sample_info):
continue

inferences = {}
for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'],
sample_info['ground_truth_prob'] + sample_info['predicted_prob']):
inferences[label] = {
'label': self._metadata['labels'][label],
'confidence': _round(prob),
'saliency_maps': []
}

if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']:
for label, std, low, high in zip(
sample_info['ground_truth_label'] + sample_info['predicted_label'],
sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'],
sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'],
sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi']
):
inferences[label]['confidence_sd'] = _round(std)
inferences[label]['confidence_itl95'] = [_round(low), _round(high)]

for explainer, label_heatmap_path_dict in sample_info['explanation'].items():
for label, heatmap_path in label_heatmap_path_dict.items():
if label in inferences:
inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path})

returned_sample['inferences'] = list(inferences.values())
returned_samples.append(returned_sample)
returned_samples = [{'id': sample_id, 'name': info['name'], 'image': info['image'],
'inferences': list(info['inferences'].values())} for sample_id, info in
self._samples.items() if info.get('image', False)]
return returned_samples return returned_samples


def _import_data_from_event(self, event_dict: Dict): def _import_data_from_event(self, event_dict: Dict):
@@ -461,30 +418,29 @@ class ExplainLoader:
- predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels.
- explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary
of saliency maps. The structure of explanations demonstrates below: of saliency maps. The structure of explanations demonstrates below:

{ {
explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...},
... ...
} }
- hierarchical_occlusion (dict): A dictionary where each label is matched to a dictionary:
{label_1: [{prob: layer1_prob, bbox: []}, {prob: layer2_prob, bbox: []}],
label_2:
}
""" """
if getattr(sample, 'sample_id', None) is None: if getattr(sample, 'sample_id', None) is None:
raise ParamValueError('sample_event has no sample_id') raise ParamValueError('sample_event has no sample_id')
sample_id = sample.sample_id sample_id = sample.sample_id
samples_copy = self._samples.copy()
if sample_id not in samples_copy:
if sample_id not in self._samples:
self._samples[sample_id] = { self._samples[sample_id] = {
'id': sample_id,
'name': str(sample_id),
'image': sample.image_path,
'ground_truth_label': [], 'ground_truth_label': [],
'ground_truth_prob': [],
'ground_truth_prob_sd': [],
'ground_truth_prob_itl95_low': [],
'ground_truth_prob_itl95_hi': [],
'predicted_label': [], 'predicted_label': [],
'predicted_prob': [],
'predicted_prob_sd': [],
'predicted_prob_itl95_low': [],
'predicted_prob_itl95_hi': [],
'explanation': defaultdict(dict)
'inferences': defaultdict(dict),
'explanation': defaultdict(dict),
'hierarchical_occlusion': defaultdict(dict)
} }


if sample.image_path: if sample.image_path:
@@ -492,27 +448,66 @@ class ExplainLoader:


for tag in _SAMPLE_FIELD_NAMES: for tag in _SAMPLE_FIELD_NAMES:
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
if not self._samples[sample_id]['ground_truth_label']:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
elif tag == ExplainFieldsEnum.INFERENCE: elif tag == ExplainFieldsEnum.INFERENCE:
self._import_inference_from_event(sample, sample_id) self._import_inference_from_event(sample, sample_id)
else:
elif tag == ExplainFieldsEnum.EXPLANATION:
self._import_explanation_from_event(sample, sample_id) self._import_explanation_from_event(sample, sample_id)
elif tag == ExplainFieldsEnum.HIERARCHICAL_OCCLUSION:
self._import_hoc_from_event(sample, sample_id)

def _reform_sample_info(self):
"""Reform the sample info."""
for _, sample_info in self._samples.items():
inferences = sample_info['inferences']
res_dict = defaultdict(list)
for explainer, label_heatmap_path_dict in sample_info['explanation'].items():
for label, heatmap_path in label_heatmap_path_dict.items():
res_dict[label].append({'explainer': explainer, 'overlay': heatmap_path})

for label, item in inferences.items():
item['saliency_maps'] = res_dict[label]

for label, item in sample_info['hierarchical_occlusion'].items():
inferences[label]['hoc_layers'] = item['hoc_layers']


def _import_inference_from_event(self, event, sample_id): def _import_inference_from_event(self, event, sample_id):
"""Parse the inference event.""" """Parse the inference event."""
inference = event.inference inference = event.inference
self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob))
self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd))
self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low))
self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi))
self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label))
self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob))
self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd))
self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low))
self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi))

if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']:
if inference.ground_truth_prob_sd or inference.predicted_prob_sd:
self._loader_info['uncertainty_enabled'] = True self._loader_info['uncertainty_enabled'] = True
if not self._samples[sample_id]['predicted_label']:
self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label))
if not self._samples[sample_id]['inferences']:
inferences = {}
for label, prob in zip(list(event.ground_truth_label) + list(inference.predicted_label),
list(inference.ground_truth_prob) + list(inference.predicted_prob)):
inferences[label] = {
'label': self._metadata['labels'][label],
'confidence': _round(prob),
'saliency_maps': [],
'hoc_layers': {},
}
if not event.ground_truth_label:
inferences[label]['prediction_type'] = None
else:
if prob < self.min_confidence:
inferences[label]['prediction_type'] = 'FN'
elif label in event.ground_truth_label:
inferences[label]['prediction_type'] = 'TP'
else:
inferences[label]['prediction_type'] = 'FP'
if self._loader_info['uncertainty_enabled']:
for label, std, low, high in zip(
list(event.ground_truth_label) + list(inference.predicted_label),
list(inference.ground_truth_prob_sd) + list(inference.predicted_prob_sd),
list(inference.ground_truth_prob_itl95_low) + list(inference.predicted_prob_itl95_low),
list(inference.ground_truth_prob_itl95_hi) + list(inference.predicted_prob_itl95_hi)):
inferences[label]['confidence_sd'] = _round(std)
inferences[label]['confidence_itl95'] = [_round(low), _round(high)]

self._samples[sample_id]['inferences'] = inferences


def _import_explanation_from_event(self, event, sample_id): def _import_explanation_from_event(self, event, sample_id):
"""Parse the explanation event.""" """Parse the explanation event."""
@@ -525,6 +520,24 @@ class ExplainLoader:
label = explanation_item.label label = explanation_item.label
sample_explanation[explainer][label] = explanation_item.heatmap_path sample_explanation[explainer][label] = explanation_item.heatmap_path


def _import_hoc_from_event(self, event, sample_id):
"""Parse the mango event."""
sample_hoc = self._samples[sample_id]['hierarchical_occlusion']
if event.hierarchical_occlusion:
for hoc_item in event.hierarchical_occlusion:
label = hoc_item.label
sample_hoc[label] = {}
sample_hoc[label]['label'] = label
sample_hoc[label]['mask'] = hoc_item.mask
sample_hoc[label]['confidence'] = self._samples[sample_id]['inferences'][label]['confidence']
sample_hoc[label]['hoc_layers'] = []
for hoc_layer in hoc_item.layer:
sample_hoc_dict = {'confidence': hoc_layer.prob}
box_lst = list(hoc_layer.box)
box = [box_lst[i: i + 4] for i in range(0, len(hoc_layer.box), 4)]
sample_hoc_dict['boxes'] = box
sample_hoc[label]['hoc_layers'].append(sample_hoc_dict)

def _clear_job(self): def _clear_job(self):
"""Clear the cached data and update the time info of the loader.""" """Clear the cached data and update the time info of the loader."""
self._samples.clear() self._samples.clear()


+ 1
- 0
mindinsight/explainer/manager/explain_manager.py View File

@@ -114,6 +114,7 @@ class ExplainManager:
Returns: Returns:
tuple[total, directories], total indicates the overall number of explain directories and directories tuple[total, directories], total indicates the overall number of explain directories and directories
indicate list of summary directory info including the following attributes. indicate list of summary directory info including the following attributes.

- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./". starting with "./".
- create_time (datetime): Creation time of summary file. - create_time (datetime): Creation time of summary file.


+ 3
- 1
mindinsight/explainer/manager/explain_parser.py View File

@@ -30,6 +30,7 @@ from mindinsight.utils.exceptions import UnknownError
HEADER_SIZE = 8 HEADER_SIZE = 8
CRC_STR_SIZE = 4 CRC_STR_SIZE = 4
MAX_EVENT_STRING = 500000000 MAX_EVENT_STRING = 500000000

BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status'])
MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status'])
InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob',
@@ -42,7 +43,7 @@ InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob',
'predicted_prob_itl95_low', 'predicted_prob_itl95_low',
'predicted_prob_itl95_hi']) 'predicted_prob_itl95_hi'])
SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference',
'explanation', 'status'])
'explanation', 'hierarchical_occlusion', 'status'])




class ExplainParser(_SummaryParser): class ExplainParser(_SummaryParser):
@@ -193,6 +194,7 @@ class ExplainParser(_SummaryParser):
ground_truth_label=tensor_event_value.ground_truth_label, ground_truth_label=tensor_event_value.ground_truth_label,
inference=inference, inference=inference,
explanation=tensor_event_value.explanation, explanation=tensor_event_value.explanation,
hierarchical_occlusion=tensor_event_value.hoc,
status=tensor_event_value.status status=tensor_event_value.status
) )
return sample_data return sample_data


+ 3
- 1
tests/ut/explainer/encapsulator/mock_explain_manager.py View File

@@ -105,7 +105,9 @@ class MockExplainManager:
{ {
"relative_path": "./mock_job_1", "relative_path": "./mock_job_1",
"create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT), "create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT),
"update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT)
"update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT),
"saliency_map": True,
"hierarchical_occlusion": True
} }
] ]
return 1, job_list return 1, job_list


+ 3
- 1
tests/ut/explainer/encapsulator/test_explain_job_encap.py View File

@@ -31,7 +31,9 @@ class TestExplainJobEncap:
{ {
"train_id": "./mock_job_1", "train_id": "./mock_job_1",
"create_time": "2020-10-01 20:21:23", "create_time": "2020-10-01 20:21:23",
"update_time": "2020-10-01 20:21:23"
"update_time": "2020-10-01 20:21:23",
"saliency_map": True,
"hierarchical_occlusion": True
} }
]) ])
assert job_list == expected_result assert job_list == expected_result


+ 0
- 4
tests/ut/explainer/manager/test_explain_loader.py View File

@@ -24,10 +24,6 @@ from mindinsight.explainer.manager.explain_loader import _LoaderStatus
from mindinsight.explainer.manager.explain_parser import ExplainParser from mindinsight.explainer.manager.explain_parser import ExplainParser




def abc():
FileHandler.is_file('aaa')
print('after')



class TestExplainLoader: class TestExplainLoader:
"""Test explain loader class.""" """Test explain loader class."""


Loading…
Cancel
Save