| @@ -13,7 +13,9 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Module init file.""" | |||
| from mindinsight.conf import settings | |||
| from mindinsight.backend.explainer.explainer_api import init_module as init_query_module | |||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||
| def init_module(app): | |||
| @@ -27,3 +29,4 @@ def init_module(app): | |||
| """ | |||
| init_query_module(app) | |||
| EXPLAIN_MANAGER.start_load_data(reload_interval=settings.RELOAD_INTERVAL) | |||
| @@ -29,32 +29,16 @@ from mindinsight.datavisual.common.exceptions import ImageNotExistError | |||
| from mindinsight.datavisual.common.validation import Validation | |||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.datavisual.utils.tools import get_train_id | |||
| from mindinsight.explainer.manager.explain_manager import ExplainManager | |||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||
| from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | |||
| from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap | |||
| from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | |||
| from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap | |||
| URL_PREFIX = settings.URL_PATH_PREFIX+settings.API_PREFIX | |||
| URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX | |||
| BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | |||
| class ExplainManagerHolder: | |||
| """ExplainManger instance holder.""" | |||
| static_instance = None | |||
| @classmethod | |||
| def get_instance(cls): | |||
| return cls.static_instance | |||
| @classmethod | |||
| def initialize(cls): | |||
| cls.static_instance = ExplainManager(settings.SUMMARY_BASE_DIR) | |||
| cls.static_instance.start_load_data() | |||
| def _image_url_formatter(train_id, image_path, image_type): | |||
| """Returns image url.""" | |||
| data = { | |||
| @@ -91,7 +75,7 @@ def query_explain_jobs(): | |||
| offset = Validation.check_offset(offset=offset) | |||
| limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT) | |||
| encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance()) | |||
| encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) | |||
| total, jobs = encapsulator.query_explain_jobs(offset, limit) | |||
| return jsonify({ | |||
| @@ -107,7 +91,7 @@ def query_explain_job(): | |||
| train_id = get_train_id(request) | |||
| if train_id is None: | |||
| raise ParamMissError("train_id") | |||
| encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance()) | |||
| encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) | |||
| metadata = encapsulator.query_meta(train_id) | |||
| return jsonify(metadata) | |||
| @@ -139,7 +123,7 @@ def query_saliency(): | |||
| encapsulator = SaliencyEncap( | |||
| _image_url_formatter, | |||
| ExplainManagerHolder.get_instance()) | |||
| EXPLAIN_MANAGER) | |||
| count, samples = encapsulator.query_saliency_maps(train_id=train_id, | |||
| labels=labels, | |||
| explainers=explainers, | |||
| @@ -160,7 +144,7 @@ def query_evaluation(): | |||
| train_id = get_train_id(request) | |||
| if train_id is None: | |||
| raise ParamMissError("train_id") | |||
| encapsulator = EvaluationEncap(ExplainManagerHolder.get_instance()) | |||
| encapsulator = EvaluationEncap(EXPLAIN_MANAGER) | |||
| scores = encapsulator.query_explainer_scores(train_id) | |||
| return jsonify({ | |||
| "explainer_scores": scores, | |||
| @@ -182,7 +166,7 @@ def query_image(): | |||
| if image_type not in ("original", "overlay"): | |||
| raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") | |||
| encapsulator = DatafileEncap(ExplainManagerHolder.get_instance()) | |||
| encapsulator = DatafileEncap(EXPLAIN_MANAGER) | |||
| image = encapsulator.query_image_binary(train_id, image_path, image_type) | |||
| if image is None: | |||
| raise ImageNotExistError(f"{image_path}") | |||
| @@ -198,5 +182,4 @@ def init_module(app): | |||
| app: the application obj. | |||
| """ | |||
| ExplainManagerHolder.initialize() | |||
| app.register_blueprint(BLUEPRINT) | |||
| @@ -31,7 +31,7 @@ class DataManagerStatus(BaseEnum): | |||
| INVALID = 'INVALID' | |||
| class PluginNameEnum(BaseEnum): | |||
| class ExplainFieldsEnum(BaseEnum): | |||
| """Plugin Name Enum.""" | |||
| EXPLAIN = 'explain' | |||
| SAMPLE_ID = 'sample_id' | |||
| @@ -70,7 +70,7 @@ class ExplainJobEncap(ExplainDataEncap): | |||
| info["train_id"] = job.train_id | |||
| info["create_time"] = datetime.fromtimestamp(job.create_time)\ | |||
| .strftime(cls.DATETIME_FORMAT) | |||
| info["update_time"] = datetime.fromtimestamp(job.latest_update_time)\ | |||
| info["update_time"] = datetime.fromtimestamp(job.update_time)\ | |||
| .strftime(cls.DATETIME_FORMAT) | |||
| return info | |||
| @@ -1,168 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """EventParser for summary event.""" | |||
| from collections import namedtuple, defaultdict | |||
| from typing import Dict, List, Optional, Tuple | |||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| _IMAGE_DATA_TAGS = { | |||
| 'sample_id': PluginNameEnum.SAMPLE_ID.value, | |||
| 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value, | |||
| 'inference': PluginNameEnum.INFERENCE.value, | |||
| 'explanation': PluginNameEnum.EXPLANATION.value | |||
| } | |||
| _NUM_DIGIT = 7 | |||
| class EventParser: | |||
| """Parser for event data.""" | |||
| def __init__(self, job): | |||
| self._job = job | |||
| self._sample_pool = {} | |||
| @staticmethod | |||
| def parse_metadata(metadata) -> Tuple[List, List, List]: | |||
| """Parse the metadata event.""" | |||
| explainers = list(metadata.explain_method) | |||
| metrics = list(metadata.benchmark_method) | |||
| labels = list(metadata.label) | |||
| return explainers, metrics, labels | |||
| @staticmethod | |||
| def parse_benchmark(benchmarks) -> Tuple[Dict, Dict]: | |||
| """Parse the benchmark event.""" | |||
| explainer_score_dict = defaultdict(list) | |||
| label_score_dict = defaultdict(dict) | |||
| for benchmark in benchmarks: | |||
| explainer = benchmark.explain_method | |||
| metric = benchmark.benchmark_method | |||
| metric_score = benchmark.total_score | |||
| label_score_event = benchmark.label_score | |||
| explainer_score_dict[explainer].append({ | |||
| 'metric': metric, | |||
| 'score': round(metric_score, _NUM_DIGIT)}) | |||
| new_label_score_dict = EventParser._score_event_to_dict(label_score_event, metric) | |||
| for label, label_scores in new_label_score_dict.items(): | |||
| label_score_dict[explainer][label] = label_score_dict[explainer].get(label, []) + label_scores | |||
| return explainer_score_dict, label_score_dict | |||
| def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]: | |||
| """Parse the sample event.""" | |||
| sample_id = sample.sample_id | |||
| if sample_id not in self._sample_pool: | |||
| self._sample_pool[sample_id] = sample | |||
| return None | |||
| for tag in _IMAGE_DATA_TAGS: | |||
| try: | |||
| if tag == PluginNameEnum.INFERENCE.value: | |||
| self._parse_inference(sample, sample_id) | |||
| elif tag == PluginNameEnum.EXPLANATION.value: | |||
| self._parse_explanation(sample, sample_id) | |||
| else: | |||
| self._parse_sample_info(sample, sample_id, tag) | |||
| except UnknownError as ex: | |||
| logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) | |||
| continue | |||
| if EventParser._is_ready_for_display(self._sample_pool[sample_id]): | |||
| return self._sample_pool[sample_id] | |||
| return None | |||
| def clear(self): | |||
| """Clear the loaded data.""" | |||
| self._sample_pool.clear() | |||
| @staticmethod | |||
| def _is_ready_for_display(image_container: namedtuple) -> bool: | |||
| """ | |||
| Check whether the image_container is ready for frontend display. | |||
| Args: | |||
| image_container (namedtuple): container consists of sample data | |||
| Return: | |||
| bool: whether the image_container if ready for display | |||
| """ | |||
| required_attrs = ['image_path', 'ground_truth_label', 'inference'] | |||
| for attr in required_attrs: | |||
| if not EventParser.is_attr_ready(image_container, attr): | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def is_attr_ready(image_container: namedtuple, attr: str) -> bool: | |||
| """ | |||
| Check whether the given attribute is ready in image_container. | |||
| Args: | |||
| image_container (namedtuple): container consist of sample data | |||
| attr (str): attribute to check | |||
| Returns: | |||
| bool, whether the attr is ready | |||
| """ | |||
| if getattr(image_container, attr, False): | |||
| return True | |||
| return False | |||
| @staticmethod | |||
| def _score_event_to_dict(label_score_event, metric): | |||
| """Transfer metric scores per label to pre-defined structure.""" | |||
| new_label_score_dict = defaultdict(list) | |||
| for label_id, label_score in enumerate(label_score_event): | |||
| new_label_score_dict[label_id].append({ | |||
| 'metric': metric, | |||
| 'score': round(label_score, _NUM_DIGIT), | |||
| }) | |||
| return new_label_score_dict | |||
| def _parse_inference(self, event, sample_id): | |||
| """Parse the inference event.""" | |||
| self._sample_pool[sample_id].inference.ground_truth_prob.extend(event.inference.ground_truth_prob) | |||
| self._sample_pool[sample_id].inference.ground_truth_prob_sd.extend(event.inference.ground_truth_prob_sd) | |||
| self._sample_pool[sample_id].inference.ground_truth_prob_itl95_low.\ | |||
| extend(event.inference.ground_truth_prob_itl95_low) | |||
| self._sample_pool[sample_id].inference.ground_truth_prob_itl95_hi.\ | |||
| extend(event.inference.ground_truth_prob_itl95_hi) | |||
| self._sample_pool[sample_id].inference.predicted_label.extend(event.inference.predicted_label) | |||
| self._sample_pool[sample_id].inference.predicted_prob.extend(event.inference.predicted_prob) | |||
| self._sample_pool[sample_id].inference.predicted_prob_sd.extend(event.inference.predicted_prob_sd) | |||
| self._sample_pool[sample_id].inference.predicted_prob_itl95_low.extend(event.inference.predicted_prob_itl95_low) | |||
| self._sample_pool[sample_id].inference.predicted_prob_itl95_hi.extend(event.inference.predicted_prob_itl95_hi) | |||
| def _parse_explanation(self, event, sample_id): | |||
| """Parse the explanation event.""" | |||
| if event.explanation: | |||
| for explanation_item in event.explanation: | |||
| new_explanation = self._sample_pool[sample_id].explanation.add() | |||
| new_explanation.explain_method = explanation_item.explain_method | |||
| new_explanation.label = explanation_item.label | |||
| new_explanation.heatmap_path = explanation_item.heatmap_path | |||
| def _parse_sample_info(self, event, sample_id, tag): | |||
| """Parse the event containing image info.""" | |||
| if not getattr(self._sample_pool[sample_id], tag): | |||
| setattr(self._sample_pool[sample_id], tag, getattr(event, tag)) | |||
| @@ -1,398 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ExplainJob.""" | |||
| import os | |||
| from collections import defaultdict | |||
| from datetime import datetime | |||
| from typing import Union | |||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.explainer.manager.explain_parser import _ExplainParser | |||
| from mindinsight.explainer.manager.event_parse import EventParser | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| _NUM_DIGIT = 7 | |||
| class ExplainJob: | |||
| """ExplainJob which manage the record in the summary file.""" | |||
| def __init__(self, | |||
| job_id: str, | |||
| summary_dir: str, | |||
| create_time: float, | |||
| latest_update_time: float): | |||
| self._job_id = job_id | |||
| self._summary_dir = summary_dir | |||
| self._parser = _ExplainParser(summary_dir) | |||
| self._event_parser = EventParser(self) | |||
| self._latest_update_time = latest_update_time | |||
| self._create_time = create_time | |||
| self._uncertainty_enabled = False | |||
| self._labels = [] | |||
| self._metrics = [] | |||
| self._explainers = [] | |||
| self._samples_info = {} | |||
| self._labels_info = {} | |||
| self._explainer_score_dict = defaultdict(list) | |||
| self._label_score_dict = defaultdict(dict) | |||
| @property | |||
| def all_classes(self): | |||
| """ | |||
| Return a list of label info | |||
| Returns: | |||
| class_objs (List[ClassObj]): a list of class_objects, each object | |||
| contains: | |||
| - id (int): label id | |||
| - label (str): label name | |||
| - sample_count (int): number of samples for each label | |||
| """ | |||
| all_classes_return = [] | |||
| for label_id, label_info in self._labels_info.items(): | |||
| single_info = { | |||
| 'id': label_id, | |||
| 'label': label_info['label'], | |||
| 'sample_count': len(label_info['sample_ids'])} | |||
| all_classes_return.append(single_info) | |||
| return all_classes_return | |||
| @property | |||
| def explainers(self): | |||
| """ | |||
| Return a list of explainer names | |||
| Returns: | |||
| list(str), explainer names | |||
| """ | |||
| return self._explainers | |||
| @property | |||
| def explainer_scores(self): | |||
| """Return evaluation results for every explainer.""" | |||
| merged_scores = [] | |||
| for explainer, explainer_score_on_metric in self._explainer_score_dict.items(): | |||
| label_scores = [] | |||
| for label, label_score_on_metric in self._label_score_dict[explainer].items(): | |||
| score_single_label = { | |||
| 'label': self._labels[label], | |||
| 'evaluations': label_score_on_metric, | |||
| } | |||
| label_scores.append(score_single_label) | |||
| merged_scores.append({ | |||
| 'explainer': explainer, | |||
| 'evaluations': explainer_score_on_metric, | |||
| 'class_scores': label_scores, | |||
| }) | |||
| return merged_scores | |||
| @property | |||
| def sample_count(self): | |||
| """ | |||
| Return total number of samples in the job. | |||
| Return: | |||
| int, total number of samples | |||
| """ | |||
| return len(self._samples_info) | |||
| @property | |||
| def train_id(self): | |||
| """ | |||
| Return ID of explain job | |||
| Returns: | |||
| str, id of ExplainJob object | |||
| """ | |||
| return self._job_id | |||
| @property | |||
| def metrics(self): | |||
| """ | |||
| Return a list of metric names | |||
| Returns: | |||
| list(str), metric names | |||
| """ | |||
| return self._metrics | |||
| @property | |||
| def min_confidence(self): | |||
| """ | |||
| Return minimum confidence | |||
| Returns: | |||
| min_confidence (float): | |||
| """ | |||
| return None | |||
| @property | |||
| def uncertainty_enabled(self): | |||
| return self._uncertainty_enabled | |||
| @property | |||
| def create_time(self): | |||
| """ | |||
| Return the create time of summary file | |||
| Returns: | |||
| creation timestamp (float) | |||
| """ | |||
| return self._create_time | |||
| @property | |||
| def labels(self): | |||
| """Return the label contained in the job.""" | |||
| return self._labels | |||
| @property | |||
| def latest_update_time(self): | |||
| """ | |||
| Return last modification time stamp of summary file. | |||
| Returns: | |||
| float, last_modification_time stamp | |||
| """ | |||
| return self._latest_update_time | |||
| @latest_update_time.setter | |||
| def latest_update_time(self, new_time: Union[float, datetime]): | |||
| """ | |||
| Update the latest_update_time timestamp manually. | |||
| Args: | |||
| new_time stamp (union[float, datetime]): updated time for the job | |||
| """ | |||
| if isinstance(new_time, datetime): | |||
| self._latest_update_time = new_time.timestamp() | |||
| elif isinstance(new_time, float): | |||
| self._latest_update_time = new_time | |||
| else: | |||
| raise TypeError('new_time should have type of float or datetime') | |||
| @property | |||
| def loader_id(self): | |||
| """Return the job id.""" | |||
| return self._job_id | |||
| @property | |||
| def samples(self): | |||
| """Return the information of all samples in the job.""" | |||
| return self._samples_info | |||
| @staticmethod | |||
| def get_create_time(file_path: str) -> float: | |||
| """Return timestamp of create time of specific path.""" | |||
| create_time = os.stat(file_path).st_ctime | |||
| return create_time | |||
| @staticmethod | |||
| def get_update_time(file_path: str) -> float: | |||
| """Return timestamp of update time of specific path.""" | |||
| update_time = os.stat(file_path).st_mtime | |||
| return update_time | |||
| def _initialize_labels_info(self): | |||
| """Initialize a dict for labels in the job.""" | |||
| if self._labels is None: | |||
| logger.warning('No labels is provided in job %s', self._job_id) | |||
| return | |||
| for label_id, label in enumerate(self._labels): | |||
| self._labels_info[label_id] = {'label': label, | |||
| 'sample_ids': set()} | |||
| def _explanation_to_dict(self, explanation): | |||
| """Transfer the explanation from event to dict storage.""" | |||
| explain_info = { | |||
| 'explainer': explanation.explain_method, | |||
| 'overlay': explanation.heatmap_path, | |||
| } | |||
| return explain_info | |||
| def _image_container_to_dict(self, sample_data): | |||
| """Transfer the image container to dict storage.""" | |||
| has_uncertainty = False | |||
| sample_id = sample_data.sample_id | |||
| sample_info = { | |||
| 'id': sample_id, | |||
| 'image': sample_data.image_path, | |||
| 'name': str(sample_id), | |||
| 'labels': [self._labels_info[x]['label'] | |||
| for x in sample_data.ground_truth_label], | |||
| 'inferences': []} | |||
| ground_truth_labels = list(sample_data.ground_truth_label) | |||
| ground_truth_probs = list(sample_data.inference.ground_truth_prob) | |||
| predicted_labels = list(sample_data.inference.predicted_label) | |||
| predicted_probs = list(sample_data.inference.predicted_prob) | |||
| if sample_data.inference.predicted_prob_sd or sample_data.inference.ground_truth_prob_sd: | |||
| ground_truth_prob_sds = list(sample_data.inference.ground_truth_prob_sd) | |||
| ground_truth_prob_lows = list(sample_data.inference.ground_truth_prob_itl95_low) | |||
| ground_truth_prob_his = list(sample_data.inference.ground_truth_prob_itl95_hi) | |||
| predicted_prob_sds = list(sample_data.inference.predicted_prob_sd) | |||
| predicted_prob_lows = list(sample_data.inference.predicted_prob_itl95_low) | |||
| predicted_prob_his = list(sample_data.inference.predicted_prob_itl95_hi) | |||
| has_uncertainty = True | |||
| else: | |||
| ground_truth_prob_sds = ground_truth_prob_lows = ground_truth_prob_his = None | |||
| predicted_prob_sds = predicted_prob_lows = predicted_prob_his = None | |||
| inference_info = {} | |||
| for label, prob in zip( | |||
| ground_truth_labels + predicted_labels, | |||
| ground_truth_probs + predicted_probs): | |||
| inference_info[label] = { | |||
| 'label': self._labels_info[label]['label'], | |||
| 'confidence': round(prob, _NUM_DIGIT), | |||
| 'saliency_maps': []} | |||
| if ground_truth_prob_sds or predicted_prob_sds: | |||
| for label, sd, low, hi in zip( | |||
| ground_truth_labels + predicted_labels, | |||
| ground_truth_prob_sds + predicted_prob_sds, | |||
| ground_truth_prob_lows + predicted_prob_lows, | |||
| ground_truth_prob_his + predicted_prob_his): | |||
| inference_info[label]['confidence_sd'] = sd | |||
| inference_info[label]['confidence_itl95'] = [low, hi] | |||
| if EventParser.is_attr_ready(sample_data, 'explanation'): | |||
| for explanation in sample_data.explanation: | |||
| explanation_dict = self._explanation_to_dict(explanation) | |||
| inference_info[explanation.label]['saliency_maps'].append(explanation_dict) | |||
| sample_info['inferences'] = list(inference_info.values()) | |||
| return sample_info, has_uncertainty | |||
| def _import_sample(self, sample): | |||
| """Add sample object of given sample id.""" | |||
| for label_id in sample.ground_truth_label: | |||
| self._labels_info[label_id]['sample_ids'].add(sample.sample_id) | |||
| sample_info, has_uncertainty = self._image_container_to_dict(sample) | |||
| self._samples_info.update({sample_info['id']: sample_info}) | |||
| self._uncertainty_enabled |= has_uncertainty | |||
| def get_all_samples(self): | |||
| """ | |||
| Return a list of sample information cachced in the explain job | |||
| Returns: | |||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||
| consists of: | |||
| - id (int): sample id | |||
| - name (str): basename of image | |||
| - labels (list[str]): list of labels | |||
| - inferences list[dict]) | |||
| """ | |||
| samples_in_list = list(self._samples_info.values()) | |||
| return samples_in_list | |||
| def _is_metadata_empty(self): | |||
| """Check whether metadata is loaded first.""" | |||
| if not self._explainers or not self._metrics or not self._labels: | |||
| return True | |||
| return False | |||
| def _import_data_from_event(self, event): | |||
| """Parse and import data from the event data.""" | |||
| tags = { | |||
| 'sample_id': PluginNameEnum.SAMPLE_ID, | |||
| 'benchmark': PluginNameEnum.BENCHMARK, | |||
| 'metadata': PluginNameEnum.METADATA | |||
| } | |||
| if 'metadata' not in event and self._is_metadata_empty(): | |||
| raise ValueError('metadata is empty, should write metadata first in the summary.') | |||
| for tag in tags: | |||
| if tag not in event: | |||
| continue | |||
| if tag == PluginNameEnum.SAMPLE_ID.value: | |||
| sample_event = event[tag] | |||
| sample_data = self._event_parser.parse_sample(sample_event) | |||
| if sample_data is not None: | |||
| self._import_sample(sample_data) | |||
| continue | |||
| if tag == PluginNameEnum.BENCHMARK.value: | |||
| benchmark_event = event[tag].benchmark | |||
| explain_score_dict, label_score_dict = EventParser.parse_benchmark(benchmark_event) | |||
| self._update_benchmark(explain_score_dict, label_score_dict) | |||
| elif tag == PluginNameEnum.METADATA.value: | |||
| metadata_event = event[tag].metadata | |||
| metadata = EventParser.parse_metadata(metadata_event) | |||
| self._explainers, self._metrics, self._labels = metadata | |||
| self._initialize_labels_info() | |||
| def load(self): | |||
| """ | |||
| Start loading data from parser. | |||
| """ | |||
| valid_file_names = [] | |||
| for filename in FileHandler.list_dir(self._summary_dir): | |||
| if FileHandler.is_file( | |||
| FileHandler.join(self._summary_dir, filename)): | |||
| valid_file_names.append(filename) | |||
| if not valid_file_names: | |||
| raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) | |||
| is_end = False | |||
| while not is_end: | |||
| is_clean, is_end, event = self._parser.parse_explain(valid_file_names) | |||
| if is_clean: | |||
| logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) | |||
| self._clean_job() | |||
| if event: | |||
| self._import_data_from_event(event) | |||
| def _clean_job(self): | |||
| """Clean the cached data in job.""" | |||
| self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) | |||
| self._create_time = ExplainJob.get_update_time(self._summary_dir) | |||
| self._labels.clear() | |||
| self._metrics.clear() | |||
| self._explainers.clear() | |||
| self._samples_info.clear() | |||
| self._labels_info.clear() | |||
| self._explainer_score_dict.clear() | |||
| self._label_score_dict.clear() | |||
| self._event_parser.clear() | |||
| def _update_benchmark(self, explainer_score_dict, labels_score_dict): | |||
| """Update the benchmark info.""" | |||
| for explainer, score in explainer_score_dict.items(): | |||
| self._explainer_score_dict[explainer].extend(score) | |||
| for explainer, score in labels_score_dict.items(): | |||
| for label, score_of_label in score.items(): | |||
| self._label_score_dict[explainer][label] = (self._label_score_dict[explainer].get(label, []) | |||
| + score_of_label) | |||
| @@ -0,0 +1,580 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ExplainLoader.""" | |||
| import os | |||
| import re | |||
| from collections import defaultdict | |||
| from datetime import datetime | |||
| from typing import Dict, Iterable, List, Optional, Union | |||
| from mindinsight.explainer.common.enums import ExplainFieldsEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.explainer.manager.explain_parser import ExplainParser | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||
| from mindinsight.utils.exceptions import ParamValueError, UnknownError | |||
| _NUM_DIGITS = 6 | |||
| _EXPLAIN_FIELD_NAMES = [ | |||
| ExplainFieldsEnum.SAMPLE_ID, | |||
| ExplainFieldsEnum.BENCHMARK, | |||
| ExplainFieldsEnum.METADATA, | |||
| ] | |||
| _SAMPLE_FIELD_NAMES = [ | |||
| ExplainFieldsEnum.GROUND_TRUTH_LABEL, | |||
| ExplainFieldsEnum.INFERENCE, | |||
| ExplainFieldsEnum.EXPLANATION, | |||
| ] | |||
| def _round(score): | |||
| """Take round of a number to given precision.""" | |||
| return round(score, _NUM_DIGITS) | |||
| class ExplainLoader: | |||
| """ExplainLoader which manage the record in the summary file.""" | |||
| def __init__(self, | |||
| loader_id: str, | |||
| summary_dir: str): | |||
| self._parser = ExplainParser(summary_dir) | |||
| self._loader_info = { | |||
| 'loader_id': loader_id, | |||
| 'summary_dir': summary_dir, | |||
| 'create_time': os.stat(summary_dir).st_ctime, | |||
| 'update_time': os.stat(summary_dir).st_mtime, | |||
| 'query_time': os.stat(summary_dir).st_ctime, | |||
| 'uncertainty_enabled': False, | |||
| } | |||
| self._samples = defaultdict(dict) | |||
| self._metadata = {'explainers': [], 'metrics': [], 'labels': []} | |||
| self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} | |||
| @property | |||
| def all_classes(self) -> List[Dict]: | |||
| """ | |||
| Return a list of detailed label information, including label id, label name and sample count of each label. | |||
| Returns: | |||
| list[dict], a list of dict, each dict contains: | |||
| - id (int): label id | |||
| - label (str): label name | |||
| - sample_count (int): number of samples for each label | |||
| """ | |||
| sample_count_per_label = defaultdict(int) | |||
| samples_copy = self._samples.copy() | |||
| for sample in samples_copy.values(): | |||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||
| for label in sample['ground_truth_label']: | |||
| sample_count_per_label[label] += 1 | |||
| all_classes_return = [] | |||
| for label_id, label_name in enumerate(self._metadata['labels']): | |||
| single_info = { | |||
| 'id': label_id, | |||
| 'label': label_name, | |||
| 'sample_count': sample_count_per_label[label_id] | |||
| } | |||
| all_classes_return.append(single_info) | |||
| return all_classes_return | |||
| @property | |||
| def query_time(self) -> float: | |||
| """Return query timestamp of explain loader.""" | |||
| return self._loader_info['query_time'] | |||
| @query_time.setter | |||
| def query_time(self, new_time: Union[datetime, float]): | |||
| """ | |||
| Update the query_time timestamp manually. | |||
| Args: | |||
| new_time (datetime.datetime or float): Updated query_time for the explain loader. | |||
| """ | |||
| if isinstance(new_time, datetime): | |||
| self._loader_info['query_time'] = new_time.timestamp() | |||
| elif isinstance(new_time, float): | |||
| self._loader_info['query_time'] = new_time | |||
| else: | |||
| raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' | |||
| .format(type(new_time))) | |||
| @property | |||
| def create_time(self) -> float: | |||
| """Return the create timestamp of summary file.""" | |||
| return self._loader_info['create_time'] | |||
| @create_time.setter | |||
| def create_time(self, new_time: Union[datetime, float]): | |||
| """ | |||
| Update the create_time manually | |||
| Args: | |||
| new_time (datetime.datetime or float): Updated create_time of summary_file. | |||
| """ | |||
| if isinstance(new_time, datetime): | |||
| self._loader_info['create_time'] = new_time.timestamp() | |||
| elif isinstance(new_time, float): | |||
| self._loader_info['create_time'] = new_time | |||
| else: | |||
| raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' | |||
| .format(type(new_time))) | |||
| @property | |||
| def explainers(self) -> List[str]: | |||
| """Return a list of explainer names recorded in the summary file.""" | |||
| return self._metadata['explainers'] | |||
| @property | |||
| def explainer_scores(self) -> List[Dict]: | |||
| """ | |||
| Return evaluation results for every explainer. | |||
| Returns: | |||
| list[dict], A list of evaluation results of each explainer. Each item contains: | |||
| - explainer (str): Name of evaluated explainer. | |||
| - evaluations (list[dict]): A list of evlauation results by different metrics. | |||
| - class_scores (list[dict]): A list of evaluation results on different labels. | |||
| Each item in the evaluations contains: | |||
| - metric (str): name of metric method | |||
| - score (float): evaluation result | |||
| Each item in the class_scores contains: | |||
| - label (str): Name of label | |||
| - evaluations (list[dict]): A list of evalution results on different labels by different metrics. | |||
| Each item in evaluations contains: | |||
| - metric (str): Name of metric method | |||
| - score (float): Evaluation scores of explainer on specific label by the metric. | |||
| """ | |||
| explainer_scores = [] | |||
| for explainer, explainer_score_on_metric in self._benchmark['explainer_score'].copy().items(): | |||
| metric_scores = [{'metric': metric, 'score': _round(score)} | |||
| for metric, score in explainer_score_on_metric.items()] | |||
| label_scores = [] | |||
| for label, label_score_on_metric in self._benchmark['label_score'][explainer].copy().items(): | |||
| score_of_single_label = { | |||
| 'label': self._metadata['labels'][label], | |||
| 'evaluations': [ | |||
| {'metric': metric, 'score': _round(score)} for metric, score in label_score_on_metric.items() | |||
| ], | |||
| } | |||
| label_scores.append(score_of_single_label) | |||
| explainer_scores.append({ | |||
| 'explainer': explainer, | |||
| 'evaluations': metric_scores, | |||
| 'class_scores': label_scores, | |||
| }) | |||
| return explainer_scores | |||
| @property | |||
| def labels(self) -> List[str]: | |||
| """Return the label recorded in the summary.""" | |||
| return self._metadata['labels'] | |||
| @property | |||
| def metrics(self) -> List[str]: | |||
| """Return a list of metric names recorded in the summary file.""" | |||
| return self._metadata['metrics'] | |||
| @property | |||
| def min_confidence(self) -> Optional[float]: | |||
| """Return minimum confidence used to filter the predicted labels.""" | |||
| return None | |||
| @property | |||
| def sample_count(self) -> int: | |||
| """ | |||
| Return total number of samples in the loader. | |||
| Since the loader only return available samples (i.e. with original image data and ground_truth_label loaded in | |||
| cache), the returned count only takes the available samples into account. | |||
| Return: | |||
| int, total number of available samples in the loading job. | |||
| """ | |||
| sample_count = 0 | |||
| samples_copy = self._samples.copy() | |||
| for sample in samples_copy.values(): | |||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||
| sample_count += 1 | |||
| return sample_count | |||
| @property | |||
| def samples(self) -> List[Dict]: | |||
| """Return the information of all samples in the job.""" | |||
| return self.get_all_samples() | |||
| @property | |||
| def train_id(self) -> str: | |||
| """Return ID of explain loader.""" | |||
| return self._loader_info['loader_id'] | |||
| @property | |||
| def uncertainty_enabled(self): | |||
| """Whethter uncertainty is enabled.""" | |||
| return self._loader_info['uncertainty_enabled'] | |||
| @property | |||
| def update_time(self) -> float: | |||
| """Return latest modification timestamp of summary file.""" | |||
| return self._loader_info['update_time'] | |||
| @update_time.setter | |||
| def update_time(self, new_time: Union[datetime, float]): | |||
| """ | |||
| Update the update_time manually. | |||
| Args: | |||
| new_time stamp (datetime.datetime or float): Updated time for the summary file. | |||
| """ | |||
| if isinstance(new_time, datetime): | |||
| self._loader_info['update_time'] = new_time.timestamp() | |||
| elif isinstance(new_time, float): | |||
| self._loader_info['update_time'] = new_time | |||
| else: | |||
| raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' | |||
| .format(type(new_time))) | |||
| def load(self): | |||
| """Start loading data from the latest summary file to the loader.""" | |||
| filenames = [] | |||
| for filename in FileHandler.list_dir(self._loader_info['summary_dir']): | |||
| if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)): | |||
| filenames.append(filename) | |||
| filenames = ExplainLoader._filter_files(filenames) | |||
| if not filenames: | |||
| raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' | |||
| % self._loader_info['summary_dir']) | |||
| is_end = False | |||
| while not is_end: | |||
| is_clean, is_end, event_dict = self._parser.parse_explain(filenames) | |||
| if is_clean: | |||
| logger.info('Summary file in %s update, reload the data in the summary.', | |||
| self._loader_info['summary_dir']) | |||
| self._clear_job() | |||
| if event_dict: | |||
| self._import_data_from_event(event_dict) | |||
| def get_all_samples(self) -> List[Dict]: | |||
| """ | |||
| Return a list of sample information cachced in the explain job | |||
| Returns: | |||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||
| consists of: | |||
| - id (int): sample id | |||
| - name (str): basename of image | |||
| - labels (list[str]): list of labels | |||
| - inferences list[dict]) | |||
| """ | |||
| returned_samples = [] | |||
| samples_copy = self._samples.copy() | |||
| for sample_id, sample_info in samples_copy.items(): | |||
| if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): | |||
| continue | |||
| returned_sample = { | |||
| 'id': sample_id, | |||
| 'name': str(sample_id), | |||
| 'image': sample_info['image'], | |||
| 'labels': sample_info['ground_truth_label'], | |||
| } | |||
| if not ExplainLoader._is_inference_valid(sample_info): | |||
| continue | |||
| inferences = {} | |||
| for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||
| sample_info['ground_truth_prob'] + sample_info['predicted_prob']): | |||
| inferences[label] = { | |||
| 'label': self._metadata['labels'][label], | |||
| 'confidence': _round(prob), | |||
| 'saliency_maps': [] | |||
| } | |||
| if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: | |||
| for label, std, low, high in zip( | |||
| sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||
| sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], | |||
| sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], | |||
| sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] | |||
| ): | |||
| inferences[label]['confidence_sd'] = _round(std) | |||
| inferences[label]['confidence_itl95'] = [_round(low), _round(high)] | |||
| for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): | |||
| for label, heatmap_path in label_heatmap_path_dict.items(): | |||
| if label in inferences: | |||
| inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) | |||
| returned_sample['inferences'] = list(inferences.values()) | |||
| returned_samples.append(returned_sample) | |||
| return returned_samples | |||
| def _import_data_from_event(self, event_dict: Dict): | |||
| """Parse and import data from the event data.""" | |||
| if 'metadata' not in event_dict and self._is_metadata_empty(): | |||
| raise ParamValueError('metadata is imcomplete, should write metadata first in the summary.') | |||
| for tag, event in event_dict.items(): | |||
| if tag == ExplainFieldsEnum.METADATA.value: | |||
| self._import_metadata_from_event(event.metadata) | |||
| elif tag == ExplainFieldsEnum.BENCHMARK.value: | |||
| self._import_benchmark_from_event(event.benchmark) | |||
| elif tag == ExplainFieldsEnum.SAMPLE_ID.value: | |||
| self._import_sample_from_event(event) | |||
| else: | |||
| logger.info('Unknown ExplainField: %s', tag) | |||
| def _is_metadata_empty(self): | |||
| """Check whether metadata is completely loaded first.""" | |||
| if not self._metadata['labels']: | |||
| return True | |||
| return False | |||
| def _import_metadata_from_event(self, metadata_event): | |||
| """Import the metadata from event into loader.""" | |||
| def take_union(existed_list, imported_data): | |||
| """Take union of existed_list and imported_data.""" | |||
| if isinstance(imported_data, Iterable): | |||
| for sample in imported_data: | |||
| if sample not in existed_list: | |||
| existed_list.append(sample) | |||
| take_union(self._metadata['explainers'], metadata_event.explain_method) | |||
| take_union(self._metadata['metrics'], metadata_event.benchmark_method) | |||
| take_union(self._metadata['labels'], metadata_event.label) | |||
| def _import_benchmark_from_event(self, benchmarks): | |||
| """ | |||
| Parse the benchmark event. | |||
| Benchmark data are separeted into 'explainer_score' and 'label_score'. 'explainer_score' contains overall | |||
| evaluation results of each explainer by different metrics, while 'label_score' additionally devides the results | |||
| w.r.t different labels. | |||
| The structure of self._benchmark['explainer_score'] demonstrates below: | |||
| { | |||
| explainer_1: {metric_name_1: score_1, ...}, | |||
| explainer_2: {metric_name_1: score_1, ...}, | |||
| ... | |||
| } | |||
| The structure of self._benchmark['label_score'] is: | |||
| { | |||
| explainer_1: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, | |||
| explainer_2: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, | |||
| ... | |||
| } | |||
| Args: | |||
| benchmarks (benchmark_container): Parsed benchmarks data from summary file. | |||
| """ | |||
| explainer_score = self._benchmark['explainer_score'] | |||
| label_score = self._benchmark['label_score'] | |||
| for benchmark in benchmarks: | |||
| explainer = benchmark.explain_method | |||
| metric = benchmark.benchmark_method | |||
| metric_score = benchmark.total_score | |||
| label_score_event = benchmark.label_score | |||
| explainer_score[explainer][metric] = metric_score | |||
| new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) | |||
| for label, scores_of_metric in new_label_score_dict.items(): | |||
| if label not in label_score[explainer]: | |||
| label_score[explainer][label] = {} | |||
| label_score[explainer][label].update(scores_of_metric) | |||
| def _import_sample_from_event(self, sample): | |||
| """ | |||
| Parse the sample event. | |||
| Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored | |||
| in the following structure. | |||
| - ground_truth_labels (list[int]): A list of ground truth labels of the sample. | |||
| - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. | |||
| - predicted_labels (list[int]): A list of predicted labels from the black-box model. | |||
| - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. | |||
| - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary | |||
| of saliency maps. The structure of explanations demonstrates below: | |||
| { | |||
| explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | |||
| explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | |||
| ... | |||
| } | |||
| """ | |||
| if not getattr(sample, 'sample_id', None): | |||
| raise ParamValueError('sample_event has no sample_id') | |||
| sample_id = sample.sample_id | |||
| samples_copy = self._samples.copy() | |||
| if sample_id not in samples_copy: | |||
| self._samples[sample_id] = { | |||
| 'ground_truth_label': [], | |||
| 'ground_truth_prob': [], | |||
| 'ground_truth_prob_sd': [], | |||
| 'ground_truth_prob_itl95_low': [], | |||
| 'ground_truth_prob_itl95_hi': [], | |||
| 'predicted_label': [], | |||
| 'predicted_prob': [], | |||
| 'predicted_prob_sd': [], | |||
| 'predicted_prob_itl95_low': [], | |||
| 'predicted_prob_itl95_hi': [], | |||
| 'explanation': defaultdict(dict) | |||
| } | |||
| if sample.image_path: | |||
| self._samples[sample_id]['image'] = sample.image_path | |||
| for tag in _SAMPLE_FIELD_NAMES: | |||
| try: | |||
| if ExplainLoader._is_attr_empty(sample, tag.value): | |||
| continue | |||
| if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: | |||
| self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) | |||
| elif tag == ExplainFieldsEnum.INFERENCE: | |||
| self._import_inference_from_event(sample, sample_id) | |||
| elif tag == ExplainFieldsEnum.EXPLANATION: | |||
| self._import_explanation_from_event(sample, sample_id) | |||
| except UnknownError as ex: | |||
| logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) | |||
| def _import_inference_from_event(self, event, sample_id): | |||
| """Parse the inference event.""" | |||
| inference = event.inference | |||
| self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) | |||
| self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) | |||
| self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) | |||
| self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) | |||
| self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) | |||
| self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) | |||
| self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) | |||
| self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) | |||
| self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) | |||
| if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: | |||
| self._loader_info['uncertainty_enabled'] = True | |||
| def _import_explanation_from_event(self, event, sample_id): | |||
| """Parse the explanation event.""" | |||
| if self._samples[sample_id]['explanation'] is None: | |||
| self._samples[sample_id]['explanation'] = defaultdict(dict) | |||
| sample_explanation = self._samples[sample_id]['explanation'] | |||
| for explanation_item in event.explanation: | |||
| explainer = explanation_item.explain_method | |||
| label = explanation_item.label | |||
| sample_explanation[explainer][label] = explanation_item.heatmap_path | |||
| def _clear_job(self): | |||
| """Clear the cached data and update the time info of the loader.""" | |||
| self._samples.clear() | |||
| self._loader_info['create_time'] = os.stat(self._loader_info['summary_dir']).st_ctime | |||
| self._loader_info['update_time'] = os.stat(self._loader_info['summary_dir']).st_mtime | |||
| self._loader_info['query_time'] = max(self._loader_info['update_time'], self._loader_info['query_time']) | |||
| def clear_inner_dict(outer_dict): | |||
| """Clear the inner structured data of the given dict.""" | |||
| for item in outer_dict.values(): | |||
| item.clear() | |||
| map(clear_inner_dict, [self._metadata, self._benchmark]) | |||
| @staticmethod | |||
| def _filter_files(filenames): | |||
| """ | |||
| Gets a list of summary files. | |||
| Args: | |||
| filenames (list[str]): File name list, like [filename1, filename2]. | |||
| Returns: | |||
| list[str], filename list. | |||
| """ | |||
| return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), | |||
| filenames)) | |||
| @staticmethod | |||
| def _is_attr_empty(event, attr_name) -> bool: | |||
| if not getattr(event, attr_name): | |||
| return True | |||
| for item in getattr(event, attr_name): | |||
| if not isinstance(item, list) or item: | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool: | |||
| if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']): | |||
| logger.info('length of ground_truth_prob does not match the length of ground_truth_label' | |||
| 'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.' | |||
| 'sample_id is : %s.', | |||
| len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id) | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def _is_inference_valid(sample): | |||
| """ | |||
| Check whether the inference data is empty or have the same length. | |||
| If the probs have different length with the labels, it can be confusing when assigning each prob to label. | |||
| 'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be | |||
| empty, so empty prob will pass the check. | |||
| """ | |||
| ground_truth_len = len(sample['ground_truth_label']) | |||
| for name in ['ground_truth_prob', 'ground_truth_prob_sd', | |||
| 'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']: | |||
| if sample[name] and len(sample[name]) != ground_truth_len: | |||
| return False | |||
| predicted_len = len(sample['predicted_label']) | |||
| for name in ['predicted_prob', 'predicted_prob_sd', | |||
| 'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']: | |||
| if sample[name] and len(sample[name]) != predicted_len: | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool: | |||
| if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']): | |||
| logger.info('length of predicted_probs does not match the length of predicted_labels' | |||
| 'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.', | |||
| len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id) | |||
| return False | |||
| return True | |||
| @staticmethod | |||
| def _score_event_to_dict(label_score_event, metric) -> Dict: | |||
| """Transfer metric scores per label to pre-defined structure.""" | |||
| new_label_score_dict = defaultdict(dict) | |||
| for label_id, label_score in enumerate(label_score_event): | |||
| new_label_score_dict[label_id][metric] = label_score | |||
| return new_label_score_dict | |||
| @@ -17,17 +17,20 @@ | |||
| import os | |||
| import threading | |||
| import time | |||
| from collections import OrderedDict | |||
| from datetime import datetime | |||
| from typing import Optional | |||
| from mindinsight.conf import settings | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.datavisual.common.enums import BaseEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.explainer.manager.explain_job import ExplainJob | |||
| from mindinsight.explainer.manager.explain_loader import ExplainLoader | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | |||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError | |||
| _MAX_LOADER_NUM = 3 | |||
| _MAX_INTERVAL = 3 | |||
| _MAX_LOADERS_NUM = 3 | |||
| class _ExplainManagerStatus(BaseEnum): | |||
| @@ -43,245 +46,63 @@ class ExplainManager: | |||
| def __init__(self, summary_base_dir: str): | |||
| self._summary_base_dir = summary_base_dir | |||
| self._loader_pool = {} | |||
| self._deleted_ids = [] | |||
| self._status = _ExplainManagerStatus.INIT.value | |||
| self._loader_pool = OrderedDict() | |||
| self._loading_status = _ExplainManagerStatus.INIT.value | |||
| self._status_mutex = threading.Lock() | |||
| self._loader_pool_mutex = threading.Lock() | |||
| self._max_loader_num = _MAX_LOADER_NUM | |||
| self._reload_interval = None | |||
| self._max_loaders_num = _MAX_LOADERS_NUM | |||
| self._summary_watcher = SummaryWatcher() | |||
| def _reload_data(self): | |||
| """periodically load summary from file.""" | |||
| while True: | |||
| try: | |||
| self._load_data() | |||
| if not self._reload_interval: | |||
| break | |||
| time.sleep(self._reload_interval) | |||
| except UnknownError as ex: | |||
| logger.exception(ex) | |||
| logger.error('Unknown Error raise when loading summary files, status: %r, and loader pool size is %r.' | |||
| 'Detail: %s', self._status, len(self._loader_pool), str(ex)) | |||
| self._status = _ExplainManagerStatus.INVALID.value | |||
| def _load_data(self): | |||
| """Loading the summary in the given base directory.""" | |||
| logger.info('Start to load data, reload interval: %r.', self._reload_interval) | |||
| with self._status_mutex: | |||
| if self._status == _ExplainManagerStatus.LOADING.value: | |||
| logger.info('Current status is %s, will ignore to load data.', self._status) | |||
| return | |||
| self._status = _ExplainManagerStatus.LOADING.value | |||
| try: | |||
| self._generate_loaders() | |||
| self._execute_load_data() | |||
| except Exception as ex: | |||
| raise UnknownError(ex) | |||
| if not self._loader_pool: | |||
| self._status = _ExplainManagerStatus.INVALID.value | |||
| else: | |||
| self._status = _ExplainManagerStatus.DONE.value | |||
| logger.info('Load event data end, status: %r, and loader pool size is %r', | |||
| self._status, len(self._loader_pool)) | |||
| def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): | |||
| """update the update time of loader of given id.""" | |||
| if latest_update_time is None: | |||
| latest_update_time = time.time() | |||
| self._loader_pool[loader_id].latest_update_time = latest_update_time | |||
| def _delete_loader(self, loader_id): | |||
| """delete loader given loader_id""" | |||
| if self._loader_pool.get(loader_id, None) is not None: | |||
| self._loader_pool.pop(loader_id) | |||
| logger.debug('delete loader %s', loader_id) | |||
| def _add_loader(self, loader): | |||
| """add loader to the loader_pool.""" | |||
| if len(self._loader_pool) >= _MAX_LOADER_NUM: | |||
| delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1 | |||
| sorted_loaders = sorted( | |||
| self._loader_pool.items(), | |||
| key=lambda x: x[1].latest_update_time) | |||
| for index in range(delete_num): | |||
| delete_loader_id = sorted_loaders[index][0] | |||
| self._delete_loader(delete_loader_id) | |||
| self._loader_pool.update({loader.loader_id: loader}) | |||
| def _deal_loaders(self, latest_loaders): | |||
| """"update the loader pool.""" | |||
| with self._loader_pool_mutex: | |||
| for loader_id, loader in latest_loaders: | |||
| if self._loader_pool.get(loader_id, None) is None: | |||
| self._add_loader(loader) | |||
| continue | |||
| if (self._loader_pool[loader_id].latest_update_time | |||
| < loader.latest_update_time): | |||
| self._update_loader_latest_update_time( | |||
| loader_id, loader.latest_update_time) | |||
| @staticmethod | |||
| def _generate_loader_id(relative_path): | |||
| """Generate loader id for given path""" | |||
| loader_id = relative_path | |||
| return loader_id | |||
| @staticmethod | |||
| def _generate_loader_name(relative_path): | |||
| """Generate_loader name for given path.""" | |||
| loader_name = relative_path | |||
| return loader_name | |||
| def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob: | |||
| """Generate explain job from given relative path.""" | |||
| current_dir = os.path.realpath(FileHandler.join( | |||
| self._summary_base_dir, relative_path | |||
| )) | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = ExplainJob( | |||
| job_id=loader_id, | |||
| summary_dir=current_dir, | |||
| create_time=ExplainJob.get_create_time(current_dir), | |||
| latest_update_time=ExplainJob.get_update_time(current_dir)) | |||
| return loader | |||
| def _generate_loaders(self): | |||
| """Generate job loaders from the summary watcher.""" | |||
| dir_map_mtime_dict = {} | |||
| loader_dict = {} | |||
| min_modify_time = None | |||
| _, summaries = SummaryWatcher().list_explain_directories( | |||
| self._summary_base_dir) | |||
| for item in summaries: | |||
| relative_path = item.get('relative_path') | |||
| modify_time = item.get('update_time').timestamp() | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = self._loader_pool.get(loader_id, None) | |||
| if loader is not None and loader.latest_update_time > modify_time: | |||
| modify_time = loader.latest_update_time | |||
| if min_modify_time is None: | |||
| min_modify_time = modify_time | |||
| if len(dir_map_mtime_dict) < _MAX_LOADER_NUM: | |||
| if modify_time < min_modify_time: | |||
| min_modify_time = modify_time | |||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||
| else: | |||
| if modify_time >= min_modify_time: | |||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||
| sorted_dir_tuple = sorted(dir_map_mtime_dict.items(), | |||
| key=lambda d: d[1])[-_MAX_LOADER_NUM:] | |||
| for relative_path, modify_time in sorted_dir_tuple: | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = self._generate_loader_by_relative_path(relative_path) | |||
| loader_dict.update({loader_id: loader}) | |||
| sorted_loaders = sorted(loader_dict.items(), | |||
| key=lambda x: x[1].latest_update_time) | |||
| latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:] | |||
| self._deal_loaders(latest_loaders) | |||
| def _execute_loader(self, loader_id): | |||
| """Execute the data loading.""" | |||
| try: | |||
| with self._loader_pool_mutex: | |||
| loader = self._loader_pool.get(loader_id, None) | |||
| if loader is None: | |||
| logger.debug('Loader %r has been deleted, will not load' | |||
| 'data', loader_id) | |||
| return | |||
| loader.load() | |||
| @property | |||
| def summary_base_dir(self): | |||
| """Return the base directory for summary records.""" | |||
| return self._summary_base_dir | |||
| except MindInsightException as ex: | |||
| logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) | |||
| with self._loader_pool_mutex: | |||
| self._delete_loader(loader_id) | |||
| def start_load_data(self, reload_interval: int = 0): | |||
| """ | |||
| Start individual thread to cache explain_jobs and loading summary data periodically. | |||
| def _execute_load_data(self): | |||
| """Execute the loader in the pool to load data.""" | |||
| loader_pool = self._get_snapshot_loader_pool() | |||
| for loader_id in loader_pool: | |||
| self._execute_loader(loader_id) | |||
| Args: | |||
| reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded | |||
| once. Default: 0. | |||
| """ | |||
| thread = threading.Thread(target=self._repeat_loading, | |||
| name='start_load_thread', | |||
| args=(reload_interval,), | |||
| daemon=True) | |||
| time.sleep(1) | |||
| thread.start() | |||
| def _get_snapshot_loader_pool(self): | |||
| """Get snapshot of loader_pool.""" | |||
| with self._loader_pool_mutex: | |||
| return dict(self._loader_pool) | |||
| def get_job(self, loader_id: str) -> Optional[ExplainLoader]: | |||
| """ | |||
| Return ExplainLoader given loader_id. | |||
| def _check_status_valid(self): | |||
| """Check manager status.""" | |||
| if self._status == _ExplainManagerStatus.INIT.value: | |||
| raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status) | |||
| If explain job w.r.t given loader_id is not found, None will be returned. | |||
| @staticmethod | |||
| def _check_train_id_valid(train_id: str): | |||
| """Verify the train_id is valid.""" | |||
| if not train_id.startswith('./'): | |||
| logger.warning('train_id does not start with "./"') | |||
| return False | |||
| if len(train_id.split('/')) > 2: | |||
| logger.warning('train_id contains multiple "/"') | |||
| return False | |||
| return True | |||
| def _check_train_job_exist(self, train_id): | |||
| """Verify thee train_job is existed given train_id.""" | |||
| if train_id in self._loader_pool: | |||
| return | |||
| self._check_train_id_valid(train_id) | |||
| if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id): | |||
| return | |||
| raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id) | |||
| Args: | |||
| loader_id (str): The id of expected ExplainLoader | |||
| def _reload_data_again(self): | |||
| """Reload the data one more time.""" | |||
| logger.debug('Start to reload data again.') | |||
| thread = threading.Thread(target=self._load_data, | |||
| name='reload_data_thread') | |||
| thread.daemon = False | |||
| thread.start() | |||
| Return: | |||
| explain_job | |||
| """ | |||
| self._check_status_valid() | |||
| def _get_job(self, train_id): | |||
| """Retrieve train_job given train_id.""" | |||
| is_reload = False | |||
| with self._loader_pool_mutex: | |||
| loader = self._loader_pool.get(train_id, None) | |||
| if loader is None: | |||
| relative_path = train_id | |||
| temp_loader = self._generate_loader_by_relative_path( | |||
| relative_path) | |||
| if loader_id in self._loader_pool: | |||
| self._loader_pool[loader_id].query_time = datetime.now().timestamp() | |||
| self._loader_pool.move_to_end(loader_id, last=False) | |||
| return self._loader_pool[loader_id] | |||
| if temp_loader is None: | |||
| return None | |||
| self._add_loader(temp_loader) | |||
| is_reload = True | |||
| if is_reload: | |||
| self._reload_data_again() | |||
| try: | |||
| loader = self._generate_loader_from_relative_path(loader_id) | |||
| loader.query_time = datetime.now().timestamp() | |||
| self._add_loader(loader) | |||
| self._reload_data_again() | |||
| except ParamValueError: | |||
| logger.warning('Cannot find summary in path: %s. No explain_job will be returned.', loader_id) | |||
| return None | |||
| return loader | |||
| @property | |||
| def summary_base_dir(self): | |||
| """Return the base directory for summary records.""" | |||
| return self._summary_base_dir | |||
| def get_job_list(self, offset=0, limit=None): | |||
| """ | |||
| Return List of explain jobs. includes job ID, create and update time. | |||
| @@ -298,44 +119,146 @@ class ExplainManager: | |||
| - create_time (datetime): Creation time of summary file. | |||
| - update_time (datetime): Modification time of summary file. | |||
| """ | |||
| watcher = SummaryWatcher() | |||
| total, dir_infos = \ | |||
| watcher.list_explain_directories(self._summary_base_dir, | |||
| offset=offset, limit=limit) | |||
| self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit) | |||
| return total, dir_infos | |||
| def get_job(self, train_id): | |||
| def _repeat_loading(self, repeat_interval): | |||
| """Periodically loading summary.""" | |||
| while True: | |||
| try: | |||
| logger.info('Start to load data, repeat interval: %r.', repeat_interval) | |||
| self._load_data() | |||
| if not repeat_interval: | |||
| return | |||
| time.sleep(repeat_interval) | |||
| except UnknownError as ex: | |||
| logger.exception(ex) | |||
| logger.error('Unexpected error happens when loading data. Loading status: %s, loading pool size: %d' | |||
| 'Detail: %s', self._loading_status, len(self._loader_pool), str(ex)) | |||
| def _load_data(self): | |||
| """ | |||
| Prepare loaders in cache and start loading the data from summaries. | |||
| Only a limited number of loaders will be cached in terms of updated_time or query_time. The size of cache | |||
| pool is determined by _MAX_LOADERS_NUM. When the manager start loading data, only the lastest _MAX_LOADER_NUM | |||
| summaries will be loaded in cache. If a cached loader if queries by 'get_job', the query_time of the loader | |||
| will be updated as well as the the loader moved to the end of cache. If an uncached summary is queried, | |||
| a new loader instance will be generated and put to the end cache. | |||
| """ | |||
| Return ExplainJob given train_id. | |||
| try: | |||
| with self._status_mutex: | |||
| if self._loading_status == _ExplainManagerStatus.LOADING.value: | |||
| logger.info('Current status is %s, will ignore to load data.', self._loading_status) | |||
| return | |||
| If explain job w.r.t given train_id is not found, None will be returned. | |||
| self._loading_status = _ExplainManagerStatus.LOADING.value | |||
| Args: | |||
| train_id (str): The id of expected ExplainJob | |||
| self._cache_loaders() | |||
| self._execute_loading() | |||
| Return: | |||
| explain_job | |||
| """ | |||
| self._check_status_valid() | |||
| self._check_train_job_exist(train_id) | |||
| if not self._loader_pool: | |||
| self._loading_status = _ExplainManagerStatus.INVALID.value | |||
| else: | |||
| self._loading_status = _ExplainManagerStatus.DONE.value | |||
| logger.info('Load event data end, status: %s, and loader pool size: %d', | |||
| self._loading_status, len(self._loader_pool)) | |||
| except Exception as ex: | |||
| self._loading_status = _ExplainManagerStatus.INVALID.value | |||
| logger.exception(ex) | |||
| raise UnknownError(str(ex)) | |||
| def _cache_loaders(self): | |||
| """Cache explain loader in cache pool.""" | |||
| dir_map_mtime_dict = [] | |||
| _, summaries_info = self._summary_watcher.list_explain_directories(self._summary_base_dir) | |||
| for summary_info in summaries_info: | |||
| summary_path = summary_info.get('relative_path') | |||
| summary_update_time = summary_info.get('update_time').timestamp() | |||
| if summary_path in self._loader_pool: | |||
| summary_update_time = max(summary_update_time, self._loader_pool[summary_path].query_time) | |||
| dir_map_mtime_dict.append((summary_info, summary_update_time)) | |||
| sorted_summaries_info = sorted(dir_map_mtime_dict, key=lambda x: x[1])[-_MAX_LOADERS_NUM:] | |||
| loader = self._get_job(train_id) | |||
| if loader is None: | |||
| return None | |||
| with self._loader_pool_mutex: | |||
| for summary_info, query_time in sorted_summaries_info: | |||
| summary_path = summary_info['relative_path'] | |||
| if summary_path not in self._loader_pool: | |||
| loader = self._generate_loader_from_relative_path(summary_path) | |||
| self._add_loader(loader) | |||
| else: | |||
| self._loader_pool[summary_path].query_time = query_time | |||
| self._loader_pool.move_to_end(summary_path, last=False) | |||
| def _generate_loader_from_relative_path(self, relative_path: str) -> ExplainLoader: | |||
| """Generate explain loader from the given relative path.""" | |||
| self._check_summary_exist(relative_path) | |||
| current_dir = os.path.realpath(FileHandler.join(self._summary_base_dir, relative_path)) | |||
| loader_id = self._generate_loader_id(relative_path) | |||
| loader = ExplainLoader(loader_id=loader_id, summary_dir=current_dir) | |||
| return loader | |||
| def start_load_data(self, reload_interval=_MAX_INTERVAL): | |||
| """ | |||
| Start threads for loading data. | |||
| def _add_loader(self, loader): | |||
| """add loader to the loader_pool.""" | |||
| if loader.train_id not in self._loader_pool: | |||
| self._loader_pool[loader.train_id] = loader | |||
| else: | |||
| self._loader_pool.move_to_end(loader.train_id) | |||
| Args: | |||
| reload_interval (int): interval to reload the summary from file | |||
| """ | |||
| self._reload_interval = reload_interval | |||
| while len(self._loader_pool) > self._max_loaders_num: | |||
| self._loader_pool.popitem(last=False) | |||
| def _execute_loading(self): | |||
| """Execute the data loading.""" | |||
| for loader_id in list(self._loader_pool.keys()): | |||
| try: | |||
| with self._loader_pool_mutex: | |||
| loader = self._loader_pool.get(loader_id, None) | |||
| if loader is None: | |||
| logger.debug('Loader %r has been deleted, will not load data', loader_id) | |||
| return | |||
| loader.load() | |||
| except MindInsightException as ex: | |||
| logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) | |||
| with self._loader_pool_mutex: | |||
| self._delete_loader(loader_id) | |||
| def _delete_loader(self, loader_id): | |||
| """delete loader given loader_id""" | |||
| if loader_id in self._loader_pool: | |||
| self._loader_pool.pop(loader_id) | |||
| logger.debug('delete loader %s', loader_id) | |||
| thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') | |||
| thread.daemon = True | |||
| def _check_status_valid(self): | |||
| """Check manager status.""" | |||
| if self._loading_status == _ExplainManagerStatus.INIT.value: | |||
| raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._loading_status) | |||
| def _check_summary_exist(self, loader_id): | |||
| """Verify thee train_job is existed given loader_id.""" | |||
| if not self._summary_watcher.is_summary_directory(self._summary_base_dir, loader_id): | |||
| raise ParamValueError('Can not find the train job in the manager.') | |||
| def _reload_data_again(self): | |||
| """Reload the data one more time.""" | |||
| logger.debug('Start to reload data again.') | |||
| thread = threading.Thread(target=self._load_data, name='reload_data_thread') | |||
| thread.daemon = False | |||
| thread.start() | |||
| # wait for data loading | |||
| time.sleep(1) | |||
| @staticmethod | |||
| def _generate_loader_id(relative_path): | |||
| """Generate loader id for given path""" | |||
| loader_id = relative_path | |||
| return loader_id | |||
| EXPLAIN_MANAGER = ExplainManager(summary_base_dir=settings.SUMMARY_BASE_DIR) | |||
| @@ -17,49 +17,41 @@ File parser for MindExplain data. | |||
| This module is used to parse the MindExplain log file. | |||
| """ | |||
| import re | |||
| import collections | |||
| from collections import namedtuple | |||
| from google.protobuf.message import DecodeError | |||
| from mindinsight.datavisual.common import exceptions | |||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||
| from mindinsight.explainer.common.enums import ExplainFieldsEnum | |||
| from mindinsight.explainer.common.log import logger | |||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||
| from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser | |||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | |||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain | |||
| from mindinsight.utils.exceptions import UnknownError | |||
| HEADER_SIZE = 8 | |||
| CRC_STR_SIZE = 4 | |||
| MAX_EVENT_STRING = 500000000 | |||
| BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) | |||
| MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) | |||
| class ImageDataContainer: | |||
| """ | |||
| Container for image data to allow pickling. | |||
| Args: | |||
| explain_message (Explain): Explain proto buffer message. | |||
| """ | |||
| def __init__(self, explain_message: Explain): | |||
| self.sample_id = explain_message.sample_id | |||
| self.image_path = explain_message.image_path | |||
| self.ground_truth_label = explain_message.ground_truth_label | |||
| self.inference = explain_message.inference | |||
| self.explanation = explain_message.explanation | |||
| self.status = explain_message.status | |||
| class _ExplainParser(_SummaryParser): | |||
| BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) | |||
| MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) | |||
| InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | |||
| 'ground_truth_prob_sd', | |||
| 'ground_truth_prob_itl95_low', | |||
| 'ground_truth_prob_itl95_hi', | |||
| 'predicted_label', | |||
| 'predicted_prob', | |||
| 'predicted_prob_sd', | |||
| 'predicted_prob_itl95_low', | |||
| 'predicted_prob_itl95_hi']) | |||
| SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', | |||
| 'explanation', 'status']) | |||
| class ExplainParser(_SummaryParser): | |||
| """The summary file parser.""" | |||
| def __init__(self, summary_dir): | |||
| super(_ExplainParser, self).__init__(summary_dir) | |||
| super(ExplainParser, self).__init__(summary_dir) | |||
| self._latest_filename = '' | |||
| def parse_explain(self, filenames): | |||
| @@ -71,8 +63,7 @@ class _ExplainParser(_SummaryParser): | |||
| Returns: | |||
| bool, True if all the summary files are finished loading. | |||
| """ | |||
| summary_files = self.filter_files(filenames) | |||
| summary_files = self.sort_files(summary_files) | |||
| summary_files = self.sort_files(filenames) | |||
| is_end = False | |||
| is_clean = False | |||
| @@ -125,20 +116,6 @@ class _ExplainParser(_SummaryParser): | |||
| logger.exception(ex) | |||
| raise UnknownError(str(ex)) | |||
| def filter_files(self, filenames): | |||
| """ | |||
| Gets a list of summary files. | |||
| Args: | |||
| filenames (list[str]): File name list, like [filename1, filename2]. | |||
| Returns: | |||
| list[str], filename list. | |||
| """ | |||
| return list(filter( | |||
| lambda filename: (re.search(r'summary\.\d+', filename) | |||
| and filename.endswith("_explain")), filenames)) | |||
| @staticmethod | |||
| def _event_decode(event_str): | |||
| """ | |||
| @@ -153,9 +130,9 @@ class _ExplainParser(_SummaryParser): | |||
| logger.debug("Deserialize event string completed.") | |||
| fields = { | |||
| 'sample_id': PluginNameEnum.SAMPLE_ID, | |||
| 'benchmark': PluginNameEnum.BENCHMARK, | |||
| 'metadata': PluginNameEnum.METADATA | |||
| 'sample_id': ExplainFieldsEnum.SAMPLE_ID, | |||
| 'benchmark': ExplainFieldsEnum.BENCHMARK, | |||
| 'metadata': ExplainFieldsEnum.METADATA | |||
| } | |||
| tensor_event_value = getattr(event, 'explain') | |||
| @@ -163,19 +140,19 @@ class _ExplainParser(_SummaryParser): | |||
| field_list = [] | |||
| tensor_value_list = [] | |||
| for field in fields: | |||
| if not getattr(tensor_event_value, field): | |||
| if not getattr(tensor_event_value, field, False): | |||
| continue | |||
| if PluginNameEnum.METADATA.value == field and not tensor_event_value.metadata.label: | |||
| if ExplainFieldsEnum.METADATA.value == field and not tensor_event_value.metadata.label: | |||
| continue | |||
| tensor_value = None | |||
| if field == PluginNameEnum.SAMPLE_ID.value: | |||
| tensor_value = _ExplainParser._add_image_data(tensor_event_value) | |||
| elif field == PluginNameEnum.BENCHMARK.value: | |||
| tensor_value = _ExplainParser._add_benchmark(tensor_event_value) | |||
| elif field == PluginNameEnum.METADATA.value: | |||
| tensor_value = _ExplainParser._add_metadata(tensor_event_value) | |||
| if field == ExplainFieldsEnum.SAMPLE_ID.value: | |||
| tensor_value = ExplainParser._add_image_data(tensor_event_value) | |||
| elif field == ExplainFieldsEnum.BENCHMARK.value: | |||
| tensor_value = ExplainParser._add_benchmark(tensor_event_value) | |||
| elif field == ExplainFieldsEnum.METADATA.value: | |||
| tensor_value = ExplainParser._add_metadata(tensor_event_value) | |||
| logger.debug("Event generated, label is %s, step is %s.", field, event.step) | |||
| field_list.append(field) | |||
| tensor_value_list.append(tensor_value) | |||
| @@ -189,8 +166,26 @@ class _ExplainParser(_SummaryParser): | |||
| Args: | |||
| tensor_event_value: the object of Explain message | |||
| """ | |||
| image_data = ImageDataContainer(tensor_event_value) | |||
| return image_data | |||
| inference = InferfenceContainer( | |||
| ground_truth_prob=tensor_event_value.inference.ground_truth_prob, | |||
| ground_truth_prob_sd=tensor_event_value.inference.ground_truth_prob_sd, | |||
| ground_truth_prob_itl95_low=tensor_event_value.inference.ground_truth_prob_itl95_low, | |||
| ground_truth_prob_itl95_hi=tensor_event_value.inference.ground_truth_prob_itl95_hi, | |||
| predicted_label=tensor_event_value.inference.predicted_label, | |||
| predicted_prob=tensor_event_value.inference.predicted_prob, | |||
| predicted_prob_sd=tensor_event_value.inference.predicted_prob_sd, | |||
| predicted_prob_itl95_low=tensor_event_value.inference.predicted_prob_itl95_low, | |||
| predicted_prob_itl95_hi=tensor_event_value.inference.predicted_prob_itl95_hi | |||
| ) | |||
| sample_data = SampleContainer( | |||
| sample_id=tensor_event_value.sample_id, | |||
| image_path=tensor_event_value.image_path, | |||
| ground_truth_label=tensor_event_value.ground_truth_label, | |||
| inference=inference, | |||
| explanation=tensor_event_value.explanation, | |||
| status=tensor_event_value.status | |||
| ) | |||
| return sample_data | |||
| @staticmethod | |||
| def _add_benchmark(tensor_event_value): | |||
| @@ -25,7 +25,7 @@ class MockExplainJob: | |||
| self.create_time = datetime.timestamp( | |||
| datetime.strptime("2020-10-01 20:21:23", | |||
| ExplainJobEncap.DATETIME_FORMAT)) | |||
| self.latest_update_time = self.create_time | |||
| self.update_time = self.create_time | |||
| self.sample_count = 1999 | |||
| self.min_confidence = 0.5 | |||
| self.explainers = ["Gradient"] | |||