diff --git a/mindinsight/backend/explainer/explainer_api.py b/mindinsight/backend/explainer/explainer_api.py index 34c74d3a..1d14a396 100644 --- a/mindinsight/backend/explainer/explainer_api.py +++ b/mindinsight/backend/explainer/explainer_api.py @@ -40,8 +40,52 @@ URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) +def _validate_type(param, name, expected_types): + """ + Common function to validate type. + + Args: + param (object): Parameter to be validated. + name (str): Name of the parameter. + expected_types (type, tuple[type]): Expected type(s) of param. + + Raises: + ParamTypeError: When param is not an instance of expected_types. + """ + + if not isinstance(param, expected_types): + raise ParamTypeError(name, expected_types) + + +def _validate_value(param, name, expected_values): + """ + Common function to validate values of param. + + Args: + param (object): Parameter to be validated. + name (str): Name of the parameter. + expected_values (tuple) : Expected values of param. + + Raises: + ParamValueError: When param is not in expected_values. + """ + + if param not in expected_values: + raise ParamValueError(f"Valid options for {name} are {expected_values}, but got {param}.") + + def _image_url_formatter(train_id, image_path, image_type): - """Returns image url.""" + """ + Returns image url. + + Args: + train_id (str): Id that specifies explain job. + image_path (str): Local path or unique string that specifies the image for query. + image_type (str): Image query type. + + Returns: + str, url string for image query. + """ data = { "train_id": train_id, "path": image_path, @@ -69,38 +113,48 @@ def _read_post_request(post_request): def _get_query_sample_parameters(data): - """Get parameter for query.""" + """ + Get parameter for query. + + Args: + data (dict): Dict that contains request info. + + Returns: + dict, key-value pairs to call backend query functions. + + Raises: + ParamMissError: If train_id info is not in the request. + ParamTypeError: If certain key is not in the expected type in the request. + ParamValueError: If certain key does not have the expected value in the request. + """ train_id = data.get("train_id") if train_id is None: raise ParamMissError('train_id') labels = data.get("labels") - if labels is not None and not isinstance(labels, list): - raise ParamTypeError("labels", (list, None)) + if labels is not None: + _validate_type(labels, "labels", list) if labels: for item in labels: - if not isinstance(item, str): - raise ParamTypeError("element of labels", str) + _validate_type(item, "element of labels", str) limit = data.get("limit", 10) limit = Validation.check_limit(limit, min_value=1, max_value=100) offset = data.get("offset", 0) offset = Validation.check_offset(offset=offset) sorted_name = data.get("sorted_name", "") + _validate_value(sorted_name, "sorted_name", ('', 'confidence', 'uncertainty')) + sorted_type = data.get("sorted_type", "descending") - if sorted_name not in ("", "confidence", "uncertainty"): - raise ParamValueError(f"sorted_name: {sorted_name}, valid options: '' 'confidence' 'uncertainty'") - if sorted_type not in ("ascending", "descending"): - raise ParamValueError(f"sorted_type: {sorted_type}, valid options: 'confidence' 'uncertainty'") + _validate_value(sorted_type, "sorted_type", ("ascending", "descending")) prediction_types = data.get("prediction_types") - if prediction_types is not None and not isinstance(prediction_types, list): - raise ParamTypeError("prediction_types", (list, None)) + if prediction_types is not None: + _validate_type(prediction_types, "element of labels", list) if prediction_types: for item in prediction_types: - if item not in ['TP', 'FN', 'FP']: - raise ParamValueError(f"Item of prediction_types must be in ['TP', 'FN', 'FP'], but got {item}.") + _validate_value(item, "element of prediction_types", ('TP', 'FN', 'FP')) query_kwarg = {"train_id": train_id, "labels": labels, @@ -114,7 +168,17 @@ def _get_query_sample_parameters(data): @BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"]) def query_explain_jobs(): - """Query explain jobs.""" + """ + Query explain jobs. + + Returns: + Response, contains dict that stores base directory, total number of jobs and their detailed job metadata. + + Raises: + ParamMissError: If train_id info is not in the request. + ParamTypeError: If one of (offset, limit) is not integer in the request. + ParamValueError: If one of (offset, limit) does not have the expected value in the request. + """ offset = request.args.get("offset", default=0) limit = request.args.get("limit", default=10) offset = Validation.check_offset(offset=offset) @@ -132,7 +196,15 @@ def query_explain_jobs(): @BLUEPRINT.route("/explainer/explain-job", methods=["GET"]) def query_explain_job(): - """Query explain job meta-data.""" + """ + Query explain job meta-data. + + Returns: + Response, contains dict that stores metadata of the requested job. + + Raises: + ParamMissError: If train_id info is not in the request. + """ train_id = get_train_id(request) if train_id is None: raise ParamMissError("train_id") @@ -143,7 +215,16 @@ def query_explain_job(): @BLUEPRINT.route("/explainer/saliency", methods=["POST"]) def query_saliency(): - """Query saliency map related results.""" + """ + Query saliency map related results. + + Returns: + Response, contains dict that stores number of samples and the detailed sample info. + + Raises: + ParamTypeError: If certain key is not in the expected type in the request. + ParamValueError: If certain key does not have the expected value in the request. + """ data = _read_post_request(request) query_kwarg = _get_query_sample_parameters(data) explainers = data.get("explainers") @@ -169,7 +250,16 @@ def query_saliency(): @BLUEPRINT.route("/explainer/hoc", methods=["POST"]) def query_hoc(): - """Query hierarchical occlusion related results.""" + """ + Query hierarchical occlusion related results. + + Returns: + Response, contains dict that stores number of samples and the detailed sample info. + + Raises: + ParamTypeError: If certain key is not in the expected type in the request. + ParamValueError: If certain key does not have the expected value in the request. + """ data = _read_post_request(request) query_kwargs = _get_query_sample_parameters(data) @@ -193,7 +283,15 @@ def query_hoc(): @BLUEPRINT.route("/explainer/evaluation", methods=["GET"]) def query_evaluation(): - """Query saliency explainer evaluation scores.""" + """ + Query saliency explainer evaluation scores. + + Returns: + Response, contains dict that stores evaluation scores. + + Raises: + ParamMissError: If train_id info is not in the request. + """ train_id = get_train_id(request) if train_id is None: raise ParamMissError("train_id") @@ -206,7 +304,12 @@ def query_evaluation(): @BLUEPRINT.route("/explainer/image", methods=["GET"]) def query_image(): - """Query image.""" + """ + Query image. + + Returns: + bytes, image binary content for UI to demonstrate. + """ train_id = get_train_id(request) if train_id is None: raise ParamMissError("train_id") @@ -230,6 +333,6 @@ def init_module(app): Init module entry. Args: - app: the application obj. + app (flask.app): The application obj. """ app.register_blueprint(BLUEPRINT) diff --git a/mindinsight/datavisual/data_transform/summary_watcher.py b/mindinsight/datavisual/data_transform/summary_watcher.py index a9322840..c55be672 100644 --- a/mindinsight/datavisual/data_transform/summary_watcher.py +++ b/mindinsight/datavisual/data_transform/summary_watcher.py @@ -220,10 +220,7 @@ class SummaryWatcher: summary_dict[relative_path].update(job_dict) if summary_dict[relative_path]['create_time'] < ctime: - summary_dict[relative_path].update({ - 'create_time': ctime, - 'update_time': mtime, - }) + summary_dict[relative_path].update({'create_time': ctime, 'update_time': mtime}) job_dict = _get_explain_job_info(summary_base_dir, relative_path, timestamp) summary_dict[relative_path].update(job_dict) @@ -243,12 +240,10 @@ class SummaryWatcher: if not is_find: return - profiler = { - 'directory': os.path.join('.', entry.name), - 'create_time': ctime, - 'update_time': mtime, - "profiler_type": profiler_type - } + profiler = {'directory': os.path.join('.', entry.name), + 'create_time': ctime, + 'update_time': mtime, + "profiler_type": profiler_type} if relative_path in summary_dict: summary_dict[relative_path]['profiler'] = profiler diff --git a/mindinsight/explainer/common/enums.py b/mindinsight/explainer/common/enums.py index d2c29f5e..d5a473c3 100644 --- a/mindinsight/explainer/common/enums.py +++ b/mindinsight/explainer/common/enums.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -50,3 +50,16 @@ class CacheStatus(enum.Enum): NOT_IN_CACHE = "NOT_IN_CACHE" CACHING = "CACHING" CACHED = "CACHED" + + +class ExplanationKeys(enum.Enum): + """Query type enums.""" + HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose + SALIENCY = "saliency_maps" + + +class ImageQueryTypes(enum.Enum): + """Image query type enums.""" + ORIGINAL = 'original' # Query for the original image + OUTCOME = 'outcome' # Query for outcome of HOC explanation + OVERLAY = 'overlay' # Query for saliency maps overlay diff --git a/mindinsight/explainer/encapsulator/_hoc_pil_apply.py b/mindinsight/explainer/encapsulator/_hoc_pil_apply.py index 2e0f8996..00113a6e 100644 --- a/mindinsight/explainer/encapsulator/_hoc_pil_apply.py +++ b/mindinsight/explainer/encapsulator/_hoc_pil_apply.py @@ -63,7 +63,7 @@ def pil_apply_edit_steps(image, mask, edit_steps, by_masking=False, inplace=Fals Args: image (PIL.Image): The input image in RGB mode. mask (Union[str, int, tuple[int, int, int], PIL.Image.Image]): The mask to apply on the image, could be string - e.g. 'gaussian:9', a single, grey scale intensity [0, 255], a RBG tuple or a PIL Image object. + e.g. 'gaussian:9', a single, grey scale intensity [0, 255], an RBG tuple or a PIL Image object. edit_steps (list[EditStep]): Edit steps to be drawn. by_masking (bool): Whether to use masking method. Default: False. inplace (bool): True to draw on the input image, otherwise draw on a cloned image. @@ -99,7 +99,7 @@ def _pil_apply_edit_steps_unmask(image, mask, edit_steps, inplace=False): Args: image (PIL.Image): The input image. mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey - scale intensity [0, 255], a RBG tuple or a PIL Image. + scale intensity [0, 255], an RBG tuple or a PIL Image. edit_steps (list[EditStep]): Edit steps to be drawn. inplace (bool): True to draw on the input image, otherwise draw on a cloned image. @@ -132,7 +132,7 @@ def _pil_apply_edit_steps_mask(image, mask, edit_steps, inplace=False): Args: image (PIL.Image): The input image. mask (Union[int, tuple[int, int, int], PIL.Image]): The mask to apply on the image, could be a single grey - scale intensity [0, 255], a RBG tuple or a PIL Image. + scale intensity [0, 255], an RBG tuple or a PIL Image. edit_steps (list[EditStep]): Edit steps to be drawn. inplace (bool): True to draw on the input image, otherwise draw on a cloned image. diff --git a/mindinsight/explainer/encapsulator/datafile_encap.py b/mindinsight/explainer/encapsulator/datafile_encap.py index 249b6858..5c1eaf8d 100644 --- a/mindinsight/explainer/encapsulator/datafile_encap.py +++ b/mindinsight/explainer/encapsulator/datafile_encap.py @@ -21,6 +21,7 @@ import numpy as np from PIL import Image from mindinsight.datavisual.common.exceptions import ImageNotExistError +from mindinsight.explainer.common.enums import ImageQueryTypes from mindinsight.explainer.encapsulator._hoc_pil_apply import EditStep, pil_apply_edit_steps from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap from mindinsight.utils.exceptions import FileSystemPermissionError @@ -63,41 +64,10 @@ class DatafileEncap(ExplainDataEncap): image_type (str): Image type, Options: 'original', 'overlay' or 'outcome'. Returns: - bytes, image binary. + bytes, image binary content for UI to demonstrate. """ - if image_type == "outcome": - sample_id, label, layer = image_path.strip(".jpg").split("_") - layer = int(layer) - job = self.job_manager.get_job(train_id) - samples = job.samples - label_idx = job.labels.index(label) - - chosen_sample = samples[int(sample_id)] - original_path_image = chosen_sample['image'] - abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), - original_path_image) - if self._is_forbidden(abs_image_path): - raise FileSystemPermissionError("Forbidden.") - try: - image = Image.open(abs_image_path) - except FileNotFoundError: - raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") - except PermissionError: - raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") - except OSError: - raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") - - edit_steps = [] - boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] - mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] - - for box in boxes: - edit_steps.append(EditStep(layer, *box)) - image_cp = pil_apply_edit_steps(image, mask, edit_steps) - buffer = io.BytesIO() - image_cp.save(buffer, format=_PNG_FORMAT) - - return buffer.getvalue() + if image_type == ImageQueryTypes.OUTCOME.value: + return self._get_hoc_image(image_path, train_id) abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), @@ -108,7 +78,7 @@ class DatafileEncap(ExplainDataEncap): try: - if image_type != "overlay": + if image_type != ImageQueryTypes.OVERLAY.value: # no need to convert with open(abs_image_path, "rb") as fp: return fp.read() @@ -153,3 +123,41 @@ class DatafileEncap(ExplainDataEncap): base_dir = os.path.realpath(self.job_manager.summary_base_dir) path = os.path.realpath(path) return not path.startswith(base_dir) + + def _get_hoc_image(self, image_path, train_id): + """Get hoc image for image data demonstration in UI.""" + + sample_id, label, layer = image_path.strip(".jpg").split("_") + layer = int(layer) + job = self.job_manager.get_job(train_id) + samples = job.samples + label_idx = job.labels.index(label) + + chosen_sample = samples[int(sample_id)] + original_path_image = chosen_sample['image'] + abs_image_path = os.path.join(self.job_manager.summary_base_dir, _clean_train_id_b4_join(train_id), + original_path_image) + if self._is_forbidden(abs_image_path): + raise FileSystemPermissionError("Forbidden.") + + image_type = ImageQueryTypes.OUTCOME.value + try: + image = Image.open(abs_image_path) + except FileNotFoundError: + raise ImageNotExistError(f"train_id:{train_id} path:{image_path} type:{image_type}") + except PermissionError: + raise FileSystemPermissionError(f"train_id:{train_id} path:{image_path} type:{image_type}") + except OSError: + raise UnknownError(f"Invalid image file: train_id:{train_id} path:{image_path} type:{image_type}") + + edit_steps = [] + boxes = chosen_sample["hierarchical_occlusion"][label_idx]["hoc_layers"][layer]["boxes"] + mask = chosen_sample["hierarchical_occlusion"][label_idx]["mask"] + + for box in boxes: + edit_steps.append(EditStep(layer, *box)) + image_cp = pil_apply_edit_steps(image, mask, edit_steps) + buffer = io.BytesIO() + image_cp.save(buffer, format=_PNG_FORMAT) + + return buffer.getvalue() diff --git a/mindinsight/explainer/encapsulator/explain_data_encap.py b/mindinsight/explainer/encapsulator/explain_data_encap.py index 4dee8f86..c2110be5 100644 --- a/mindinsight/explainer/encapsulator/explain_data_encap.py +++ b/mindinsight/explainer/encapsulator/explain_data_encap.py @@ -15,13 +15,13 @@ """Common explain data encapsulator base class.""" import copy -from enum import Enum +from mindinsight.explainer.common.enums import ExplanationKeys from mindinsight.utils.exceptions import ParamValueError def _sort_key_min_confidence(sample, labels): - """Samples sort key by the min. confidence.""" + """Samples sort key by the minimum confidence.""" min_confidence = float("+inf") for inference in sample["inferences"]: if labels and inference["label"] not in labels: @@ -32,7 +32,7 @@ def _sort_key_min_confidence(sample, labels): def _sort_key_max_confidence(sample, labels): - """Samples sort key by the max. confidence.""" + """Samples sort key by the maximum confidence.""" max_confidence = float("-inf") for inference in sample["inferences"]: if labels and inference["label"] not in labels: @@ -43,7 +43,7 @@ def _sort_key_max_confidence(sample, labels): def _sort_key_min_confidence_sd(sample, labels): - """Samples sort key by the min. confidence_sd.""" + """Samples sort key by the minimum confidence_sd.""" min_confidence_sd = float("+inf") for inference in sample["inferences"]: if labels and inference["label"] not in labels: @@ -55,7 +55,7 @@ def _sort_key_min_confidence_sd(sample, labels): def _sort_key_max_confidence_sd(sample, labels): - """Samples sort key by the max. confidence_sd.""" + """Samples sort key by the maximum confidence_sd.""" max_confidence_sd = float("-inf") for inference in sample["inferences"]: if labels and inference["label"] not in labels: @@ -65,11 +65,6 @@ def _sort_key_max_confidence_sd(sample, labels): max_confidence_sd = confidence_sd return max_confidence_sd -class ExplanationKeys(Enum): - """Query type enums.""" - HOC = "hoc_layers" # HOC: Hierarchical Occlusion, an explanation method we propose - SALIENCY = "saliency_maps" - class ExplainDataEncap: """Explain data encapsulator base class.""" @@ -105,9 +100,9 @@ class ExplanationEncap(ExplainDataEncap): sorted_name (str): Field to be sorted. sorted_type (str): Sorting order, 'ascending' or 'descending'. prediction_types (list[str]): Prediction type filter. - drop_type (str, None): When it is None, no filer will be applied. When it is 'hoc_layers', samples without - hoc explanations will be filtered out. When it is 'saliency_maps', samples without saliency explanations - will be filtered out. + drop_type (str, None): When it is None, all data will be kept. When it is 'hoc_layers', samples without + hoc explanations will be drop out. When it is 'saliency_maps', samples without saliency explanations + will be drop out. Returns: list[dict], samples to be queried. diff --git a/mindinsight/explainer/encapsulator/explain_job_encap.py b/mindinsight/explainer/encapsulator/explain_job_encap.py index 1c9e5ddd..da2d1df2 100644 --- a/mindinsight/explainer/encapsulator/explain_job_encap.py +++ b/mindinsight/explainer/encapsulator/explain_job_encap.py @@ -29,11 +29,13 @@ class ExplainJobEncap(ExplainDataEncap): def query_explain_jobs(self, offset, limit): """ Query explain job list. + Args: offset (int): Page offset. - limit (int): Max. no. of items to be returned. + limit (int): Maximum number of items to be returned. + Returns: - tuple[int, list[Dict]], total no. of jobs and job list. + tuple[int, list[Dict]], total number of jobs and job list. """ total, dir_infos = self.job_manager.get_job_list(offset=offset, limit=limit) job_infos = [self._dir_2_info(dir_info) for dir_info in dir_infos] diff --git a/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py b/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py index 42e3eae6..0b0bf440 100644 --- a/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py +++ b/mindinsight/explainer/encapsulator/hierarchical_occlusion_encap.py @@ -15,7 +15,8 @@ """Hierarchical Occlusion encapsulator.""" from mindinsight.datavisual.common.exceptions import TrainJobNotExistError -from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys +from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes +from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap class HierarchicalOcclusionEncap(ExplanationEncap): @@ -81,7 +82,10 @@ class HierarchicalOcclusionEncap(ExplanationEncap): Returns: dict, the edited sample info. """ - sample["image"] = self._get_image_url(job.train_id, sample["image"], "original") + original = ImageQueryTypes.ORIGINAL.value + outcome = ImageQueryTypes.OUTCOME.value + + sample["image"] = self._get_image_url(job.train_id, sample["image"], original) inferences = sample["inferences"] i = 0 # init index for while loop while i < len(inferences): @@ -91,9 +95,9 @@ class HierarchicalOcclusionEncap(ExplanationEncap): continue new_list = [] for idx, hoc_layer in enumerate(inference_item[ExplanationKeys.HOC.value]): - hoc_layer["outcome"] = self._get_image_url(job.train_id, - f"{sample['id']}_{inference_item['label']}_{idx}.jpg", - "outcome") + hoc_layer[outcome] = self._get_image_url(job.train_id, + f"{sample['id']}_{inference_item['label']}_{idx}.jpg", + outcome) new_list.append(hoc_layer) inference_item[ExplanationKeys.HOC.value] = new_list i += 1 diff --git a/mindinsight/explainer/encapsulator/saliency_encap.py b/mindinsight/explainer/encapsulator/saliency_encap.py index 4634b264..789694f0 100644 --- a/mindinsight/explainer/encapsulator/saliency_encap.py +++ b/mindinsight/explainer/encapsulator/saliency_encap.py @@ -15,7 +15,8 @@ """Saliency map encapsulator.""" from mindinsight.datavisual.common.exceptions import TrainJobNotExistError -from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap, ExplanationKeys +from mindinsight.explainer.common.enums import ExplanationKeys, ImageQueryTypes +from mindinsight.explainer.encapsulator.explain_data_encap import ExplanationEncap class SaliencyEncap(ExplanationEncap): @@ -32,6 +33,7 @@ class SaliencyEncap(ExplanationEncap): prediction_types=None): """ Query saliency maps. + Args: train_id (str): Job ID. labels (list[str]): Label filter. @@ -65,7 +67,8 @@ class SaliencyEncap(ExplanationEncap): def _touch_sample(self, sample, job, explainers): """ - Final editing the sample info. + Final edit on single sample info. + Args: sample (dict): Sample info. job (ExplainJob): Explain job. @@ -74,14 +77,17 @@ class SaliencyEncap(ExplanationEncap): Returns: dict, the edited sample info. """ + original = ImageQueryTypes.ORIGINAL.value + overlay = ImageQueryTypes.OVERLAY.value + sample_cp = sample.copy() - sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], "original") + sample_cp["image"] = self._get_image_url(job.train_id, sample['image'], original) for inference in sample_cp["inferences"]: new_list = [] for saliency_map in inference[ExplanationKeys.SALIENCY.value]: if explainers and saliency_map["explainer"] not in explainers: continue - saliency_map["overlay"] = self._get_image_url(job.train_id, saliency_map['overlay'], "overlay") + saliency_map[overlay] = self._get_image_url(job.train_id, saliency_map[overlay], overlay) new_list.append(saliency_map) inference[ExplanationKeys.SALIENCY.value] = new_list return sample_cp diff --git a/mindinsight/explainer/manager/explain_loader.py b/mindinsight/explainer/manager/explain_loader.py index 302eda65..db4f4a94 100644 --- a/mindinsight/explainer/manager/explain_loader.py +++ b/mindinsight/explainer/manager/explain_loader.py @@ -270,7 +270,7 @@ class ExplainLoader: Update the update_time manually. Args: - new_time stamp (datetime.datetime or float): Updated time for the summary file. + new_time (datetime.datetime or float): Updated time for the summary file. """ if isinstance(new_time, datetime): self._loader_info['update_time'] = new_time.timestamp() @@ -333,11 +333,10 @@ class ExplainLoader: def get_all_samples(self) -> List[Dict]: """ - Return a list of sample information cached in the explain job + Return a list of sample information cached in the explain job. Returns: - sample_list (List[SampleObj]): a list of sample objects, each object - consists of: + sample_list (list[SampleObj]): a list of sample objects, each object consists of: - id (int): Sample id. - name (str): Basename of image. @@ -406,7 +405,7 @@ class ExplainLoader: } Args: - benchmarks (benchmark_container): Parsed benchmarks data from summary file. + benchmarks (BenchmarkContainer): Parsed benchmarks data from summary file. """ explainer_score = self._benchmark['explainer_score'] label_score = self._benchmark['label_score'] @@ -429,7 +428,7 @@ class ExplainLoader: Parse the sample event. Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored - in the following structure. + in the following structure: - ground_truth_labels (list[int]): A list of ground truth labels of the sample. - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. diff --git a/mindinsight/explainer/manager/explain_manager.py b/mindinsight/explainer/manager/explain_manager.py index 7d650698..17710c95 100644 --- a/mindinsight/explainer/manager/explain_manager.py +++ b/mindinsight/explainer/manager/explain_manager.py @@ -65,7 +65,7 @@ class ExplainManager: Args: reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded - once. Default: 0. + once. Default: 0. """ thread = threading.Thread(target=self._repeat_loading, name='explainer.start_load_thread', @@ -80,10 +80,10 @@ class ExplainManager: If explain job w.r.t given loader_id is not found, None will be returned. Args: - loader_id (str): The id of expected ExplainLoader + loader_id (str): The id of expected ExplainLoader. - Return: - explain_job + Returns: + ExplainLoader, the data loader specified by loader_id. """ self._check_status_valid() @@ -111,17 +111,19 @@ class ExplainManager: Return List of explain jobs. includes job ID, create and update time. Args: - offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default value is 0. - limit (int): The max data items for per page. Default value is 10. + offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default: 0. + limit (int): The max data items for per page. Default: 10. Returns: - tuple[total, directories], total indicates the overall number of explain directories and directories - indicate list of summary directory info including the following attributes. + tuple, the elements of the returned tuple are: + + - total (int): The overall number of explain directories + - dir_infos (list): List of summary directory info including the following attributes: - - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, - starting with "./". - - create_time (datetime): Creation time of summary file. - - update_time (datetime): Modification time of summary file. + - relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR, + starting with "./". + - create_time (datetime): Creation time of summary file. + - update_time (datetime): Modification time of summary file. """ total, dir_infos = \ self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit) @@ -216,7 +218,7 @@ class ExplainManager: return loader def _add_loader(self, loader): - """add loader to the loader_pool.""" + """Add loader to the loader_pool.""" if loader.train_id not in self._loader_pool: self._loader_pool[loader.train_id] = loader else: diff --git a/mindinsight/explainer/manager/explain_parser.py b/mindinsight/explainer/manager/explain_parser.py index ec6aaebd..70177008 100644 --- a/mindinsight/explainer/manager/explain_parser.py +++ b/mindinsight/explainer/manager/explain_parser.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -59,12 +59,13 @@ class ExplainParser(_SummaryParser): Args: filenames (list[str]): File name list. + Returns: - tuple, will return (file_changed, is_end, event_data), + tuple, the elements of the tuple are: - file_changed (bool): True if the 9latest file is changed. - is_end (bool): True if all the summary files are finished loading. - event_data (dict): return an event data, key is field. + - file_changed (bool): True if the latest file is changed. + - is_end (bool): True if all the summary files are finished loading. + - event_data (dict): Event data where keys are explanation field. """ summary_files = self.sort_files(filenames) @@ -134,6 +135,12 @@ class ExplainParser(_SummaryParser): Args: event_str (str): Message event string in summary proto, data read from file handler. + + Returns: + tuple, the elements of the result tuple are: + + - field_list (list): Explain fields to be parsed. + - tensor_value_list (list): Parsed data with respect to the field list. """ logger.debug("Start to parse event string. Event string len: %s.", len(event_str)) @@ -172,10 +179,13 @@ class ExplainParser(_SummaryParser): @staticmethod def _add_image_data(tensor_event_value): """ - Parse image data based on sample_id in Explain message + Parse image data based on sample_id in Explain message. Args: - tensor_event_value: the object of Explain message + tensor_event_value (Event): The object of Explain message. + + Returns: + SampleContainer, a named tuple containing sample data. """ inference = InferfenceContainer( ground_truth_prob=tensor_event_value.inference.ground_truth_prob, @@ -205,10 +215,10 @@ class ExplainParser(_SummaryParser): Parse benchmark data from Explain message. Args: - tensor_event_value: the object of Explain message + tensor_event_value (Event): The object of Explain message. Returns: - benchmark_data: An object containing benchmark. + BenchmarkContainer, a named tuple containing benchmark data. """ benchmark_data = BenchmarkContainer( benchmark=tensor_event_value.benchmark, @@ -223,10 +233,10 @@ class ExplainParser(_SummaryParser): Parse metadata from Explain message. Args: - tensor_event_value: the object of Explain message + tensor_event_value (Event): The object of Explain message. Returns: - benchmark_data: An object containing metadata. + MetadataContainer, a named tuple containing benchmark data. """ metadata_value = MetadataContainer( metadata=tensor_event_value.metadata,