| @@ -40,8 +40,52 @@ URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX | |||
| BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | |||
| def _validate_type(param, name, expected_types): | |||
| """ | |||
| Common function to validate type. | |||
| Args: | |||
| param (object): Parameter to be validated. | |||
| name (str): Name of the parameter. | |||
| expected_types (type, tuple[type]): Expected type(s) of param. | |||
| Raises: | |||
| ParamTypeError: When param is not an instance of expected_types. | |||
| """ | |||
| if not isinstance(param, expected_types): | |||
| raise ParamTypeError(name, expected_types) | |||
| def _validate_value(param, name, expected_values): | |||
| """ | |||
| Common function to validate values of param. | |||
| Args: | |||
| param (object): Parameter to be validated. | |||
| name (str): Name of the parameter. | |||
| expected_values (tuple) : Expected values of param. | |||
| Raises: | |||
| ParamValueError: When param is not in expected_values. | |||
| """ | |||
| if param not in expected_values: | |||
| raise ParamValueError(f"Valid options for {name} are {expected_values}, but got {param}.") | |||
| def _image_url_formatter(train_id, image_path, image_type): | |||
| """Returns image url.""" | |||
| """ | |||
| Returns image url. | |||
| Args: | |||
| train_id (str): Id that specifies explain job. | |||
| image_path (str): Local path or unique string that specifies the image for query. | |||
| image_type (str): Image query type. | |||
| Returns: | |||
| str, url string for image query. | |||
| """ | |||
| data = { | |||
| "train_id": train_id, | |||
| "path": image_path, | |||
| @@ -69,38 +113,48 @@ def _read_post_request(post_request): | |||
| def _get_query_sample_parameters(data): | |||
| """Get parameter for query.""" | |||
| """ | |||
| Get parameter for query. | |||
| Args: | |||
| data (dict): Dict that contains request info. | |||
| Returns: | |||
| dict, key-value pairs to call backend query functions. | |||
| Raises: | |||
| ParamMissError: If train_id info is not in the request. | |||
| ParamTypeError: If certain key is not in the expected type in the request. | |||
| ParamValueError: If certain key does not have the expected value in the request. | |||
| """ | |||
| train_id = data.get("train_id") | |||
| if train_id is None: | |||
| raise ParamMissError('train_id') | |||
| labels = data.get("labels") | |||
| if labels is not None and not isinstance(labels, list): | |||
| raise ParamTypeError("labels", (list, None)) | |||
| if labels is not None: | |||
| _validate_type(labels, "labels", list) | |||
| if labels: | |||
| for item in labels: | |||
| if not isinstance(item, str): | |||
| raise ParamTypeError("element of labels", str) | |||
| _validate_type(item, "element of labels", str) | |||
| limit = data.get("limit", 10) | |||
| limit = Validation.check_limit(limit, min_value=1, max_value=100) | |||
| offset = data.get("offset", 0) | |||
| offset = Validation.check_offset(offset=offset) | |||
| sorted_name = data.get("sorted_name", "") | |||
| _validate_value(sorted_name, "sorted_name", ('', 'confidence', 'uncertainty')) | |||
| sorted_type = data.get("sorted_type", "descending") | |||
| if sorted_name not in ("", "confidence", "uncertainty"): | |||
| raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") | |||
| if sorted_type not in ("ascending", "descending"): | |||
| raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") | |||
| _validate_value(sorted_type, "sorted_type", ("ascending", "descending")) | |||
| prediction_types = data.get("prediction_types") | |||
| if prediction_types is not None and not isinstance(prediction_types, list): | |||
| raise ParamTypeError("prediction_types", (list, None)) | |||
| if prediction_types is not None: | |||
| _validate_type(prediction_types, "element of labels", list) | |||
| if prediction_types: | |||
| for item in prediction_types: | |||
| if item not in ['TP', 'FN', 'FP']: | |||
| raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.") | |||
| _validate_value(item, "element of prediction_types", ('TP', 'FN', 'FP')) | |||
| query_kwarg = {"train_id": train_id, | |||
| "labels": labels, | |||
| @@ -114,7 +168,17 @@ def _get_query_sample_parameters(data): | |||
| @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) | |||
| def query_explain_jobs(): | |||
| """Query explain jobs.""" | |||
| """ | |||
| Query explain jobs. | |||
| Returns: | |||
| Response, contains dict that stores base directory, total number of jobs and their detailed job metadata. | |||
| Raises: | |||
| ParamMissError: If train_id info is not in the request. | |||
| ParamTypeError: If one of (offset, limit) is not integer in the request. | |||
| ParamValueError: If one of (offset, limit) does not have the expected value in the request. | |||
| """ | |||
| offset = request.args.get("offset", default=0) | |||
| limit = request.args.get("limit", default=10) | |||
| offset = Validation.check_offset(offset=offset) | |||
| @@ -132,7 +196,15 @@ def query_explain_jobs(): | |||
| @BLUEPRINT.route("/explainer/explain-job", methods=["GET"]) | |||
| def query_explain_job(): | |||
| """Query explain job meta-data.""" | |||
| """ | |||
| Query explain job meta-data. | |||
| Returns: | |||
| Response, contains dict that stores metadata of the requested job. | |||
| Raises: | |||
| ParamMissError: If train_id info is not in the request. | |||
| """ | |||
| train_id = get_train_id(request) | |||
| if train_id is None: | |||
| raise ParamMissError("train_id") | |||
| @@ -143,7 +215,16 @@ def query_explain_job(): | |||
| @BLUEPRINT.route("/explainer/saliency", methods=["POST"]) | |||
| def query_saliency(): | |||
| """Query saliency map related results.""" | |||
| """ | |||
| Query saliency map related results. | |||
| Returns: | |||
| Response, contains dict that stores number of samples and the detailed sample info. | |||
| Raises: | |||
| ParamTypeError: If certain key is not in the expected type in the request. | |||
| ParamValueError: If certain key does not have the expected value in the request. | |||
| """ | |||
| data = _read_post_request(request) | |||
| query_kwarg = _get_query_sample_parameters(data) | |||
| explainers = data.get("explainers") | |||
| @@ -169,7 +250,16 @@ def query_saliency(): | |||
| @BLUEPRINT.route("/explainer/hoc", methods=["POST"]) | |||
| def query_hoc(): | |||
| """Query hierarchical occlusion related results.""" | |||
| """ | |||
| Query hierarchical occlusion related results. | |||
| Returns: | |||
| Response, contains dict that stores number of samples and the detailed sample info. | |||
| Raises: | |||
| ParamTypeError: If certain key is not in the expected type in the request. | |||
| ParamValueError: If certain key does not have the expected value in the request. | |||
| """ | |||
| data = _read_post_request(request) | |||
| query_kwargs = _get_query_sample_parameters(data) | |||
| @@ -193,7 +283,15 @@ def query_hoc(): | |||
| @BLUEPRINT.route("/explainer/evaluation", methods=["GET"]) | |||
| def query_evaluation(): | |||
| """Query saliency explainer evaluation scores.""" | |||
| """ | |||
| Query saliency explainer evaluation scores. | |||
| Returns: | |||
| Response, contains dict that stores evaluation scores. | |||
| Raises: | |||
| ParamMissError: If train_id info is not in the request. | |||
| """ | |||
| train_id = get_train_id(request) | |||
| if train_id is None: | |||
| raise ParamMissError("train_id") | |||
| @@ -206,7 +304,12 @@ def query_evaluation(): | |||
| @BLUEPRINT.route("/explainer/image", methods=["GET"]) | |||
| def query_image(): | |||
| """Query image.""" | |||
| """ | |||
| Query image. | |||
| Returns: | |||
| bytes, image binary content for UI to demonstrate. | |||
| """ | |||
| train_id = get_train_id(request) | |||
| if train_id is None: | |||
| raise ParamMissError("train_id") | |||
| @@ -230,6 +333,6 @@ def init_module(app): | |||
| Init module entry. | |||
| Args: | |||
| app: the application obj. | |||
| app (flask.app): The application obj. | |||
| """ | |||
| app.register_blueprint(BLUEPRINT) | |||
| @@ -220,10 +220,7 @@ class SummaryWatcher: | |||
| summary_dict[relative_path].update(job_dict) | |||
| if summary_dict[relative_path]['create_time'] < ctime: | |||
| summary_dict[relative_path].update({ | |||
| 'create_time': ctime, | |||
| 'update_time': mtime, | |||
| }) | |||
| summary_dict[relative_path].update({'create_time': ctime, 'update_time': mtime}) | |||
| job_dict = _get_explain_job_info(summary_base_dir, relative_path, timestamp) | |||
| summary_dict[relative_path].update(job_dict) | |||
| @@ -243,12 +240,10 @@ class SummaryWatcher: | |||
| if not is_find: | |||
| return | |||
| profiler = { | |||
| 'directory': os.path.join('.', entry.name), | |||
| 'create_time': ctime, | |||
| 'update_time': mtime, | |||
| "profiler_type": profiler_type | |||
| } | |||
| profiler = {'directory': os.path.join('.', entry.name), | |||
| 'create_time': ctime, | |||
| 'update_time': mtime, | |||
| "profiler_type": profiler_type} | |||
| if relative_path in summary_dict: | |||
| summary_dict[relative_path]['profiler'] = profiler | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -50,3 +50,16 @@ class CacheStatus(enum.Enum): | |||
| NOT_IN_CACHE = "NOT_IN_CACHE" | |||
| CACHING = "CACHING" | |||
| CACHED = "CACHED" | |||
| class ExplanationKeys(enum.Enum): | |||
| """Query type enums.""" | |||
| HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose | |||
| SALIENCY = "saliency_maps" | |||
| class ImageQueryTypes(enum.Enum): | |||
| """Image query type enums.""" | |||
| ORIGINAL = 'original' # Query for the original image | |||
| OUTCOME = 'outcome' # Query for outcome of HOC explanation | |||
| OVERLAY = 'overlay' # Query for saliency maps overlay | |||
| @@ -63,7 +63,7 @@ def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=Fals | |||
| Args: | |||
| image (PIL.Image): The input image in RGB mode. | |||
| mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string | |||
| e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object. | |||
| e.g. 'gaussian:9', a single, grey scale intensity [0, 255], an RBG tuple or a PIL Image object. | |||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||
| by_masking (bool): Whether to use masking method. Default: False. | |||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||
| @@ -99,7 +99,7 @@ def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False): | |||
| Args: | |||
| image (PIL.Image): The input image. | |||
| mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey | |||
| scale intensity [0, 255], a RBG tuple or a PIL Image. | |||
| scale intensity [0, 255], an RBG tuple or a PIL Image. | |||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||
| @@ -132,7 +132,7 @@ def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False): | |||
| Args: | |||
| image (PIL.Image): The input image. | |||
| mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey | |||
| scale intensity [0, 255], a RBG tuple or a PIL Image. | |||
| scale intensity [0, 255], an RBG tuple or a PIL Image. | |||
| edit_steps (list[EditStep]): Edit steps to be drawn. | |||
| inplace (bool): True to draw on the input image, otherwise draw on a cloned image. | |||
| @@ -21,6 +21,7 @@ import numpy as np | |||
| from PIL import Image | |||
| from mindinsight.datavisual.common.exceptions import ImageNotExistError | |||
| from mindinsight.explainer.common.enums import ImageQueryTypes | |||
| from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap | |||
| from mindinsight.utils.exceptions import FileSystemPermissionError | |||
| @@ -63,41 +64,10 @@ class DatafileEncap(ExplainDataEncap): | |||
| image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'. | |||
| Returns: | |||
| bytes, image binary. | |||
| bytes, image binary content for UI to demonstrate. | |||
| """ | |||
| if image_type == "outcome": | |||
| sample_id, label, layer = image_path.strip(".jpg").split("_") | |||
| layer = int(layer) | |||
| job = self.job_manager.get_job(train_id) | |||
| samples = job.samples | |||
| label_idx = job.labels.index(label) | |||
| chosen_sample = samples[int(sample_id)] | |||
| original_path_image = chosen_sample['image'] | |||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), | |||
| original_path_image) | |||
| if self._is_forbidden(abs_image_path): | |||
| raise FileSystemPermissionError("Forbidden.") | |||
| try: | |||
| image = Image.open(abs_image_path) | |||
| except FileNotFoundError: | |||
| raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||
| except PermissionError: | |||
| raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||
| except OSError: | |||
| raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | |||
| edit_steps = [] | |||
| boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] | |||
| mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] | |||
| for box in boxes: | |||
| edit_steps.append(EditStep(layer, *box)) | |||
| image_cp = pil_apply_edit_steps(image, mask, edit_steps) | |||
| buffer = io.BytesIO() | |||
| image_cp.save(buffer, format=_PNG_FORMAT) | |||
| return buffer.getvalue() | |||
| if image_type == ImageQueryTypes.OUTCOME.value: | |||
| return self._get_hoc_image(image_path, train_id) | |||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, | |||
| _clean_train_id_b4_join(train_id), | |||
| @@ -108,7 +78,7 @@ class DatafileEncap(ExplainDataEncap): | |||
| try: | |||
| if image_type != "overlay": | |||
| if image_type != ImageQueryTypes.OVERLAY.value: | |||
| # no need to convert | |||
| with open(abs_image_path, "rb") as fp: | |||
| return fp.read() | |||
| @@ -153,3 +123,41 @@ class DatafileEncap(ExplainDataEncap): | |||
| base_dir = os.path.realpath(self.job_manager.summary_base_dir) | |||
| path = os.path.realpath(path) | |||
| return not path.startswith(base_dir) | |||
| def _get_hoc_image(self, image_path, train_id): | |||
| """Get hoc image for image data demonstration in UI.""" | |||
| sample_id, label, layer = image_path.strip(".jpg").split("_") | |||
| layer = int(layer) | |||
| job = self.job_manager.get_job(train_id) | |||
| samples = job.samples | |||
| label_idx = job.labels.index(label) | |||
| chosen_sample = samples[int(sample_id)] | |||
| original_path_image = chosen_sample['image'] | |||
| abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), | |||
| original_path_image) | |||
| if self._is_forbidden(abs_image_path): | |||
| raise FileSystemPermissionError("Forbidden.") | |||
| image_type = ImageQueryTypes.OUTCOME.value | |||
| try: | |||
| image = Image.open(abs_image_path) | |||
| except FileNotFoundError: | |||
| raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||
| except PermissionError: | |||
| raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") | |||
| except OSError: | |||
| raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") | |||
| edit_steps = [] | |||
| boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] | |||
| mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] | |||
| for box in boxes: | |||
| edit_steps.append(EditStep(layer, *box)) | |||
| image_cp = pil_apply_edit_steps(image, mask, edit_steps) | |||
| buffer = io.BytesIO() | |||
| image_cp.save(buffer, format=_PNG_FORMAT) | |||
| return buffer.getvalue() | |||
| @@ -15,13 +15,13 @@ | |||
| """Common explain data encapsulator base class.""" | |||
| import copy | |||
| from enum import Enum | |||
| from mindinsight.explainer.common.enums import ExplanationKeys | |||
| from mindinsight.utils.exceptions import ParamValueError | |||
| def _sort_key_min_confidence(sample, labels): | |||
| """Samples sort key by the min. confidence.""" | |||
| """Samples sort key by the minimum confidence.""" | |||
| min_confidence = float("+inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| @@ -32,7 +32,7 @@ def _sort_key_min_confidence(sample, labels): | |||
| def _sort_key_max_confidence(sample, labels): | |||
| """Samples sort key by the max. confidence.""" | |||
| """Samples sort key by the maximum confidence.""" | |||
| max_confidence = float("-inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| @@ -43,7 +43,7 @@ def _sort_key_max_confidence(sample, labels): | |||
| def _sort_key_min_confidence_sd(sample, labels): | |||
| """Samples sort key by the min. confidence_sd.""" | |||
| """Samples sort key by the minimum confidence_sd.""" | |||
| min_confidence_sd = float("+inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| @@ -55,7 +55,7 @@ def _sort_key_min_confidence_sd(sample, labels): | |||
| def _sort_key_max_confidence_sd(sample, labels): | |||
| """Samples sort key by the max. confidence_sd.""" | |||
| """Samples sort key by the maximum confidence_sd.""" | |||
| max_confidence_sd = float("-inf") | |||
| for inference in sample["inferences"]: | |||
| if labels and inference["label"] not in labels: | |||
| @@ -65,11 +65,6 @@ def _sort_key_max_confidence_sd(sample, labels): | |||
| max_confidence_sd = confidence_sd | |||
| return max_confidence_sd | |||
| class ExplanationKeys(Enum): | |||
| """Query type enums.""" | |||
| HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose | |||
| SALIENCY = "saliency_maps" | |||
| class ExplainDataEncap: | |||
| """Explain data encapsulator base class.""" | |||
| @@ -105,9 +100,9 @@ class ExplanationEncap(ExplainDataEncap): | |||
| sorted_name (str): Field to be sorted. | |||
| sorted_type (str): Sorting order, 'ascending' or 'descending'. | |||
| prediction_types (list[str]): Prediction type filter. | |||
| drop_type (str, None): When it is None, no filer will be applied. When it is 'hoc_layers', samples without | |||
| hoc explanations will be filtered out. When it is 'saliency_maps', samples without saliency explanations | |||
| will be filtered out. | |||
| drop_type (str, None): When it is None, all data will be kept. When it is 'hoc_layers', samples without | |||
| hoc explanations will be drop out. When it is 'saliency_maps', samples without saliency explanations | |||
| will be drop out. | |||
| Returns: | |||
| list[dict], samples to be queried. | |||
| @@ -29,11 +29,13 @@ class ExplainJobEncap(ExplainDataEncap): | |||
| def query_explain_jobs(self, offset, limit): | |||
| """ | |||
| Query explain job list. | |||
| Args: | |||
| offset (int): Page offset. | |||
| limit (int): Max. no. of items to be returned. | |||
| limit (int): Maximum number of items to be returned. | |||
| Returns: | |||
| tuple[int, list[Dict]], total no. of jobs and job list. | |||
| tuple[int, list[Dict]], total number of jobs and job list. | |||
| """ | |||
| total, dir_infos = self.job_manager.get_job_list(offset=offset, limit=limit) | |||
| job_infos = [self._dir_2_info(dir_info) for dir_info in dir_infos] | |||
| @@ -15,7 +15,8 @@ | |||
| """Hierarchical Occlusion encapsulator.""" | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys | |||
| from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||
| class HierarchicalOcclusionEncap(ExplanationEncap): | |||
| @@ -81,7 +82,10 @@ class HierarchicalOcclusionEncap(ExplanationEncap): | |||
| Returns: | |||
| dict, the edited sample info. | |||
| """ | |||
| sample["image"] = self._get_image_url(job.train_id, sample["image"], "original") | |||
| original = ImageQueryTypes.ORIGINAL.value | |||
| outcome = ImageQueryTypes.OUTCOME.value | |||
| sample["image"] = self._get_image_url(job.train_id, sample["image"], original) | |||
| inferences = sample["inferences"] | |||
| i = 0 # init index for while loop | |||
| while i < len(inferences): | |||
| @@ -91,9 +95,9 @@ class HierarchicalOcclusionEncap(ExplanationEncap): | |||
| continue | |||
| new_list = [] | |||
| for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]): | |||
| hoc_layer["outcome"] = self._get_image_url(job.train_id, | |||
| f"{sample['id']}_{inference_item['label']}_{idx}.jpg", | |||
| "outcome") | |||
| hoc_layer[outcome] = self._get_image_url(job.train_id, | |||
| f"{sample['id']}_{inference_item['label']}_{idx}.jpg", | |||
| outcome) | |||
| new_list.append(hoc_layer) | |||
| inference_item[ExplanationKeys.HOC.value] = new_list | |||
| i += 1 | |||
| @@ -15,7 +15,8 @@ | |||
| """Saliency map encapsulator.""" | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys | |||
| from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes | |||
| from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap | |||
| class SaliencyEncap(ExplanationEncap): | |||
| @@ -32,6 +33,7 @@ class SaliencyEncap(ExplanationEncap): | |||
| prediction_types=None): | |||
| """ | |||
| Query saliency maps. | |||
| Args: | |||
| train_id (str): Job ID. | |||
| labels (list[str]): Label filter. | |||
| @@ -65,7 +67,8 @@ class SaliencyEncap(ExplanationEncap): | |||
| def _touch_sample(self, sample, job, explainers): | |||
| """ | |||
| Final editing the sample info. | |||
| Final edit on single sample info. | |||
| Args: | |||
| sample (dict): Sample info. | |||
| job (ExplainJob): Explain job. | |||
| @@ -74,14 +77,17 @@ class SaliencyEncap(ExplanationEncap): | |||
| Returns: | |||
| dict, the edited sample info. | |||
| """ | |||
| original = ImageQueryTypes.ORIGINAL.value | |||
| overlay = ImageQueryTypes.OVERLAY.value | |||
| sample_cp = sample.copy() | |||
| sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") | |||
| sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], original) | |||
| for inference in sample_cp["inferences"]: | |||
| new_list = [] | |||
| for saliency_map in inference[ExplanationKeys.SALIENCY.value]: | |||
| if explainers and saliency_map["explainer"] not in explainers: | |||
| continue | |||
| saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") | |||
| saliency_map[overlay] = self._get_image_url(job.train_id, saliency_map[overlay], overlay) | |||
| new_list.append(saliency_map) | |||
| inference[ExplanationKeys.SALIENCY.value] = new_list | |||
| return sample_cp | |||
| @@ -270,7 +270,7 @@ class ExplainLoader: | |||
| Update the update_time manually. | |||
| Args: | |||
| new_time stamp (datetime.datetime or float): Updated time for the summary file. | |||
| new_time (datetime.datetime or float): Updated time for the summary file. | |||
| """ | |||
| if isinstance(new_time, datetime): | |||
| self._loader_info['update_time'] = new_time.timestamp() | |||
| @@ -333,11 +333,10 @@ class ExplainLoader: | |||
| def get_all_samples(self) -> List[Dict]: | |||
| """ | |||
| Return a list of sample information cached in the explain job | |||
| Return a list of sample information cached in the explain job. | |||
| Returns: | |||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||
| consists of: | |||
| sample_list (list[SampleObj]): a list of sample objects, each object consists of: | |||
| - id (int): Sample id. | |||
| - name (str): Basename of image. | |||
| @@ -406,7 +405,7 @@ class ExplainLoader: | |||
| } | |||
| Args: | |||
| benchmarks (benchmark_container): Parsed benchmarks data from summary file. | |||
| benchmarks (BenchmarkContainer): Parsed benchmarks data from summary file. | |||
| """ | |||
| explainer_score = self._benchmark['explainer_score'] | |||
| label_score = self._benchmark['label_score'] | |||
| @@ -429,7 +428,7 @@ class ExplainLoader: | |||
| Parse the sample event. | |||
| Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored | |||
| in the following structure. | |||
| in the following structure: | |||
| - ground_truth_labels (list[int]): A list of ground truth labels of the sample. | |||
| - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. | |||
| @@ -65,7 +65,7 @@ class ExplainManager: | |||
| Args: | |||
| reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded | |||
| once. Default: 0. | |||
| once. Default: 0. | |||
| """ | |||
| thread = threading.Thread(target=self._repeat_loading, | |||
| name='explainer.start_load_thread', | |||
| @@ -80,10 +80,10 @@ class ExplainManager: | |||
| If explain job w.r.t given loader_id is not found, None will be returned. | |||
| Args: | |||
| loader_id (str): The id of expected ExplainLoader | |||
| loader_id (str): The id of expected ExplainLoader. | |||
| Return: | |||
| explain_job | |||
| Returns: | |||
| ExplainLoader, the data loader specified by loader_id. | |||
| """ | |||
| self._check_status_valid() | |||
| @@ -111,17 +111,19 @@ class ExplainManager: | |||
| Return List of explain jobs. includes job ID, create and update time. | |||
| Args: | |||
| offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default value is 0. | |||
| limit (int): The max data items for per page. Default value is 10. | |||
| offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default: 0. | |||
| limit (int): The max data items for per page. Default: 10. | |||
| Returns: | |||
| tuple[total, directories], total indicates the overall number of explain directories and directories | |||
| indicate list of summary directory info including the following attributes. | |||
| tuple, the elements of the returned tuple are: | |||
| - total (int): The overall number of explain directories | |||
| - dir_infos (list): List of summary directory info including the following attributes: | |||
| - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, | |||
| starting with "./". | |||
| - create_time (datetime): Creation time of summary file. | |||
| - update_time (datetime): Modification time of summary file. | |||
| - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, | |||
| starting with "./". | |||
| - create_time (datetime): Creation time of summary file. | |||
| - update_time (datetime): Modification time of summary file. | |||
| """ | |||
| total, dir_infos = \ | |||
| self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit) | |||
| @@ -216,7 +218,7 @@ class ExplainManager: | |||
| return loader | |||
| def _add_loader(self, loader): | |||
| """add loader to the loader_pool.""" | |||
| """Add loader to the loader_pool.""" | |||
| if loader.train_id not in self._loader_pool: | |||
| self._loader_pool[loader.train_id] = loader | |||
| else: | |||
| @@ -1,4 +1,4 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| @@ -59,12 +59,13 @@ class ExplainParser(_SummaryParser): | |||
| Args: | |||
| filenames (list[str]): File name list. | |||
| Returns: | |||
| tuple, will return (file_changed, is_end, event_data), | |||
| tuple, the elements of the tuple are: | |||
| file_changed (bool): True if the 9latest file is changed. | |||
| is_end (bool): True if all the summary files are finished loading. | |||
| event_data (dict): return an event data, key is field. | |||
| - file_changed (bool): True if the latest file is changed. | |||
| - is_end (bool): True if all the summary files are finished loading. | |||
| - event_data (dict): Event data where keys are explanation field. | |||
| """ | |||
| summary_files = self.sort_files(filenames) | |||
| @@ -134,6 +135,12 @@ class ExplainParser(_SummaryParser): | |||
| Args: | |||
| event_str (str): Message event string in summary proto, data read from file handler. | |||
| Returns: | |||
| tuple, the elements of the result tuple are: | |||
| - field_list (list): Explain fields to be parsed. | |||
| - tensor_value_list (list): Parsed data with respect to the field list. | |||
| """ | |||
| logger.debug("Start to parse event string. Event string len: %s.", len(event_str)) | |||
| @@ -172,10 +179,13 @@ class ExplainParser(_SummaryParser): | |||
| @staticmethod | |||
| def _add_image_data(tensor_event_value): | |||
| """ | |||
| Parse image data based on sample_id in Explain message | |||
| Parse image data based on sample_id in Explain message. | |||
| Args: | |||
| tensor_event_value: the object of Explain message | |||
| tensor_event_value (Event): The object of Explain message. | |||
| Returns: | |||
| SampleContainer, a named tuple containing sample data. | |||
| """ | |||
| inference = InferfenceContainer( | |||
| ground_truth_prob=tensor_event_value.inference.ground_truth_prob, | |||
| @@ -205,10 +215,10 @@ class ExplainParser(_SummaryParser): | |||
| Parse benchmark data from Explain message. | |||
| Args: | |||
| tensor_event_value: the object of Explain message | |||
| tensor_event_value (Event): The object of Explain message. | |||
| Returns: | |||
| benchmark_data: An object containing benchmark. | |||
| BenchmarkContainer, a named tuple containing benchmark data. | |||
| """ | |||
| benchmark_data = BenchmarkContainer( | |||
| benchmark=tensor_event_value.benchmark, | |||
| @@ -223,10 +233,10 @@ class ExplainParser(_SummaryParser): | |||
| Parse metadata from Explain message. | |||
| Args: | |||
| tensor_event_value: the object of Explain message | |||
| tensor_event_value (Event): The object of Explain message. | |||
| Returns: | |||
| benchmark_data: An object containing metadata. | |||
| MetadataContainer, a named tuple containing benchmark data. | |||
| """ | |||
| metadata_value = MetadataContainer( | |||
| metadata=tensor_event_value.metadata, | |||