|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """ExplainJob."""
-
- import os
- from collections import defaultdict
- from datetime import datetime
- from typing import Union
-
- from mindinsight.explainer.common.enums import PluginNameEnum
- from mindinsight.explainer.common.log import logger
- from mindinsight.explainer.manager.explain_parser import _ExplainParser
- from mindinsight.explainer.manager.event_parse import EventParser
- from mindinsight.datavisual.data_access.file_handler import FileHandler
- from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
-
- _NUM_DIGIT = 7
-
-
- class ExplainJob:
- """ExplainJob which manage the record in the summary file."""
-
- def __init__(self,
- job_id: str,
- summary_dir: str,
- create_time: float,
- latest_update_time: float):
-
- self._job_id = job_id
- self._summary_dir = summary_dir
- self._parser = _ExplainParser(summary_dir)
-
- self._event_parser = EventParser(self)
- self._latest_update_time = latest_update_time
- self._create_time = create_time
- self._uncertainty_enabled = False
- self._labels = []
- self._metrics = []
- self._explainers = []
- self._samples_info = {}
- self._labels_info = {}
- self._explainer_score_dict = defaultdict(list)
- self._label_score_dict = defaultdict(dict)
-
- @property
- def all_classes(self):
- """
- Return a list of label info
-
- Returns:
- class_objs (List[ClassObj]): a list of class_objects, each object
- contains:
-
- - id (int): label id
- - label (str): label name
- - sample_count (int): number of samples for each label
- """
- all_classes_return = []
- for label_id, label_info in self._labels_info.items():
- single_info = {
- 'id': label_id,
- 'label': label_info['label'],
- 'sample_count': len(label_info['sample_ids'])}
- all_classes_return.append(single_info)
- return all_classes_return
-
- @property
- def explainers(self):
- """
- Return a list of explainer names
-
- Returns:
- list(str), explainer names
- """
- return self._explainers
-
- @property
- def explainer_scores(self):
- """Return evaluation results for every explainer."""
- merged_scores = []
- for explainer, explainer_score_on_metric in self._explainer_score_dict.items():
- label_scores = []
- for label, label_score_on_metric in self._label_score_dict[explainer].items():
- score_single_label = {
- 'label': self._labels[label],
- 'evaluations': label_score_on_metric,
- }
- label_scores.append(score_single_label)
- merged_scores.append({
- 'explainer': explainer,
- 'evaluations': explainer_score_on_metric,
- 'class_scores': label_scores,
- })
- return merged_scores
-
- @property
- def sample_count(self):
- """
- Return total number of samples in the job.
-
- Return:
- int, total number of samples
-
- """
- return len(self._samples_info)
-
- @property
- def train_id(self):
- """
- Return ID of explain job
-
- Returns:
- str, id of ExplainJob object
- """
- return self._job_id
-
- @property
- def metrics(self):
- """
- Return a list of metric names
-
- Returns:
- list(str), metric names
- """
- return self._metrics
-
- @property
- def min_confidence(self):
- """
- Return minimum confidence
-
- Returns:
- min_confidence (float):
- """
- return None
-
- @property
- def uncertainty_enabled(self):
- return self._uncertainty_enabled
-
- @property
- def create_time(self):
- """
- Return the create time of summary file
-
- Returns:
- creation timestamp (float)
-
- """
- return self._create_time
-
- @property
- def labels(self):
- """Return the label contained in the job."""
- return self._labels
-
- @property
- def latest_update_time(self):
- """
- Return last modification time stamp of summary file.
-
- Returns:
- float, last_modification_time stamp
- """
- return self._latest_update_time
-
- @latest_update_time.setter
- def latest_update_time(self, new_time: Union[float, datetime]):
- """
- Update the latest_update_time timestamp manually.
-
- Args:
- new_time stamp (union[float, datetime]): updated time for the job
- """
- if isinstance(new_time, datetime):
- self._latest_update_time = new_time.timestamp()
- elif isinstance(new_time, float):
- self._latest_update_time = new_time
- else:
- raise TypeError('new_time should have type of float or datetime')
-
- @property
- def loader_id(self):
- """Return the job id."""
- return self._job_id
-
- @property
- def samples(self):
- """Return the information of all samples in the job."""
- return self._samples_info
-
- @staticmethod
- def get_create_time(file_path: str) -> float:
- """Return timestamp of create time of specific path."""
- create_time = os.stat(file_path).st_ctime
- return create_time
-
- @staticmethod
- def get_update_time(file_path: str) -> float:
- """Return timestamp of update time of specific path."""
- update_time = os.stat(file_path).st_mtime
- return update_time
-
- def _initialize_labels_info(self):
- """Initialize a dict for labels in the job."""
- if self._labels is None:
- logger.warning('No labels is provided in job %s', self._job_id)
- return
-
- for label_id, label in enumerate(self._labels):
- self._labels_info[label_id] = {'label': label,
- 'sample_ids': set()}
-
- def _explanation_to_dict(self, explanation):
- """Transfer the explanation from event to dict storage."""
- explain_info = {
- 'explainer': explanation.explain_method,
- 'overlay': explanation.heatmap_path,
- }
- return explain_info
-
- def _image_container_to_dict(self, sample_data):
- """Transfer the image container to dict storage."""
- has_uncertainty = False
- sample_id = sample_data.sample_id
-
- sample_info = {
- 'id': sample_id,
- 'image': sample_data.image_path,
- 'name': str(sample_id),
- 'labels': [self._labels_info[x]['label']
- for x in sample_data.ground_truth_label],
- 'inferences': []}
-
- ground_truth_labels = list(sample_data.ground_truth_label)
- ground_truth_probs = list(sample_data.inference.ground_truth_prob)
- predicted_labels = list(sample_data.inference.predicted_label)
- predicted_probs = list(sample_data.inference.predicted_prob)
-
- if sample_data.inference.predicted_prob_sd or sample_data.inference.ground_truth_prob_sd:
- ground_truth_prob_sds = list(sample_data.inference.ground_truth_prob_sd)
- ground_truth_prob_lows = list(sample_data.inference.ground_truth_prob_itl95_low)
- ground_truth_prob_his = list(sample_data.inference.ground_truth_prob_itl95_hi)
- predicted_prob_sds = list(sample_data.inference.predicted_prob_sd)
- predicted_prob_lows = list(sample_data.inference.predicted_prob_itl95_low)
- predicted_prob_his = list(sample_data.inference.predicted_prob_itl95_hi)
- has_uncertainty = True
- else:
- ground_truth_prob_sds = ground_truth_prob_lows = ground_truth_prob_his = None
- predicted_prob_sds = predicted_prob_lows = predicted_prob_his = None
-
- inference_info = {}
- for label, prob in zip(
- ground_truth_labels + predicted_labels,
- ground_truth_probs + predicted_probs):
- inference_info[label] = {
- 'label': self._labels_info[label]['label'],
- 'confidence': round(prob, _NUM_DIGIT),
- 'saliency_maps': []}
-
- if ground_truth_prob_sds or predicted_prob_sds:
- for label, sd, low, hi in zip(
- ground_truth_labels + predicted_labels,
- ground_truth_prob_sds + predicted_prob_sds,
- ground_truth_prob_lows + predicted_prob_lows,
- ground_truth_prob_his + predicted_prob_his):
- inference_info[label]['confidence_sd'] = sd
- inference_info[label]['confidence_itl95'] = [low, hi]
-
- if EventParser.is_attr_ready(sample_data, 'explanation'):
- for explanation in sample_data.explanation:
- explanation_dict = self._explanation_to_dict(explanation)
- inference_info[explanation.label]['saliency_maps'].append(explanation_dict)
-
- sample_info['inferences'] = list(inference_info.values())
- return sample_info, has_uncertainty
-
- def _import_sample(self, sample):
- """Add sample object of given sample id."""
- for label_id in sample.ground_truth_label:
- self._labels_info[label_id]['sample_ids'].add(sample.sample_id)
-
- sample_info, has_uncertainty = self._image_container_to_dict(sample)
- self._samples_info.update({sample_info['id']: sample_info})
- self._uncertainty_enabled |= has_uncertainty
-
- def get_all_samples(self):
- """
- Return a list of sample information cachced in the explain job
-
- Returns:
- sample_list (List[SampleObj]): a list of sample objects, each object
- consists of:
-
- - id (int): sample id
- - name (str): basename of image
- - labels (list[str]): list of labels
- - inferences list[dict])
- """
- samples_in_list = list(self._samples_info.values())
- return samples_in_list
-
- def _is_metadata_empty(self):
- """Check whether metadata is loaded first."""
- if not self._explainers or not self._metrics or not self._labels:
- return True
- return False
-
- def _import_data_from_event(self, event):
- """Parse and import data from the event data."""
- tags = {
- 'sample_id': PluginNameEnum.SAMPLE_ID,
- 'benchmark': PluginNameEnum.BENCHMARK,
- 'metadata': PluginNameEnum.METADATA
- }
-
- if 'metadata' not in event and self._is_metadata_empty():
- raise ValueError('metadata is empty, should write metadata first in the summary.')
- for tag in tags:
- if tag not in event:
- continue
-
- if tag == PluginNameEnum.SAMPLE_ID.value:
- sample_event = event[tag]
- sample_data = self._event_parser.parse_sample(sample_event)
- if sample_data is not None:
- self._import_sample(sample_data)
- continue
-
- if tag == PluginNameEnum.BENCHMARK.value:
- benchmark_event = event[tag].benchmark
- explain_score_dict, label_score_dict = EventParser.parse_benchmark(benchmark_event)
- self._update_benchmark(explain_score_dict, label_score_dict)
-
- elif tag == PluginNameEnum.METADATA.value:
- metadata_event = event[tag].metadata
- metadata = EventParser.parse_metadata(metadata_event)
- self._explainers, self._metrics, self._labels = metadata
- self._initialize_labels_info()
-
- def load(self):
- """
- Start loading data from parser.
- """
- valid_file_names = []
- for filename in FileHandler.list_dir(self._summary_dir):
- if FileHandler.is_file(
- FileHandler.join(self._summary_dir, filename)):
- valid_file_names.append(filename)
-
- if not valid_file_names:
- raise TrainJobNotExistError('No summary file found in %s, explain job will be delete.' % self._summary_dir)
-
- is_end = False
- while not is_end:
- is_clean, is_end, event = self._parser.parse_explain(valid_file_names)
-
- if is_clean:
- logger.info('Summary file in %s update, reload the clean the loaded data.', self._summary_dir)
- self._clean_job()
-
- if event:
- self._import_data_from_event(event)
-
- def _clean_job(self):
- """Clean the cached data in job."""
- self._latest_update_time = ExplainJob.get_update_time(self._summary_dir)
- self._create_time = ExplainJob.get_update_time(self._summary_dir)
- self._labels.clear()
- self._metrics.clear()
- self._explainers.clear()
- self._samples_info.clear()
- self._labels_info.clear()
- self._explainer_score_dict.clear()
- self._label_score_dict.clear()
- self._event_parser.clear()
-
- def _update_benchmark(self, explainer_score_dict, labels_score_dict):
- """Update the benchmark info."""
- for explainer, score in explainer_score_dict.items():
- self._explainer_score_dict[explainer].extend(score)
-
- for explainer, score in labels_score_dict.items():
- for label, score_of_label in score.items():
- self._label_score_dict[explainer][label] = (self._label_score_dict[explainer].get(label, [])
- + score_of_label)
|