| @@ -13,7 +13,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Module init file.""" | """Module init file.""" | ||||
| from mindinsight.conf import settings | |||||
| from mindinsight.backend.explainer.explainer_api import init_module as init_query_module | from mindinsight.backend.explainer.explainer_api import init_module as init_query_module | ||||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||||
| def init_module(app): | def init_module(app): | ||||
| @@ -27,3 +29,4 @@ def init_module(app): | |||||
| """ | """ | ||||
| init_query_module(app) | init_query_module(app) | ||||
| EXPLAIN_MANAGER.start_load_data(reload_interval=settings.RELOAD_INTERVAL) | |||||
| @@ -29,32 +29,16 @@ from mindinsight.datavisual.common.exceptions import ImageNotExistError | |||||
| from mindinsight.datavisual.common.validation import Validation | from mindinsight.datavisual.common.validation import Validation | ||||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | ||||
| from mindinsight.datavisual.utils.tools import get_train_id | from mindinsight.datavisual.utils.tools import get_train_id | ||||
| from mindinsight.explainer.manager.explain_manager import ExplainManager | |||||
| from mindinsight.explainer.manager.explain_manager import EXPLAIN_MANAGER | |||||
| from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap | ||||
| from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap | from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap | ||||
| from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap | ||||
| from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap | from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap | ||||
| URL_PREFIX = settings.URL_PATH_PREFIX+settings.API_PREFIX | |||||
| URL_PREFIX = settings.URL_PATH_PREFIX + settings.API_PREFIX | |||||
| BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX) | ||||
| class ExplainManagerHolder: | |||||
| """ExplainManger instance holder.""" | |||||
| static_instance = None | |||||
| @classmethod | |||||
| def get_instance(cls): | |||||
| return cls.static_instance | |||||
| @classmethod | |||||
| def initialize(cls): | |||||
| cls.static_instance = ExplainManager(settings.SUMMARY_BASE_DIR) | |||||
| cls.static_instance.start_load_data() | |||||
| def _image_url_formatter(train_id, image_path, image_type): | def _image_url_formatter(train_id, image_path, image_type): | ||||
| """Returns image url.""" | """Returns image url.""" | ||||
| data = { | data = { | ||||
| @@ -91,7 +75,7 @@ def query_explain_jobs(): | |||||
| offset = Validation.check_offset(offset=offset) | offset = Validation.check_offset(offset=offset) | ||||
| limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT) | limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT) | ||||
| encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance()) | |||||
| encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) | |||||
| total, jobs = encapsulator.query_explain_jobs(offset, limit) | total, jobs = encapsulator.query_explain_jobs(offset, limit) | ||||
| return jsonify({ | return jsonify({ | ||||
| @@ -107,7 +91,7 @@ def query_explain_job(): | |||||
| train_id = get_train_id(request) | train_id = get_train_id(request) | ||||
| if train_id is None: | if train_id is None: | ||||
| raise ParamMissError("train_id") | raise ParamMissError("train_id") | ||||
| encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance()) | |||||
| encapsulator = ExplainJobEncap(EXPLAIN_MANAGER) | |||||
| metadata = encapsulator.query_meta(train_id) | metadata = encapsulator.query_meta(train_id) | ||||
| return jsonify(metadata) | return jsonify(metadata) | ||||
| @@ -139,7 +123,7 @@ def query_saliency(): | |||||
| encapsulator = SaliencyEncap( | encapsulator = SaliencyEncap( | ||||
| _image_url_formatter, | _image_url_formatter, | ||||
| ExplainManagerHolder.get_instance()) | |||||
| EXPLAIN_MANAGER) | |||||
| count, samples = encapsulator.query_saliency_maps(train_id=train_id, | count, samples = encapsulator.query_saliency_maps(train_id=train_id, | ||||
| labels=labels, | labels=labels, | ||||
| explainers=explainers, | explainers=explainers, | ||||
| @@ -160,7 +144,7 @@ def query_evaluation(): | |||||
| train_id = get_train_id(request) | train_id = get_train_id(request) | ||||
| if train_id is None: | if train_id is None: | ||||
| raise ParamMissError("train_id") | raise ParamMissError("train_id") | ||||
| encapsulator = EvaluationEncap(ExplainManagerHolder.get_instance()) | |||||
| encapsulator = EvaluationEncap(EXPLAIN_MANAGER) | |||||
| scores = encapsulator.query_explainer_scores(train_id) | scores = encapsulator.query_explainer_scores(train_id) | ||||
| return jsonify({ | return jsonify({ | ||||
| "explainer_scores": scores, | "explainer_scores": scores, | ||||
| @@ -182,7 +166,7 @@ def query_image(): | |||||
| if image_type not in ("original", "overlay"): | if image_type not in ("original", "overlay"): | ||||
| raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") | raise ParamValueError(f"type:{image_type}, valid options: 'original' 'overlay'") | ||||
| encapsulator = DatafileEncap(ExplainManagerHolder.get_instance()) | |||||
| encapsulator = DatafileEncap(EXPLAIN_MANAGER) | |||||
| image = encapsulator.query_image_binary(train_id, image_path, image_type) | image = encapsulator.query_image_binary(train_id, image_path, image_type) | ||||
| if image is None: | if image is None: | ||||
| raise ImageNotExistError(f"{image_path}") | raise ImageNotExistError(f"{image_path}") | ||||
| @@ -198,5 +182,4 @@ def init_module(app): | |||||
| app: the application obj. | app: the application obj. | ||||
| """ | """ | ||||
| ExplainManagerHolder.initialize() | |||||
| app.register_blueprint(BLUEPRINT) | app.register_blueprint(BLUEPRINT) | ||||
| @@ -31,7 +31,7 @@ class DataManagerStatus(BaseEnum): | |||||
| INVALID = 'INVALID' | INVALID = 'INVALID' | ||||
| class PluginNameEnum(BaseEnum): | |||||
| class ExplainFieldsEnum(BaseEnum): | |||||
| """Plugin Name Enum.""" | """Plugin Name Enum.""" | ||||
| EXPLAIN = 'explain' | EXPLAIN = 'explain' | ||||
| SAMPLE_ID = 'sample_id' | SAMPLE_ID = 'sample_id' | ||||
| @@ -70,7 +70,7 @@ class ExplainJobEncap(ExplainDataEncap): | |||||
| info["train_id"] = job.train_id | info["train_id"] = job.train_id | ||||
| info["create_time"] = datetime.fromtimestamp(job.create_time)\ | info["create_time"] = datetime.fromtimestamp(job.create_time)\ | ||||
| .strftime(cls.DATETIME_FORMAT) | .strftime(cls.DATETIME_FORMAT) | ||||
| info["update_time"] = datetime.fromtimestamp(job.latest_update_time)\ | |||||
| info["update_time"] = datetime.fromtimestamp(job.update_time)\ | |||||
| .strftime(cls.DATETIME_FORMAT) | .strftime(cls.DATETIME_FORMAT) | ||||
| return info | return info | ||||
| @@ -1,168 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """EventParser for summary event.""" | |||||
| from collections import namedtuple, defaultdict | |||||
| from typing import Dict, List, Optional, Tuple | |||||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||||
| from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.utils.exceptions import UnknownError | |||||
| _IMAGE_DATA_TAGS = { | |||||
| 'sample_id': PluginNameEnum.SAMPLE_ID.value, | |||||
| 'ground_truth_label': PluginNameEnum.GROUND_TRUTH_LABEL.value, | |||||
| 'inference': PluginNameEnum.INFERENCE.value, | |||||
| 'explanation': PluginNameEnum.EXPLANATION.value | |||||
| } | |||||
| _NUM_DIGIT = 7 | |||||
| class EventParser: | |||||
| """Parser for event data.""" | |||||
| def __init__(self, job): | |||||
| self._job = job | |||||
| self._sample_pool = {} | |||||
| @staticmethod | |||||
| def parse_metadata(metadata) -> Tuple[List, List, List]: | |||||
| """Parse the metadata event.""" | |||||
| explainers = list(metadata.explain_method) | |||||
| metrics = list(metadata.benchmark_method) | |||||
| labels = list(metadata.label) | |||||
| return explainers, metrics, labels | |||||
| @staticmethod | |||||
| def parse_benchmark(benchmarks) -> Tuple[Dict, Dict]: | |||||
| """Parse the benchmark event.""" | |||||
| explainer_score_dict = defaultdict(list) | |||||
| label_score_dict = defaultdict(dict) | |||||
| for benchmark in benchmarks: | |||||
| explainer = benchmark.explain_method | |||||
| metric = benchmark.benchmark_method | |||||
| metric_score = benchmark.total_score | |||||
| label_score_event = benchmark.label_score | |||||
| explainer_score_dict[explainer].append({ | |||||
| 'metric': metric, | |||||
| 'score': round(metric_score, _NUM_DIGIT)}) | |||||
| new_label_score_dict = EventParser._score_event_to_dict(label_score_event, metric) | |||||
| for label, label_scores in new_label_score_dict.items(): | |||||
| label_score_dict[explainer][label] = label_score_dict[explainer].get(label, []) + label_scores | |||||
| return explainer_score_dict, label_score_dict | |||||
| def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]: | |||||
| """Parse the sample event.""" | |||||
| sample_id = sample.sample_id | |||||
| if sample_id not in self._sample_pool: | |||||
| self._sample_pool[sample_id] = sample | |||||
| return None | |||||
| for tag in _IMAGE_DATA_TAGS: | |||||
| try: | |||||
| if tag == PluginNameEnum.INFERENCE.value: | |||||
| self._parse_inference(sample, sample_id) | |||||
| elif tag == PluginNameEnum.EXPLANATION.value: | |||||
| self._parse_explanation(sample, sample_id) | |||||
| else: | |||||
| self._parse_sample_info(sample, sample_id, tag) | |||||
| except UnknownError as ex: | |||||
| logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) | |||||
| continue | |||||
| if EventParser._is_ready_for_display(self._sample_pool[sample_id]): | |||||
| return self._sample_pool[sample_id] | |||||
| return None | |||||
| def clear(self): | |||||
| """Clear the loaded data.""" | |||||
| self._sample_pool.clear() | |||||
| @staticmethod | |||||
| def _is_ready_for_display(image_container: namedtuple) -> bool: | |||||
| """ | |||||
| Check whether the image_container is ready for frontend display. | |||||
| Args: | |||||
| image_container (namedtuple): container consists of sample data | |||||
| Return: | |||||
| bool: whether the image_container if ready for display | |||||
| """ | |||||
| required_attrs = ['image_path', 'ground_truth_label', 'inference'] | |||||
| for attr in required_attrs: | |||||
| if not EventParser.is_attr_ready(image_container, attr): | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def is_attr_ready(image_container: namedtuple, attr: str) -> bool: | |||||
| """ | |||||
| Check whether the given attribute is ready in image_container. | |||||
| Args: | |||||
| image_container (namedtuple): container consist of sample data | |||||
| attr (str): attribute to check | |||||
| Returns: | |||||
| bool, whether the attr is ready | |||||
| """ | |||||
| if getattr(image_container, attr, False): | |||||
| return True | |||||
| return False | |||||
| @staticmethod | |||||
| def _score_event_to_dict(label_score_event, metric): | |||||
| """Transfer metric scores per label to pre-defined structure.""" | |||||
| new_label_score_dict = defaultdict(list) | |||||
| for label_id, label_score in enumerate(label_score_event): | |||||
| new_label_score_dict[label_id].append({ | |||||
| 'metric': metric, | |||||
| 'score': round(label_score, _NUM_DIGIT), | |||||
| }) | |||||
| return new_label_score_dict | |||||
| def _parse_inference(self, event, sample_id): | |||||
| """Parse the inference event.""" | |||||
| self._sample_pool[sample_id].inference.ground_truth_prob.extend(event.inference.ground_truth_prob) | |||||
| self._sample_pool[sample_id].inference.ground_truth_prob_sd.extend(event.inference.ground_truth_prob_sd) | |||||
| self._sample_pool[sample_id].inference.ground_truth_prob_itl95_low.\ | |||||
| extend(event.inference.ground_truth_prob_itl95_low) | |||||
| self._sample_pool[sample_id].inference.ground_truth_prob_itl95_hi.\ | |||||
| extend(event.inference.ground_truth_prob_itl95_hi) | |||||
| self._sample_pool[sample_id].inference.predicted_label.extend(event.inference.predicted_label) | |||||
| self._sample_pool[sample_id].inference.predicted_prob.extend(event.inference.predicted_prob) | |||||
| self._sample_pool[sample_id].inference.predicted_prob_sd.extend(event.inference.predicted_prob_sd) | |||||
| self._sample_pool[sample_id].inference.predicted_prob_itl95_low.extend(event.inference.predicted_prob_itl95_low) | |||||
| self._sample_pool[sample_id].inference.predicted_prob_itl95_hi.extend(event.inference.predicted_prob_itl95_hi) | |||||
| def _parse_explanation(self, event, sample_id): | |||||
| """Parse the explanation event.""" | |||||
| if event.explanation: | |||||
| for explanation_item in event.explanation: | |||||
| new_explanation = self._sample_pool[sample_id].explanation.add() | |||||
| new_explanation.explain_method = explanation_item.explain_method | |||||
| new_explanation.label = explanation_item.label | |||||
| new_explanation.heatmap_path = explanation_item.heatmap_path | |||||
| def _parse_sample_info(self, event, sample_id, tag): | |||||
| """Parse the event containing image info.""" | |||||
| if not getattr(self._sample_pool[sample_id], tag): | |||||
| setattr(self._sample_pool[sample_id], tag, getattr(event, tag)) | |||||
| @@ -1,398 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ExplainJob.""" | |||||
| import os | |||||
| from collections import defaultdict | |||||
| from datetime import datetime | |||||
| from typing import Union | |||||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||||
| from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.explainer.manager.explain_parser import _ExplainParser | |||||
| from mindinsight.explainer.manager.event_parse import EventParser | |||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||||
| _NUM_DIGIT = 7 | |||||
| class ExplainJob: | |||||
| """ExplainJob which manage the record in the summary file.""" | |||||
| def __init__(self, | |||||
| job_id: str, | |||||
| summary_dir: str, | |||||
| create_time: float, | |||||
| latest_update_time: float): | |||||
| self._job_id = job_id | |||||
| self._summary_dir = summary_dir | |||||
| self._parser = _ExplainParser(summary_dir) | |||||
| self._event_parser = EventParser(self) | |||||
| self._latest_update_time = latest_update_time | |||||
| self._create_time = create_time | |||||
| self._uncertainty_enabled = False | |||||
| self._labels = [] | |||||
| self._metrics = [] | |||||
| self._explainers = [] | |||||
| self._samples_info = {} | |||||
| self._labels_info = {} | |||||
| self._explainer_score_dict = defaultdict(list) | |||||
| self._label_score_dict = defaultdict(dict) | |||||
| @property | |||||
| def all_classes(self): | |||||
| """ | |||||
| Return a list of label info | |||||
| Returns: | |||||
| class_objs (List[ClassObj]): a list of class_objects, each object | |||||
| contains: | |||||
| - id (int): label id | |||||
| - label (str): label name | |||||
| - sample_count (int): number of samples for each label | |||||
| """ | |||||
| all_classes_return = [] | |||||
| for label_id, label_info in self._labels_info.items(): | |||||
| single_info = { | |||||
| 'id': label_id, | |||||
| 'label': label_info['label'], | |||||
| 'sample_count': len(label_info['sample_ids'])} | |||||
| all_classes_return.append(single_info) | |||||
| return all_classes_return | |||||
| @property | |||||
| def explainers(self): | |||||
| """ | |||||
| Return a list of explainer names | |||||
| Returns: | |||||
| list(str), explainer names | |||||
| """ | |||||
| return self._explainers | |||||
| @property | |||||
| def explainer_scores(self): | |||||
| """Return evaluation results for every explainer.""" | |||||
| merged_scores = [] | |||||
| for explainer, explainer_score_on_metric in self._explainer_score_dict.items(): | |||||
| label_scores = [] | |||||
| for label, label_score_on_metric in self._label_score_dict[explainer].items(): | |||||
| score_single_label = { | |||||
| 'label': self._labels[label], | |||||
| 'evaluations': label_score_on_metric, | |||||
| } | |||||
| label_scores.append(score_single_label) | |||||
| merged_scores.append({ | |||||
| 'explainer': explainer, | |||||
| 'evaluations': explainer_score_on_metric, | |||||
| 'class_scores': label_scores, | |||||
| }) | |||||
| return merged_scores | |||||
| @property | |||||
| def sample_count(self): | |||||
| """ | |||||
| Return total number of samples in the job. | |||||
| Return: | |||||
| int, total number of samples | |||||
| """ | |||||
| return len(self._samples_info) | |||||
| @property | |||||
| def train_id(self): | |||||
| """ | |||||
| Return ID of explain job | |||||
| Returns: | |||||
| str, id of ExplainJob object | |||||
| """ | |||||
| return self._job_id | |||||
| @property | |||||
| def metrics(self): | |||||
| """ | |||||
| Return a list of metric names | |||||
| Returns: | |||||
| list(str), metric names | |||||
| """ | |||||
| return self._metrics | |||||
| @property | |||||
| def min_confidence(self): | |||||
| """ | |||||
| Return minimum confidence | |||||
| Returns: | |||||
| min_confidence (float): | |||||
| """ | |||||
| return None | |||||
| @property | |||||
| def uncertainty_enabled(self): | |||||
| return self._uncertainty_enabled | |||||
| @property | |||||
| def create_time(self): | |||||
| """ | |||||
| Return the create time of summary file | |||||
| Returns: | |||||
| creation timestamp (float) | |||||
| """ | |||||
| return self._create_time | |||||
| @property | |||||
| def labels(self): | |||||
| """Return the label contained in the job.""" | |||||
| return self._labels | |||||
| @property | |||||
| def latest_update_time(self): | |||||
| """ | |||||
| Return last modification time stamp of summary file. | |||||
| Returns: | |||||
| float, last_modification_time stamp | |||||
| """ | |||||
| return self._latest_update_time | |||||
| @latest_update_time.setter | |||||
| def latest_update_time(self, new_time: Union[float, datetime]): | |||||
| """ | |||||
| Update the latest_update_time timestamp manually. | |||||
| Args: | |||||
| new_time stamp (union[float, datetime]): updated time for the job | |||||
| """ | |||||
| if isinstance(new_time, datetime): | |||||
| self._latest_update_time = new_time.timestamp() | |||||
| elif isinstance(new_time, float): | |||||
| self._latest_update_time = new_time | |||||
| else: | |||||
| raise TypeError('new_time should have type of float or datetime') | |||||
| @property | |||||
| def loader_id(self): | |||||
| """Return the job id.""" | |||||
| return self._job_id | |||||
| @property | |||||
| def samples(self): | |||||
| """Return the information of all samples in the job.""" | |||||
| return self._samples_info | |||||
| @staticmethod | |||||
| def get_create_time(file_path: str) -> float: | |||||
| """Return timestamp of create time of specific path.""" | |||||
| create_time = os.stat(file_path).st_ctime | |||||
| return create_time | |||||
| @staticmethod | |||||
| def get_update_time(file_path: str) -> float: | |||||
| """Return timestamp of update time of specific path.""" | |||||
| update_time = os.stat(file_path).st_mtime | |||||
| return update_time | |||||
| def _initialize_labels_info(self): | |||||
| """Initialize a dict for labels in the job.""" | |||||
| if self._labels is None: | |||||
| logger.warning('No labels is provided in job %s', self._job_id) | |||||
| return | |||||
| for label_id, label in enumerate(self._labels): | |||||
| self._labels_info[label_id] = {'label': label, | |||||
| 'sample_ids': set()} | |||||
| def _explanation_to_dict(self, explanation): | |||||
| """Transfer the explanation from event to dict storage.""" | |||||
| explain_info = { | |||||
| 'explainer': explanation.explain_method, | |||||
| 'overlay': explanation.heatmap_path, | |||||
| } | |||||
| return explain_info | |||||
| def _image_container_to_dict(self, sample_data): | |||||
| """Transfer the image container to dict storage.""" | |||||
| has_uncertainty = False | |||||
| sample_id = sample_data.sample_id | |||||
| sample_info = { | |||||
| 'id': sample_id, | |||||
| 'image': sample_data.image_path, | |||||
| 'name': str(sample_id), | |||||
| 'labels': [self._labels_info[x]['label'] | |||||
| for x in sample_data.ground_truth_label], | |||||
| 'inferences': []} | |||||
| ground_truth_labels = list(sample_data.ground_truth_label) | |||||
| ground_truth_probs = list(sample_data.inference.ground_truth_prob) | |||||
| predicted_labels = list(sample_data.inference.predicted_label) | |||||
| predicted_probs = list(sample_data.inference.predicted_prob) | |||||
| if sample_data.inference.predicted_prob_sd or sample_data.inference.ground_truth_prob_sd: | |||||
| ground_truth_prob_sds = list(sample_data.inference.ground_truth_prob_sd) | |||||
| ground_truth_prob_lows = list(sample_data.inference.ground_truth_prob_itl95_low) | |||||
| ground_truth_prob_his = list(sample_data.inference.ground_truth_prob_itl95_hi) | |||||
| predicted_prob_sds = list(sample_data.inference.predicted_prob_sd) | |||||
| predicted_prob_lows = list(sample_data.inference.predicted_prob_itl95_low) | |||||
| predicted_prob_his = list(sample_data.inference.predicted_prob_itl95_hi) | |||||
| has_uncertainty = True | |||||
| else: | |||||
| ground_truth_prob_sds = ground_truth_prob_lows = ground_truth_prob_his = None | |||||
| predicted_prob_sds = predicted_prob_lows = predicted_prob_his = None | |||||
| inference_info = {} | |||||
| for label, prob in zip( | |||||
| ground_truth_labels + predicted_labels, | |||||
| ground_truth_probs + predicted_probs): | |||||
| inference_info[label] = { | |||||
| 'label': self._labels_info[label]['label'], | |||||
| 'confidence': round(prob, _NUM_DIGIT), | |||||
| 'saliency_maps': []} | |||||
| if ground_truth_prob_sds or predicted_prob_sds: | |||||
| for label, sd, low, hi in zip( | |||||
| ground_truth_labels + predicted_labels, | |||||
| ground_truth_prob_sds + predicted_prob_sds, | |||||
| ground_truth_prob_lows + predicted_prob_lows, | |||||
| ground_truth_prob_his + predicted_prob_his): | |||||
| inference_info[label]['confidence_sd'] = sd | |||||
| inference_info[label]['confidence_itl95'] = [low, hi] | |||||
| if EventParser.is_attr_ready(sample_data, 'explanation'): | |||||
| for explanation in sample_data.explanation: | |||||
| explanation_dict = self._explanation_to_dict(explanation) | |||||
| inference_info[explanation.label]['saliency_maps'].append(explanation_dict) | |||||
| sample_info['inferences'] = list(inference_info.values()) | |||||
| return sample_info, has_uncertainty | |||||
| def _import_sample(self, sample): | |||||
| """Add sample object of given sample id.""" | |||||
| for label_id in sample.ground_truth_label: | |||||
| self._labels_info[label_id]['sample_ids'].add(sample.sample_id) | |||||
| sample_info, has_uncertainty = self._image_container_to_dict(sample) | |||||
| self._samples_info.update({sample_info['id']: sample_info}) | |||||
| self._uncertainty_enabled |= has_uncertainty | |||||
| def get_all_samples(self): | |||||
| """ | |||||
| Return a list of sample information cachced in the explain job | |||||
| Returns: | |||||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||||
| consists of: | |||||
| - id (int): sample id | |||||
| - name (str): basename of image | |||||
| - labels (list[str]): list of labels | |||||
| - inferences list[dict]) | |||||
| """ | |||||
| samples_in_list = list(self._samples_info.values()) | |||||
| return samples_in_list | |||||
| def _is_metadata_empty(self): | |||||
| """Check whether metadata is loaded first.""" | |||||
| if not self._explainers or not self._metrics or not self._labels: | |||||
| return True | |||||
| return False | |||||
| def _import_data_from_event(self, event): | |||||
| """Parse and import data from the event data.""" | |||||
| tags = { | |||||
| 'sample_id': PluginNameEnum.SAMPLE_ID, | |||||
| 'benchmark': PluginNameEnum.BENCHMARK, | |||||
| 'metadata': PluginNameEnum.METADATA | |||||
| } | |||||
| if 'metadata' not in event and self._is_metadata_empty(): | |||||
| raise ValueError('metadata is empty, should write metadata first in the summary.') | |||||
| for tag in tags: | |||||
| if tag not in event: | |||||
| continue | |||||
| if tag == PluginNameEnum.SAMPLE_ID.value: | |||||
| sample_event = event[tag] | |||||
| sample_data = self._event_parser.parse_sample(sample_event) | |||||
| if sample_data is not None: | |||||
| self._import_sample(sample_data) | |||||
| continue | |||||
| if tag == PluginNameEnum.BENCHMARK.value: | |||||
| benchmark_event = event[tag].benchmark | |||||
| explain_score_dict, label_score_dict = EventParser.parse_benchmark(benchmark_event) | |||||
| self._update_benchmark(explain_score_dict, label_score_dict) | |||||
| elif tag == PluginNameEnum.METADATA.value: | |||||
| metadata_event = event[tag].metadata | |||||
| metadata = EventParser.parse_metadata(metadata_event) | |||||
| self._explainers, self._metrics, self._labels = metadata | |||||
| self._initialize_labels_info() | |||||
| def load(self): | |||||
| """ | |||||
| Start loading data from parser. | |||||
| """ | |||||
| valid_file_names = [] | |||||
| for filename in FileHandler.list_dir(self._summary_dir): | |||||
| if FileHandler.is_file( | |||||
| FileHandler.join(self._summary_dir, filename)): | |||||
| valid_file_names.append(filename) | |||||
| if not valid_file_names: | |||||
| raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir) | |||||
| is_end = False | |||||
| while not is_end: | |||||
| is_clean, is_end, event = self._parser.parse_explain(valid_file_names) | |||||
| if is_clean: | |||||
| logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir) | |||||
| self._clean_job() | |||||
| if event: | |||||
| self._import_data_from_event(event) | |||||
| def _clean_job(self): | |||||
| """Clean the cached data in job.""" | |||||
| self._latest_update_time = ExplainJob.get_update_time(self._summary_dir) | |||||
| self._create_time = ExplainJob.get_update_time(self._summary_dir) | |||||
| self._labels.clear() | |||||
| self._metrics.clear() | |||||
| self._explainers.clear() | |||||
| self._samples_info.clear() | |||||
| self._labels_info.clear() | |||||
| self._explainer_score_dict.clear() | |||||
| self._label_score_dict.clear() | |||||
| self._event_parser.clear() | |||||
| def _update_benchmark(self, explainer_score_dict, labels_score_dict): | |||||
| """Update the benchmark info.""" | |||||
| for explainer, score in explainer_score_dict.items(): | |||||
| self._explainer_score_dict[explainer].extend(score) | |||||
| for explainer, score in labels_score_dict.items(): | |||||
| for label, score_of_label in score.items(): | |||||
| self._label_score_dict[explainer][label] = (self._label_score_dict[explainer].get(label, []) | |||||
| + score_of_label) | |||||
| @@ -0,0 +1,580 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ExplainLoader.""" | |||||
| import os | |||||
| import re | |||||
| from collections import defaultdict | |||||
| from datetime import datetime | |||||
| from typing import Dict, Iterable, List, Optional, Union | |||||
| from mindinsight.explainer.common.enums import ExplainFieldsEnum | |||||
| from mindinsight.explainer.common.log import logger | |||||
| from mindinsight.explainer.manager.explain_parser import ExplainParser | |||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | |||||
| from mindinsight.datavisual.common.exceptions import TrainJobNotExistError | |||||
| from mindinsight.utils.exceptions import ParamValueError, UnknownError | |||||
| _NUM_DIGITS = 6 | |||||
| _EXPLAIN_FIELD_NAMES = [ | |||||
| ExplainFieldsEnum.SAMPLE_ID, | |||||
| ExplainFieldsEnum.BENCHMARK, | |||||
| ExplainFieldsEnum.METADATA, | |||||
| ] | |||||
| _SAMPLE_FIELD_NAMES = [ | |||||
| ExplainFieldsEnum.GROUND_TRUTH_LABEL, | |||||
| ExplainFieldsEnum.INFERENCE, | |||||
| ExplainFieldsEnum.EXPLANATION, | |||||
| ] | |||||
| def _round(score): | |||||
| """Take round of a number to given precision.""" | |||||
| return round(score, _NUM_DIGITS) | |||||
| class ExplainLoader: | |||||
| """ExplainLoader which manage the record in the summary file.""" | |||||
| def __init__(self, | |||||
| loader_id: str, | |||||
| summary_dir: str): | |||||
| self._parser = ExplainParser(summary_dir) | |||||
| self._loader_info = { | |||||
| 'loader_id': loader_id, | |||||
| 'summary_dir': summary_dir, | |||||
| 'create_time': os.stat(summary_dir).st_ctime, | |||||
| 'update_time': os.stat(summary_dir).st_mtime, | |||||
| 'query_time': os.stat(summary_dir).st_ctime, | |||||
| 'uncertainty_enabled': False, | |||||
| } | |||||
| self._samples = defaultdict(dict) | |||||
| self._metadata = {'explainers': [], 'metrics': [], 'labels': []} | |||||
| self._benchmark = {'explainer_score': defaultdict(dict), 'label_score': defaultdict(dict)} | |||||
| @property | |||||
| def all_classes(self) -> List[Dict]: | |||||
| """ | |||||
| Return a list of detailed label information, including label id, label name and sample count of each label. | |||||
| Returns: | |||||
| list[dict], a list of dict, each dict contains: | |||||
| - id (int): label id | |||||
| - label (str): label name | |||||
| - sample_count (int): number of samples for each label | |||||
| """ | |||||
| sample_count_per_label = defaultdict(int) | |||||
| samples_copy = self._samples.copy() | |||||
| for sample in samples_copy.values(): | |||||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||||
| for label in sample['ground_truth_label']: | |||||
| sample_count_per_label[label] += 1 | |||||
| all_classes_return = [] | |||||
| for label_id, label_name in enumerate(self._metadata['labels']): | |||||
| single_info = { | |||||
| 'id': label_id, | |||||
| 'label': label_name, | |||||
| 'sample_count': sample_count_per_label[label_id] | |||||
| } | |||||
| all_classes_return.append(single_info) | |||||
| return all_classes_return | |||||
| @property | |||||
| def query_time(self) -> float: | |||||
| """Return query timestamp of explain loader.""" | |||||
| return self._loader_info['query_time'] | |||||
| @query_time.setter | |||||
| def query_time(self, new_time: Union[datetime, float]): | |||||
| """ | |||||
| Update the query_time timestamp manually. | |||||
| Args: | |||||
| new_time (datetime.datetime or float): Updated query_time for the explain loader. | |||||
| """ | |||||
| if isinstance(new_time, datetime): | |||||
| self._loader_info['query_time'] = new_time.timestamp() | |||||
| elif isinstance(new_time, float): | |||||
| self._loader_info['query_time'] = new_time | |||||
| else: | |||||
| raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' | |||||
| .format(type(new_time))) | |||||
| @property | |||||
| def create_time(self) -> float: | |||||
| """Return the create timestamp of summary file.""" | |||||
| return self._loader_info['create_time'] | |||||
| @create_time.setter | |||||
| def create_time(self, new_time: Union[datetime, float]): | |||||
| """ | |||||
| Update the create_time manually | |||||
| Args: | |||||
| new_time (datetime.datetime or float): Updated create_time of summary_file. | |||||
| """ | |||||
| if isinstance(new_time, datetime): | |||||
| self._loader_info['create_time'] = new_time.timestamp() | |||||
| elif isinstance(new_time, float): | |||||
| self._loader_info['create_time'] = new_time | |||||
| else: | |||||
| raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' | |||||
| .format(type(new_time))) | |||||
| @property | |||||
| def explainers(self) -> List[str]: | |||||
| """Return a list of explainer names recorded in the summary file.""" | |||||
| return self._metadata['explainers'] | |||||
| @property | |||||
| def explainer_scores(self) -> List[Dict]: | |||||
| """ | |||||
| Return evaluation results for every explainer. | |||||
| Returns: | |||||
| list[dict], A list of evaluation results of each explainer. Each item contains: | |||||
| - explainer (str): Name of evaluated explainer. | |||||
| - evaluations (list[dict]): A list of evlauation results by different metrics. | |||||
| - class_scores (list[dict]): A list of evaluation results on different labels. | |||||
| Each item in the evaluations contains: | |||||
| - metric (str): name of metric method | |||||
| - score (float): evaluation result | |||||
| Each item in the class_scores contains: | |||||
| - label (str): Name of label | |||||
| - evaluations (list[dict]): A list of evalution results on different labels by different metrics. | |||||
| Each item in evaluations contains: | |||||
| - metric (str): Name of metric method | |||||
| - score (float): Evaluation scores of explainer on specific label by the metric. | |||||
| """ | |||||
| explainer_scores = [] | |||||
| for explainer, explainer_score_on_metric in self._benchmark['explainer_score'].copy().items(): | |||||
| metric_scores = [{'metric': metric, 'score': _round(score)} | |||||
| for metric, score in explainer_score_on_metric.items()] | |||||
| label_scores = [] | |||||
| for label, label_score_on_metric in self._benchmark['label_score'][explainer].copy().items(): | |||||
| score_of_single_label = { | |||||
| 'label': self._metadata['labels'][label], | |||||
| 'evaluations': [ | |||||
| {'metric': metric, 'score': _round(score)} for metric, score in label_score_on_metric.items() | |||||
| ], | |||||
| } | |||||
| label_scores.append(score_of_single_label) | |||||
| explainer_scores.append({ | |||||
| 'explainer': explainer, | |||||
| 'evaluations': metric_scores, | |||||
| 'class_scores': label_scores, | |||||
| }) | |||||
| return explainer_scores | |||||
| @property | |||||
| def labels(self) -> List[str]: | |||||
| """Return the label recorded in the summary.""" | |||||
| return self._metadata['labels'] | |||||
| @property | |||||
| def metrics(self) -> List[str]: | |||||
| """Return a list of metric names recorded in the summary file.""" | |||||
| return self._metadata['metrics'] | |||||
| @property | |||||
| def min_confidence(self) -> Optional[float]: | |||||
| """Return minimum confidence used to filter the predicted labels.""" | |||||
| return None | |||||
| @property | |||||
| def sample_count(self) -> int: | |||||
| """ | |||||
| Return total number of samples in the loader. | |||||
| Since the loader only return available samples (i.e. with original image data and ground_truth_label loaded in | |||||
| cache), the returned count only takes the available samples into account. | |||||
| Return: | |||||
| int, total number of available samples in the loading job. | |||||
| """ | |||||
| sample_count = 0 | |||||
| samples_copy = self._samples.copy() | |||||
| for sample in samples_copy.values(): | |||||
| if sample.get('image', False) and sample.get('ground_truth_label', False): | |||||
| sample_count += 1 | |||||
| return sample_count | |||||
| @property | |||||
| def samples(self) -> List[Dict]: | |||||
| """Return the information of all samples in the job.""" | |||||
| return self.get_all_samples() | |||||
| @property | |||||
| def train_id(self) -> str: | |||||
| """Return ID of explain loader.""" | |||||
| return self._loader_info['loader_id'] | |||||
| @property | |||||
| def uncertainty_enabled(self): | |||||
| """Whethter uncertainty is enabled.""" | |||||
| return self._loader_info['uncertainty_enabled'] | |||||
| @property | |||||
| def update_time(self) -> float: | |||||
| """Return latest modification timestamp of summary file.""" | |||||
| return self._loader_info['update_time'] | |||||
| @update_time.setter | |||||
| def update_time(self, new_time: Union[datetime, float]): | |||||
| """ | |||||
| Update the update_time manually. | |||||
| Args: | |||||
| new_time stamp (datetime.datetime or float): Updated time for the summary file. | |||||
| """ | |||||
| if isinstance(new_time, datetime): | |||||
| self._loader_info['update_time'] = new_time.timestamp() | |||||
| elif isinstance(new_time, float): | |||||
| self._loader_info['update_time'] = new_time | |||||
| else: | |||||
| raise TypeError('new_time should have type of datetime.datetime or float, but receive {}' | |||||
| .format(type(new_time))) | |||||
| def load(self): | |||||
| """Start loading data from the latest summary file to the loader.""" | |||||
| filenames = [] | |||||
| for filename in FileHandler.list_dir(self._loader_info['summary_dir']): | |||||
| if FileHandler.is_file(FileHandler.join(self._loader_info['summary_dir'], filename)): | |||||
| filenames.append(filename) | |||||
| filenames = ExplainLoader._filter_files(filenames) | |||||
| if not filenames: | |||||
| raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' | |||||
| % self._loader_info['summary_dir']) | |||||
| is_end = False | |||||
| while not is_end: | |||||
| is_clean, is_end, event_dict = self._parser.parse_explain(filenames) | |||||
| if is_clean: | |||||
| logger.info('Summary file in %s update, reload the data in the summary.', | |||||
| self._loader_info['summary_dir']) | |||||
| self._clear_job() | |||||
| if event_dict: | |||||
| self._import_data_from_event(event_dict) | |||||
| def get_all_samples(self) -> List[Dict]: | |||||
| """ | |||||
| Return a list of sample information cachced in the explain job | |||||
| Returns: | |||||
| sample_list (List[SampleObj]): a list of sample objects, each object | |||||
| consists of: | |||||
| - id (int): sample id | |||||
| - name (str): basename of image | |||||
| - labels (list[str]): list of labels | |||||
| - inferences list[dict]) | |||||
| """ | |||||
| returned_samples = [] | |||||
| samples_copy = self._samples.copy() | |||||
| for sample_id, sample_info in samples_copy.items(): | |||||
| if not sample_info.get('image', False) and not sample_info.get('ground_truth_label', False): | |||||
| continue | |||||
| returned_sample = { | |||||
| 'id': sample_id, | |||||
| 'name': str(sample_id), | |||||
| 'image': sample_info['image'], | |||||
| 'labels': sample_info['ground_truth_label'], | |||||
| } | |||||
| if not ExplainLoader._is_inference_valid(sample_info): | |||||
| continue | |||||
| inferences = {} | |||||
| for label, prob in zip(sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||||
| sample_info['ground_truth_prob'] + sample_info['predicted_prob']): | |||||
| inferences[label] = { | |||||
| 'label': self._metadata['labels'][label], | |||||
| 'confidence': _round(prob), | |||||
| 'saliency_maps': [] | |||||
| } | |||||
| if sample_info['ground_truth_prob_sd'] or sample_info['predicted_prob_sd']: | |||||
| for label, std, low, high in zip( | |||||
| sample_info['ground_truth_label'] + sample_info['predicted_label'], | |||||
| sample_info['ground_truth_prob_sd'] + sample_info['predicted_prob_sd'], | |||||
| sample_info['ground_truth_prob_itl95_low'] + sample_info['predicted_prob_itl95_low'], | |||||
| sample_info['ground_truth_prob_itl95_hi'] + sample_info['predicted_prob_itl95_hi'] | |||||
| ): | |||||
| inferences[label]['confidence_sd'] = _round(std) | |||||
| inferences[label]['confidence_itl95'] = [_round(low), _round(high)] | |||||
| for explainer, label_heatmap_path_dict in sample_info['explanation'].items(): | |||||
| for label, heatmap_path in label_heatmap_path_dict.items(): | |||||
| if label in inferences: | |||||
| inferences[label]['saliency_maps'].append({'explainer': explainer, 'overlay': heatmap_path}) | |||||
| returned_sample['inferences'] = list(inferences.values()) | |||||
| returned_samples.append(returned_sample) | |||||
| return returned_samples | |||||
| def _import_data_from_event(self, event_dict: Dict): | |||||
| """Parse and import data from the event data.""" | |||||
| if 'metadata' not in event_dict and self._is_metadata_empty(): | |||||
| raise ParamValueError('metadata is imcomplete, should write metadata first in the summary.') | |||||
| for tag, event in event_dict.items(): | |||||
| if tag == ExplainFieldsEnum.METADATA.value: | |||||
| self._import_metadata_from_event(event.metadata) | |||||
| elif tag == ExplainFieldsEnum.BENCHMARK.value: | |||||
| self._import_benchmark_from_event(event.benchmark) | |||||
| elif tag == ExplainFieldsEnum.SAMPLE_ID.value: | |||||
| self._import_sample_from_event(event) | |||||
| else: | |||||
| logger.info('Unknown ExplainField: %s', tag) | |||||
| def _is_metadata_empty(self): | |||||
| """Check whether metadata is completely loaded first.""" | |||||
| if not self._metadata['labels']: | |||||
| return True | |||||
| return False | |||||
| def _import_metadata_from_event(self, metadata_event): | |||||
| """Import the metadata from event into loader.""" | |||||
| def take_union(existed_list, imported_data): | |||||
| """Take union of existed_list and imported_data.""" | |||||
| if isinstance(imported_data, Iterable): | |||||
| for sample in imported_data: | |||||
| if sample not in existed_list: | |||||
| existed_list.append(sample) | |||||
| take_union(self._metadata['explainers'], metadata_event.explain_method) | |||||
| take_union(self._metadata['metrics'], metadata_event.benchmark_method) | |||||
| take_union(self._metadata['labels'], metadata_event.label) | |||||
| def _import_benchmark_from_event(self, benchmarks): | |||||
| """ | |||||
| Parse the benchmark event. | |||||
| Benchmark data are separeted into 'explainer_score' and 'label_score'. 'explainer_score' contains overall | |||||
| evaluation results of each explainer by different metrics, while 'label_score' additionally devides the results | |||||
| w.r.t different labels. | |||||
| The structure of self._benchmark['explainer_score'] demonstrates below: | |||||
| { | |||||
| explainer_1: {metric_name_1: score_1, ...}, | |||||
| explainer_2: {metric_name_1: score_1, ...}, | |||||
| ... | |||||
| } | |||||
| The structure of self._benchmark['label_score'] is: | |||||
| { | |||||
| explainer_1: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, | |||||
| explainer_2: {label_id: {metric_1: score_1, metric_2: score_2, ...}, ...}, | |||||
| ... | |||||
| } | |||||
| Args: | |||||
| benchmarks (benchmark_container): Parsed benchmarks data from summary file. | |||||
| """ | |||||
| explainer_score = self._benchmark['explainer_score'] | |||||
| label_score = self._benchmark['label_score'] | |||||
| for benchmark in benchmarks: | |||||
| explainer = benchmark.explain_method | |||||
| metric = benchmark.benchmark_method | |||||
| metric_score = benchmark.total_score | |||||
| label_score_event = benchmark.label_score | |||||
| explainer_score[explainer][metric] = metric_score | |||||
| new_label_score_dict = ExplainLoader._score_event_to_dict(label_score_event, metric) | |||||
| for label, scores_of_metric in new_label_score_dict.items(): | |||||
| if label not in label_score[explainer]: | |||||
| label_score[explainer][label] = {} | |||||
| label_score[explainer][label].update(scores_of_metric) | |||||
| def _import_sample_from_event(self, sample): | |||||
| """ | |||||
| Parse the sample event. | |||||
| Detailed data of each sample are store in self._samples, identified by sample_id. Each sample data are stored | |||||
| in the following structure. | |||||
| - ground_truth_labels (list[int]): A list of ground truth labels of the sample. | |||||
| - ground_truth_probs (list[float]): A list of confidences of ground-truth label from black-box model. | |||||
| - predicted_labels (list[int]): A list of predicted labels from the black-box model. | |||||
| - predicted_probs (list[int]): A list of confidences w.r.t the predicted labels. | |||||
| - explanations (dict): Explanations is a dictionary where the each explainer name mapping to a dictionary | |||||
| of saliency maps. The structure of explanations demonstrates below: | |||||
| { | |||||
| explainer_name_1: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | |||||
| explainer_name_2: {label_1: saliency_id_1, label_2: saliency_id_2, ...}, | |||||
| ... | |||||
| } | |||||
| """ | |||||
| if not getattr(sample, 'sample_id', None): | |||||
| raise ParamValueError('sample_event has no sample_id') | |||||
| sample_id = sample.sample_id | |||||
| samples_copy = self._samples.copy() | |||||
| if sample_id not in samples_copy: | |||||
| self._samples[sample_id] = { | |||||
| 'ground_truth_label': [], | |||||
| 'ground_truth_prob': [], | |||||
| 'ground_truth_prob_sd': [], | |||||
| 'ground_truth_prob_itl95_low': [], | |||||
| 'ground_truth_prob_itl95_hi': [], | |||||
| 'predicted_label': [], | |||||
| 'predicted_prob': [], | |||||
| 'predicted_prob_sd': [], | |||||
| 'predicted_prob_itl95_low': [], | |||||
| 'predicted_prob_itl95_hi': [], | |||||
| 'explanation': defaultdict(dict) | |||||
| } | |||||
| if sample.image_path: | |||||
| self._samples[sample_id]['image'] = sample.image_path | |||||
| for tag in _SAMPLE_FIELD_NAMES: | |||||
| try: | |||||
| if ExplainLoader._is_attr_empty(sample, tag.value): | |||||
| continue | |||||
| if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL: | |||||
| self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label)) | |||||
| elif tag == ExplainFieldsEnum.INFERENCE: | |||||
| self._import_inference_from_event(sample, sample_id) | |||||
| elif tag == ExplainFieldsEnum.EXPLANATION: | |||||
| self._import_explanation_from_event(sample, sample_id) | |||||
| except UnknownError as ex: | |||||
| logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex)) | |||||
| def _import_inference_from_event(self, event, sample_id): | |||||
| """Parse the inference event.""" | |||||
| inference = event.inference | |||||
| self._samples[sample_id]['ground_truth_prob'].extend(list(inference.ground_truth_prob)) | |||||
| self._samples[sample_id]['ground_truth_prob_sd'].extend(list(inference.ground_truth_prob_sd)) | |||||
| self._samples[sample_id]['ground_truth_prob_itl95_low'].extend(list(inference.ground_truth_prob_itl95_low)) | |||||
| self._samples[sample_id]['ground_truth_prob_itl95_hi'].extend(list(inference.ground_truth_prob_itl95_hi)) | |||||
| self._samples[sample_id]['predicted_label'].extend(list(inference.predicted_label)) | |||||
| self._samples[sample_id]['predicted_prob'].extend(list(inference.predicted_prob)) | |||||
| self._samples[sample_id]['predicted_prob_sd'].extend(list(inference.predicted_prob_sd)) | |||||
| self._samples[sample_id]['predicted_prob_itl95_low'].extend(list(inference.predicted_prob_itl95_low)) | |||||
| self._samples[sample_id]['predicted_prob_itl95_hi'].extend(list(inference.predicted_prob_itl95_hi)) | |||||
| if self._samples[sample_id]['ground_truth_prob_sd'] or self._samples[sample_id]['predicted_prob_sd']: | |||||
| self._loader_info['uncertainty_enabled'] = True | |||||
| def _import_explanation_from_event(self, event, sample_id): | |||||
| """Parse the explanation event.""" | |||||
| if self._samples[sample_id]['explanation'] is None: | |||||
| self._samples[sample_id]['explanation'] = defaultdict(dict) | |||||
| sample_explanation = self._samples[sample_id]['explanation'] | |||||
| for explanation_item in event.explanation: | |||||
| explainer = explanation_item.explain_method | |||||
| label = explanation_item.label | |||||
| sample_explanation[explainer][label] = explanation_item.heatmap_path | |||||
| def _clear_job(self): | |||||
| """Clear the cached data and update the time info of the loader.""" | |||||
| self._samples.clear() | |||||
| self._loader_info['create_time'] = os.stat(self._loader_info['summary_dir']).st_ctime | |||||
| self._loader_info['update_time'] = os.stat(self._loader_info['summary_dir']).st_mtime | |||||
| self._loader_info['query_time'] = max(self._loader_info['update_time'], self._loader_info['query_time']) | |||||
| def clear_inner_dict(outer_dict): | |||||
| """Clear the inner structured data of the given dict.""" | |||||
| for item in outer_dict.values(): | |||||
| item.clear() | |||||
| map(clear_inner_dict, [self._metadata, self._benchmark]) | |||||
| @staticmethod | |||||
| def _filter_files(filenames): | |||||
| """ | |||||
| Gets a list of summary files. | |||||
| Args: | |||||
| filenames (list[str]): File name list, like [filename1, filename2]. | |||||
| Returns: | |||||
| list[str], filename list. | |||||
| """ | |||||
| return list(filter(lambda filename: (re.search(r'summary\.\d+', filename) and filename.endswith("_explain")), | |||||
| filenames)) | |||||
| @staticmethod | |||||
| def _is_attr_empty(event, attr_name) -> bool: | |||||
| if not getattr(event, attr_name): | |||||
| return True | |||||
| for item in getattr(event, attr_name): | |||||
| if not isinstance(item, list) or item: | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def _is_ground_truth_label_valid(sample_id: str, sample_info: Dict) -> bool: | |||||
| if len(sample_info['ground_truth_label']) != len(sample_info['ground_truth_prob']): | |||||
| logger.info('length of ground_truth_prob does not match the length of ground_truth_label' | |||||
| 'length of ground_turth_label is: %s but length of ground_truth_prob is: %s.' | |||||
| 'sample_id is : %s.', | |||||
| len(sample_info['ground_truth_label']), len(sample_info['ground_truth_prob']), sample_id) | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def _is_inference_valid(sample): | |||||
| """ | |||||
| Check whether the inference data is empty or have the same length. | |||||
| If the probs have different length with the labels, it can be confusing when assigning each prob to label. | |||||
| 'is_inference_valid' return True only when the data size of match to each other. Note that prob data could be | |||||
| empty, so empty prob will pass the check. | |||||
| """ | |||||
| ground_truth_len = len(sample['ground_truth_label']) | |||||
| for name in ['ground_truth_prob', 'ground_truth_prob_sd', | |||||
| 'ground_truth_prob_itl95_low', 'ground_truth_prob_itl95_hi']: | |||||
| if sample[name] and len(sample[name]) != ground_truth_len: | |||||
| return False | |||||
| predicted_len = len(sample['predicted_label']) | |||||
| for name in ['predicted_prob', 'predicted_prob_sd', | |||||
| 'predicted_prob_itl95_low', 'predicted_prob_itl95_hi']: | |||||
| if sample[name] and len(sample[name]) != predicted_len: | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def _is_predicted_label_valid(sample_id: str, sample_info: Dict) -> bool: | |||||
| if len(sample_info['predicted_label']) != len(sample_info['predicted_prob']): | |||||
| logger.info('length of predicted_probs does not match the length of predicted_labels' | |||||
| 'length of predicted_probs: %s but receive length of predicted_label: %s, sample_id: %s.', | |||||
| len(sample_info['predicted_prob']), len(sample_info['predicted_label']), sample_id) | |||||
| return False | |||||
| return True | |||||
| @staticmethod | |||||
| def _score_event_to_dict(label_score_event, metric) -> Dict: | |||||
| """Transfer metric scores per label to pre-defined structure.""" | |||||
| new_label_score_dict = defaultdict(dict) | |||||
| for label_id, label_score in enumerate(label_score_event): | |||||
| new_label_score_dict[label_id][metric] = label_score | |||||
| return new_label_score_dict | |||||
| @@ -17,17 +17,20 @@ | |||||
| import os | import os | ||||
| import threading | import threading | ||||
| import time | import time | ||||
| from collections import OrderedDict | |||||
| from datetime import datetime | |||||
| from typing import Optional | |||||
| from mindinsight.conf import settings | |||||
| from mindinsight.datavisual.common import exceptions | from mindinsight.datavisual.common import exceptions | ||||
| from mindinsight.datavisual.common.enums import BaseEnum | from mindinsight.datavisual.common.enums import BaseEnum | ||||
| from mindinsight.explainer.common.log import logger | from mindinsight.explainer.common.log import logger | ||||
| from mindinsight.explainer.manager.explain_job import ExplainJob | |||||
| from mindinsight.explainer.manager.explain_loader import ExplainLoader | |||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | from mindinsight.datavisual.data_access.file_handler import FileHandler | ||||
| from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher | ||||
| from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError | from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError | ||||
| _MAX_LOADER_NUM = 3 | |||||
| _MAX_INTERVAL = 3 | |||||
| _MAX_LOADERS_NUM = 3 | |||||
| class _ExplainManagerStatus(BaseEnum): | class _ExplainManagerStatus(BaseEnum): | ||||
| @@ -43,245 +46,63 @@ class ExplainManager: | |||||
| def __init__(self, summary_base_dir: str): | def __init__(self, summary_base_dir: str): | ||||
| self._summary_base_dir = summary_base_dir | self._summary_base_dir = summary_base_dir | ||||
| self._loader_pool = {} | |||||
| self._deleted_ids = [] | |||||
| self._status = _ExplainManagerStatus.INIT.value | |||||
| self._loader_pool = OrderedDict() | |||||
| self._loading_status = _ExplainManagerStatus.INIT.value | |||||
| self._status_mutex = threading.Lock() | self._status_mutex = threading.Lock() | ||||
| self._loader_pool_mutex = threading.Lock() | self._loader_pool_mutex = threading.Lock() | ||||
| self._max_loader_num = _MAX_LOADER_NUM | |||||
| self._reload_interval = None | |||||
| self._max_loaders_num = _MAX_LOADERS_NUM | |||||
| self._summary_watcher = SummaryWatcher() | |||||
| def _reload_data(self): | |||||
| """periodically load summary from file.""" | |||||
| while True: | |||||
| try: | |||||
| self._load_data() | |||||
| if not self._reload_interval: | |||||
| break | |||||
| time.sleep(self._reload_interval) | |||||
| except UnknownError as ex: | |||||
| logger.exception(ex) | |||||
| logger.error('Unknown Error raise when loading summary files, status: %r, and loader pool size is %r.' | |||||
| 'Detail: %s', self._status, len(self._loader_pool), str(ex)) | |||||
| self._status = _ExplainManagerStatus.INVALID.value | |||||
| def _load_data(self): | |||||
| """Loading the summary in the given base directory.""" | |||||
| logger.info('Start to load data, reload interval: %r.', self._reload_interval) | |||||
| with self._status_mutex: | |||||
| if self._status == _ExplainManagerStatus.LOADING.value: | |||||
| logger.info('Current status is %s, will ignore to load data.', self._status) | |||||
| return | |||||
| self._status = _ExplainManagerStatus.LOADING.value | |||||
| try: | |||||
| self._generate_loaders() | |||||
| self._execute_load_data() | |||||
| except Exception as ex: | |||||
| raise UnknownError(ex) | |||||
| if not self._loader_pool: | |||||
| self._status = _ExplainManagerStatus.INVALID.value | |||||
| else: | |||||
| self._status = _ExplainManagerStatus.DONE.value | |||||
| logger.info('Load event data end, status: %r, and loader pool size is %r', | |||||
| self._status, len(self._loader_pool)) | |||||
| def _update_loader_latest_update_time(self, loader_id, latest_update_time=None): | |||||
| """update the update time of loader of given id.""" | |||||
| if latest_update_time is None: | |||||
| latest_update_time = time.time() | |||||
| self._loader_pool[loader_id].latest_update_time = latest_update_time | |||||
| def _delete_loader(self, loader_id): | |||||
| """delete loader given loader_id""" | |||||
| if self._loader_pool.get(loader_id, None) is not None: | |||||
| self._loader_pool.pop(loader_id) | |||||
| logger.debug('delete loader %s', loader_id) | |||||
| def _add_loader(self, loader): | |||||
| """add loader to the loader_pool.""" | |||||
| if len(self._loader_pool) >= _MAX_LOADER_NUM: | |||||
| delete_num = len(self._loader_pool) - _MAX_LOADER_NUM + 1 | |||||
| sorted_loaders = sorted( | |||||
| self._loader_pool.items(), | |||||
| key=lambda x: x[1].latest_update_time) | |||||
| for index in range(delete_num): | |||||
| delete_loader_id = sorted_loaders[index][0] | |||||
| self._delete_loader(delete_loader_id) | |||||
| self._loader_pool.update({loader.loader_id: loader}) | |||||
| def _deal_loaders(self, latest_loaders): | |||||
| """"update the loader pool.""" | |||||
| with self._loader_pool_mutex: | |||||
| for loader_id, loader in latest_loaders: | |||||
| if self._loader_pool.get(loader_id, None) is None: | |||||
| self._add_loader(loader) | |||||
| continue | |||||
| if (self._loader_pool[loader_id].latest_update_time | |||||
| < loader.latest_update_time): | |||||
| self._update_loader_latest_update_time( | |||||
| loader_id, loader.latest_update_time) | |||||
| @staticmethod | |||||
| def _generate_loader_id(relative_path): | |||||
| """Generate loader id for given path""" | |||||
| loader_id = relative_path | |||||
| return loader_id | |||||
| @staticmethod | |||||
| def _generate_loader_name(relative_path): | |||||
| """Generate_loader name for given path.""" | |||||
| loader_name = relative_path | |||||
| return loader_name | |||||
| def _generate_loader_by_relative_path(self, relative_path: str) -> ExplainJob: | |||||
| """Generate explain job from given relative path.""" | |||||
| current_dir = os.path.realpath(FileHandler.join( | |||||
| self._summary_base_dir, relative_path | |||||
| )) | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = ExplainJob( | |||||
| job_id=loader_id, | |||||
| summary_dir=current_dir, | |||||
| create_time=ExplainJob.get_create_time(current_dir), | |||||
| latest_update_time=ExplainJob.get_update_time(current_dir)) | |||||
| return loader | |||||
| def _generate_loaders(self): | |||||
| """Generate job loaders from the summary watcher.""" | |||||
| dir_map_mtime_dict = {} | |||||
| loader_dict = {} | |||||
| min_modify_time = None | |||||
| _, summaries = SummaryWatcher().list_explain_directories( | |||||
| self._summary_base_dir) | |||||
| for item in summaries: | |||||
| relative_path = item.get('relative_path') | |||||
| modify_time = item.get('update_time').timestamp() | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = self._loader_pool.get(loader_id, None) | |||||
| if loader is not None and loader.latest_update_time > modify_time: | |||||
| modify_time = loader.latest_update_time | |||||
| if min_modify_time is None: | |||||
| min_modify_time = modify_time | |||||
| if len(dir_map_mtime_dict) < _MAX_LOADER_NUM: | |||||
| if modify_time < min_modify_time: | |||||
| min_modify_time = modify_time | |||||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||||
| else: | |||||
| if modify_time >= min_modify_time: | |||||
| dir_map_mtime_dict.update({relative_path: modify_time}) | |||||
| sorted_dir_tuple = sorted(dir_map_mtime_dict.items(), | |||||
| key=lambda d: d[1])[-_MAX_LOADER_NUM:] | |||||
| for relative_path, modify_time in sorted_dir_tuple: | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = self._generate_loader_by_relative_path(relative_path) | |||||
| loader_dict.update({loader_id: loader}) | |||||
| sorted_loaders = sorted(loader_dict.items(), | |||||
| key=lambda x: x[1].latest_update_time) | |||||
| latest_loaders = sorted_loaders[-_MAX_LOADER_NUM:] | |||||
| self._deal_loaders(latest_loaders) | |||||
| def _execute_loader(self, loader_id): | |||||
| """Execute the data loading.""" | |||||
| try: | |||||
| with self._loader_pool_mutex: | |||||
| loader = self._loader_pool.get(loader_id, None) | |||||
| if loader is None: | |||||
| logger.debug('Loader %r has been deleted, will not load' | |||||
| 'data', loader_id) | |||||
| return | |||||
| loader.load() | |||||
| @property | |||||
| def summary_base_dir(self): | |||||
| """Return the base directory for summary records.""" | |||||
| return self._summary_base_dir | |||||
| except MindInsightException as ex: | |||||
| logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) | |||||
| with self._loader_pool_mutex: | |||||
| self._delete_loader(loader_id) | |||||
| def start_load_data(self, reload_interval: int = 0): | |||||
| """ | |||||
| Start individual thread to cache explain_jobs and loading summary data periodically. | |||||
| def _execute_load_data(self): | |||||
| """Execute the loader in the pool to load data.""" | |||||
| loader_pool = self._get_snapshot_loader_pool() | |||||
| for loader_id in loader_pool: | |||||
| self._execute_loader(loader_id) | |||||
| Args: | |||||
| reload_interval (int): Specify the loading period in seconds. If interval == 0, data will only be loaded | |||||
| once. Default: 0. | |||||
| """ | |||||
| thread = threading.Thread(target=self._repeat_loading, | |||||
| name='start_load_thread', | |||||
| args=(reload_interval,), | |||||
| daemon=True) | |||||
| time.sleep(1) | |||||
| thread.start() | |||||
| def _get_snapshot_loader_pool(self): | |||||
| """Get snapshot of loader_pool.""" | |||||
| with self._loader_pool_mutex: | |||||
| return dict(self._loader_pool) | |||||
| def get_job(self, loader_id: str) -> Optional[ExplainLoader]: | |||||
| """ | |||||
| Return ExplainLoader given loader_id. | |||||
| def _check_status_valid(self): | |||||
| """Check manager status.""" | |||||
| if self._status == _ExplainManagerStatus.INIT.value: | |||||
| raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._status) | |||||
| If explain job w.r.t given loader_id is not found, None will be returned. | |||||
| @staticmethod | |||||
| def _check_train_id_valid(train_id: str): | |||||
| """Verify the train_id is valid.""" | |||||
| if not train_id.startswith('./'): | |||||
| logger.warning('train_id does not start with "./"') | |||||
| return False | |||||
| if len(train_id.split('/')) > 2: | |||||
| logger.warning('train_id contains multiple "/"') | |||||
| return False | |||||
| return True | |||||
| def _check_train_job_exist(self, train_id): | |||||
| """Verify thee train_job is existed given train_id.""" | |||||
| if train_id in self._loader_pool: | |||||
| return | |||||
| self._check_train_id_valid(train_id) | |||||
| if SummaryWatcher().is_summary_directory(self._summary_base_dir, train_id): | |||||
| return | |||||
| raise ParamValueError('Can not find the train job in the manager, train_id: %s' % train_id) | |||||
| Args: | |||||
| loader_id (str): The id of expected ExplainLoader | |||||
| def _reload_data_again(self): | |||||
| """Reload the data one more time.""" | |||||
| logger.debug('Start to reload data again.') | |||||
| thread = threading.Thread(target=self._load_data, | |||||
| name='reload_data_thread') | |||||
| thread.daemon = False | |||||
| thread.start() | |||||
| Return: | |||||
| explain_job | |||||
| """ | |||||
| self._check_status_valid() | |||||
| def _get_job(self, train_id): | |||||
| """Retrieve train_job given train_id.""" | |||||
| is_reload = False | |||||
| with self._loader_pool_mutex: | with self._loader_pool_mutex: | ||||
| loader = self._loader_pool.get(train_id, None) | |||||
| if loader is None: | |||||
| relative_path = train_id | |||||
| temp_loader = self._generate_loader_by_relative_path( | |||||
| relative_path) | |||||
| if loader_id in self._loader_pool: | |||||
| self._loader_pool[loader_id].query_time = datetime.now().timestamp() | |||||
| self._loader_pool.move_to_end(loader_id, last=False) | |||||
| return self._loader_pool[loader_id] | |||||
| if temp_loader is None: | |||||
| return None | |||||
| self._add_loader(temp_loader) | |||||
| is_reload = True | |||||
| if is_reload: | |||||
| self._reload_data_again() | |||||
| try: | |||||
| loader = self._generate_loader_from_relative_path(loader_id) | |||||
| loader.query_time = datetime.now().timestamp() | |||||
| self._add_loader(loader) | |||||
| self._reload_data_again() | |||||
| except ParamValueError: | |||||
| logger.warning('Cannot find summary in path: %s. No explain_job will be returned.', loader_id) | |||||
| return None | |||||
| return loader | return loader | ||||
| @property | |||||
| def summary_base_dir(self): | |||||
| """Return the base directory for summary records.""" | |||||
| return self._summary_base_dir | |||||
| def get_job_list(self, offset=0, limit=None): | def get_job_list(self, offset=0, limit=None): | ||||
| """ | """ | ||||
| Return List of explain jobs. includes job ID, create and update time. | Return List of explain jobs. includes job ID, create and update time. | ||||
| @@ -298,44 +119,146 @@ class ExplainManager: | |||||
| - create_time (datetime): Creation time of summary file. | - create_time (datetime): Creation time of summary file. | ||||
| - update_time (datetime): Modification time of summary file. | - update_time (datetime): Modification time of summary file. | ||||
| """ | """ | ||||
| watcher = SummaryWatcher() | |||||
| total, dir_infos = \ | total, dir_infos = \ | ||||
| watcher.list_explain_directories(self._summary_base_dir, | |||||
| offset=offset, limit=limit) | |||||
| self._summary_watcher.list_explain_directories(self._summary_base_dir, offset=offset, limit=limit) | |||||
| return total, dir_infos | return total, dir_infos | ||||
| def get_job(self, train_id): | |||||
| def _repeat_loading(self, repeat_interval): | |||||
| """Periodically loading summary.""" | |||||
| while True: | |||||
| try: | |||||
| logger.info('Start to load data, repeat interval: %r.', repeat_interval) | |||||
| self._load_data() | |||||
| if not repeat_interval: | |||||
| return | |||||
| time.sleep(repeat_interval) | |||||
| except UnknownError as ex: | |||||
| logger.exception(ex) | |||||
| logger.error('Unexpected error happens when loading data. Loading status: %s, loading pool size: %d' | |||||
| 'Detail: %s', self._loading_status, len(self._loader_pool), str(ex)) | |||||
| def _load_data(self): | |||||
| """ | |||||
| Prepare loaders in cache and start loading the data from summaries. | |||||
| Only a limited number of loaders will be cached in terms of updated_time or query_time. The size of cache | |||||
| pool is determined by _MAX_LOADERS_NUM. When the manager start loading data, only the lastest _MAX_LOADER_NUM | |||||
| summaries will be loaded in cache. If a cached loader if queries by 'get_job', the query_time of the loader | |||||
| will be updated as well as the the loader moved to the end of cache. If an uncached summary is queried, | |||||
| a new loader instance will be generated and put to the end cache. | |||||
| """ | """ | ||||
| Return ExplainJob given train_id. | |||||
| try: | |||||
| with self._status_mutex: | |||||
| if self._loading_status == _ExplainManagerStatus.LOADING.value: | |||||
| logger.info('Current status is %s, will ignore to load data.', self._loading_status) | |||||
| return | |||||
| If explain job w.r.t given train_id is not found, None will be returned. | |||||
| self._loading_status = _ExplainManagerStatus.LOADING.value | |||||
| Args: | |||||
| train_id (str): The id of expected ExplainJob | |||||
| self._cache_loaders() | |||||
| self._execute_loading() | |||||
| Return: | |||||
| explain_job | |||||
| """ | |||||
| self._check_status_valid() | |||||
| self._check_train_job_exist(train_id) | |||||
| if not self._loader_pool: | |||||
| self._loading_status = _ExplainManagerStatus.INVALID.value | |||||
| else: | |||||
| self._loading_status = _ExplainManagerStatus.DONE.value | |||||
| logger.info('Load event data end, status: %s, and loader pool size: %d', | |||||
| self._loading_status, len(self._loader_pool)) | |||||
| except Exception as ex: | |||||
| self._loading_status = _ExplainManagerStatus.INVALID.value | |||||
| logger.exception(ex) | |||||
| raise UnknownError(str(ex)) | |||||
| def _cache_loaders(self): | |||||
| """Cache explain loader in cache pool.""" | |||||
| dir_map_mtime_dict = [] | |||||
| _, summaries_info = self._summary_watcher.list_explain_directories(self._summary_base_dir) | |||||
| for summary_info in summaries_info: | |||||
| summary_path = summary_info.get('relative_path') | |||||
| summary_update_time = summary_info.get('update_time').timestamp() | |||||
| if summary_path in self._loader_pool: | |||||
| summary_update_time = max(summary_update_time, self._loader_pool[summary_path].query_time) | |||||
| dir_map_mtime_dict.append((summary_info, summary_update_time)) | |||||
| sorted_summaries_info = sorted(dir_map_mtime_dict, key=lambda x: x[1])[-_MAX_LOADERS_NUM:] | |||||
| loader = self._get_job(train_id) | |||||
| if loader is None: | |||||
| return None | |||||
| with self._loader_pool_mutex: | |||||
| for summary_info, query_time in sorted_summaries_info: | |||||
| summary_path = summary_info['relative_path'] | |||||
| if summary_path not in self._loader_pool: | |||||
| loader = self._generate_loader_from_relative_path(summary_path) | |||||
| self._add_loader(loader) | |||||
| else: | |||||
| self._loader_pool[summary_path].query_time = query_time | |||||
| self._loader_pool.move_to_end(summary_path, last=False) | |||||
| def _generate_loader_from_relative_path(self, relative_path: str) -> ExplainLoader: | |||||
| """Generate explain loader from the given relative path.""" | |||||
| self._check_summary_exist(relative_path) | |||||
| current_dir = os.path.realpath(FileHandler.join(self._summary_base_dir, relative_path)) | |||||
| loader_id = self._generate_loader_id(relative_path) | |||||
| loader = ExplainLoader(loader_id=loader_id, summary_dir=current_dir) | |||||
| return loader | return loader | ||||
| def start_load_data(self, reload_interval=_MAX_INTERVAL): | |||||
| """ | |||||
| Start threads for loading data. | |||||
| def _add_loader(self, loader): | |||||
| """add loader to the loader_pool.""" | |||||
| if loader.train_id not in self._loader_pool: | |||||
| self._loader_pool[loader.train_id] = loader | |||||
| else: | |||||
| self._loader_pool.move_to_end(loader.train_id) | |||||
| Args: | |||||
| reload_interval (int): interval to reload the summary from file | |||||
| """ | |||||
| self._reload_interval = reload_interval | |||||
| while len(self._loader_pool) > self._max_loaders_num: | |||||
| self._loader_pool.popitem(last=False) | |||||
| def _execute_loading(self): | |||||
| """Execute the data loading.""" | |||||
| for loader_id in list(self._loader_pool.keys()): | |||||
| try: | |||||
| with self._loader_pool_mutex: | |||||
| loader = self._loader_pool.get(loader_id, None) | |||||
| if loader is None: | |||||
| logger.debug('Loader %r has been deleted, will not load data', loader_id) | |||||
| return | |||||
| loader.load() | |||||
| except MindInsightException as ex: | |||||
| logger.warning('Data loader %r load data failed. Delete data_loader. Detail: %s', loader_id, ex) | |||||
| with self._loader_pool_mutex: | |||||
| self._delete_loader(loader_id) | |||||
| def _delete_loader(self, loader_id): | |||||
| """delete loader given loader_id""" | |||||
| if loader_id in self._loader_pool: | |||||
| self._loader_pool.pop(loader_id) | |||||
| logger.debug('delete loader %s', loader_id) | |||||
| thread = threading.Thread(target=self._reload_data, name='start_load_data_thread') | |||||
| thread.daemon = True | |||||
| def _check_status_valid(self): | |||||
| """Check manager status.""" | |||||
| if self._loading_status == _ExplainManagerStatus.INIT.value: | |||||
| raise exceptions.SummaryLogIsLoading('Data is loading, current status is %s' % self._loading_status) | |||||
| def _check_summary_exist(self, loader_id): | |||||
| """Verify thee train_job is existed given loader_id.""" | |||||
| if not self._summary_watcher.is_summary_directory(self._summary_base_dir, loader_id): | |||||
| raise ParamValueError('Can not find the train job in the manager.') | |||||
| def _reload_data_again(self): | |||||
| """Reload the data one more time.""" | |||||
| logger.debug('Start to reload data again.') | |||||
| thread = threading.Thread(target=self._load_data, name='reload_data_thread') | |||||
| thread.daemon = False | |||||
| thread.start() | thread.start() | ||||
| # wait for data loading | |||||
| time.sleep(1) | |||||
| @staticmethod | |||||
| def _generate_loader_id(relative_path): | |||||
| """Generate loader id for given path""" | |||||
| loader_id = relative_path | |||||
| return loader_id | |||||
| EXPLAIN_MANAGER = ExplainManager(summary_base_dir=settings.SUMMARY_BASE_DIR) | |||||
| @@ -17,49 +17,41 @@ File parser for MindExplain data. | |||||
| This module is used to parse the MindExplain log file. | This module is used to parse the MindExplain log file. | ||||
| """ | """ | ||||
| import re | |||||
| import collections | |||||
| from collections import namedtuple | |||||
| from google.protobuf.message import DecodeError | from google.protobuf.message import DecodeError | ||||
| from mindinsight.datavisual.common import exceptions | from mindinsight.datavisual.common import exceptions | ||||
| from mindinsight.explainer.common.enums import PluginNameEnum | |||||
| from mindinsight.explainer.common.enums import ExplainFieldsEnum | |||||
| from mindinsight.explainer.common.log import logger | from mindinsight.explainer.common.log import logger | ||||
| from mindinsight.datavisual.data_access.file_handler import FileHandler | from mindinsight.datavisual.data_access.file_handler import FileHandler | ||||
| from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser | from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser | ||||
| from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2 | ||||
| from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain | |||||
| from mindinsight.utils.exceptions import UnknownError | from mindinsight.utils.exceptions import UnknownError | ||||
| HEADER_SIZE = 8 | HEADER_SIZE = 8 | ||||
| CRC_STR_SIZE = 4 | CRC_STR_SIZE = 4 | ||||
| MAX_EVENT_STRING = 500000000 | MAX_EVENT_STRING = 500000000 | ||||
| BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status']) | |||||
| MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status']) | |||||
| class ImageDataContainer: | |||||
| """ | |||||
| Container for image data to allow pickling. | |||||
| Args: | |||||
| explain_message (Explain): Explain proto buffer message. | |||||
| """ | |||||
| def __init__(self, explain_message: Explain): | |||||
| self.sample_id = explain_message.sample_id | |||||
| self.image_path = explain_message.image_path | |||||
| self.ground_truth_label = explain_message.ground_truth_label | |||||
| self.inference = explain_message.inference | |||||
| self.explanation = explain_message.explanation | |||||
| self.status = explain_message.status | |||||
| class _ExplainParser(_SummaryParser): | |||||
| BenchmarkContainer = namedtuple('BenchmarkContainer', ['benchmark', 'status']) | |||||
| MetadataContainer = namedtuple('MetadataContainer', ['metadata', 'status']) | |||||
| InferfenceContainer = namedtuple('InferenceContainer', ['ground_truth_prob', | |||||
| 'ground_truth_prob_sd', | |||||
| 'ground_truth_prob_itl95_low', | |||||
| 'ground_truth_prob_itl95_hi', | |||||
| 'predicted_label', | |||||
| 'predicted_prob', | |||||
| 'predicted_prob_sd', | |||||
| 'predicted_prob_itl95_low', | |||||
| 'predicted_prob_itl95_hi']) | |||||
| SampleContainer = namedtuple('SampleContainer', ['sample_id', 'image_path', 'ground_truth_label', 'inference', | |||||
| 'explanation', 'status']) | |||||
| class ExplainParser(_SummaryParser): | |||||
| """The summary file parser.""" | """The summary file parser.""" | ||||
| def __init__(self, summary_dir): | def __init__(self, summary_dir): | ||||
| super(_ExplainParser, self).__init__(summary_dir) | |||||
| super(ExplainParser, self).__init__(summary_dir) | |||||
| self._latest_filename = '' | self._latest_filename = '' | ||||
| def parse_explain(self, filenames): | def parse_explain(self, filenames): | ||||
| @@ -71,8 +63,7 @@ class _ExplainParser(_SummaryParser): | |||||
| Returns: | Returns: | ||||
| bool, True if all the summary files are finished loading. | bool, True if all the summary files are finished loading. | ||||
| """ | """ | ||||
| summary_files = self.filter_files(filenames) | |||||
| summary_files = self.sort_files(summary_files) | |||||
| summary_files = self.sort_files(filenames) | |||||
| is_end = False | is_end = False | ||||
| is_clean = False | is_clean = False | ||||
| @@ -125,20 +116,6 @@ class _ExplainParser(_SummaryParser): | |||||
| logger.exception(ex) | logger.exception(ex) | ||||
| raise UnknownError(str(ex)) | raise UnknownError(str(ex)) | ||||
| def filter_files(self, filenames): | |||||
| """ | |||||
| Gets a list of summary files. | |||||
| Args: | |||||
| filenames (list[str]): File name list, like [filename1, filename2]. | |||||
| Returns: | |||||
| list[str], filename list. | |||||
| """ | |||||
| return list(filter( | |||||
| lambda filename: (re.search(r'summary\.\d+', filename) | |||||
| and filename.endswith("_explain")), filenames)) | |||||
| @staticmethod | @staticmethod | ||||
| def _event_decode(event_str): | def _event_decode(event_str): | ||||
| """ | """ | ||||
| @@ -153,9 +130,9 @@ class _ExplainParser(_SummaryParser): | |||||
| logger.debug("Deserialize event string completed.") | logger.debug("Deserialize event string completed.") | ||||
| fields = { | fields = { | ||||
| 'sample_id': PluginNameEnum.SAMPLE_ID, | |||||
| 'benchmark': PluginNameEnum.BENCHMARK, | |||||
| 'metadata': PluginNameEnum.METADATA | |||||
| 'sample_id': ExplainFieldsEnum.SAMPLE_ID, | |||||
| 'benchmark': ExplainFieldsEnum.BENCHMARK, | |||||
| 'metadata': ExplainFieldsEnum.METADATA | |||||
| } | } | ||||
| tensor_event_value = getattr(event, 'explain') | tensor_event_value = getattr(event, 'explain') | ||||
| @@ -163,19 +140,19 @@ class _ExplainParser(_SummaryParser): | |||||
| field_list = [] | field_list = [] | ||||
| tensor_value_list = [] | tensor_value_list = [] | ||||
| for field in fields: | for field in fields: | ||||
| if not getattr(tensor_event_value, field): | |||||
| if not getattr(tensor_event_value, field, False): | |||||
| continue | continue | ||||
| if PluginNameEnum.METADATA.value == field and not tensor_event_value.metadata.label: | |||||
| if ExplainFieldsEnum.METADATA.value == field and not tensor_event_value.metadata.label: | |||||
| continue | continue | ||||
| tensor_value = None | tensor_value = None | ||||
| if field == PluginNameEnum.SAMPLE_ID.value: | |||||
| tensor_value = _ExplainParser._add_image_data(tensor_event_value) | |||||
| elif field == PluginNameEnum.BENCHMARK.value: | |||||
| tensor_value = _ExplainParser._add_benchmark(tensor_event_value) | |||||
| elif field == PluginNameEnum.METADATA.value: | |||||
| tensor_value = _ExplainParser._add_metadata(tensor_event_value) | |||||
| if field == ExplainFieldsEnum.SAMPLE_ID.value: | |||||
| tensor_value = ExplainParser._add_image_data(tensor_event_value) | |||||
| elif field == ExplainFieldsEnum.BENCHMARK.value: | |||||
| tensor_value = ExplainParser._add_benchmark(tensor_event_value) | |||||
| elif field == ExplainFieldsEnum.METADATA.value: | |||||
| tensor_value = ExplainParser._add_metadata(tensor_event_value) | |||||
| logger.debug("Event generated, label is %s, step is %s.", field, event.step) | logger.debug("Event generated, label is %s, step is %s.", field, event.step) | ||||
| field_list.append(field) | field_list.append(field) | ||||
| tensor_value_list.append(tensor_value) | tensor_value_list.append(tensor_value) | ||||
| @@ -189,8 +166,26 @@ class _ExplainParser(_SummaryParser): | |||||
| Args: | Args: | ||||
| tensor_event_value: the object of Explain message | tensor_event_value: the object of Explain message | ||||
| """ | """ | ||||
| image_data = ImageDataContainer(tensor_event_value) | |||||
| return image_data | |||||
| inference = InferfenceContainer( | |||||
| ground_truth_prob=tensor_event_value.inference.ground_truth_prob, | |||||
| ground_truth_prob_sd=tensor_event_value.inference.ground_truth_prob_sd, | |||||
| ground_truth_prob_itl95_low=tensor_event_value.inference.ground_truth_prob_itl95_low, | |||||
| ground_truth_prob_itl95_hi=tensor_event_value.inference.ground_truth_prob_itl95_hi, | |||||
| predicted_label=tensor_event_value.inference.predicted_label, | |||||
| predicted_prob=tensor_event_value.inference.predicted_prob, | |||||
| predicted_prob_sd=tensor_event_value.inference.predicted_prob_sd, | |||||
| predicted_prob_itl95_low=tensor_event_value.inference.predicted_prob_itl95_low, | |||||
| predicted_prob_itl95_hi=tensor_event_value.inference.predicted_prob_itl95_hi | |||||
| ) | |||||
| sample_data = SampleContainer( | |||||
| sample_id=tensor_event_value.sample_id, | |||||
| image_path=tensor_event_value.image_path, | |||||
| ground_truth_label=tensor_event_value.ground_truth_label, | |||||
| inference=inference, | |||||
| explanation=tensor_event_value.explanation, | |||||
| status=tensor_event_value.status | |||||
| ) | |||||
| return sample_data | |||||
| @staticmethod | @staticmethod | ||||
| def _add_benchmark(tensor_event_value): | def _add_benchmark(tensor_event_value): | ||||
| @@ -25,7 +25,7 @@ class MockExplainJob: | |||||
| self.create_time = datetime.timestamp( | self.create_time = datetime.timestamp( | ||||
| datetime.strptime("2020-10-01 20:21:23", | datetime.strptime("2020-10-01 20:21:23", | ||||
| ExplainJobEncap.DATETIME_FORMAT)) | ExplainJobEncap.DATETIME_FORMAT)) | ||||
| self.latest_update_time = self.create_time | |||||
| self.update_time = self.create_time | |||||
| self.sample_count = 1999 | self.sample_count = 1999 | ||||
| self.min_confidence = 0.5 | self.min_confidence = 0.5 | ||||
| self.explainers = ["Gradient"] | self.explainers = ["Gradient"] | ||||