Browse Source

!1067 Refactor the explainer manager module part of code

From: @ouwenchang
Reviewed-by: @wangyue01,@tangjr14
Signed-off-by: @tangjr14
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
f357d24e9f
4 changed files with 48 additions and 42 deletions
  1. +20
    -22
      mindinsight/explainer/manager/explain_loader.py
  2. +5
    -4
      mindinsight/explainer/manager/explain_manager.py
  3. +19
    -13
      mindinsight/explainer/manager/explain_parser.py
  4. +4
    -3
      tests/ut/explainer/manager/test_explain_loader.py

+ 20
- 22
mindinsight/explainer/manager/explain_loader.py View File

@@ -14,21 +14,22 @@
# ============================================================================ # ============================================================================
"""ExplainLoader.""" """ExplainLoader."""


from collections import defaultdict
from enum import Enum

import math import math
import os import os
import re import re
from collections import defaultdict
import threading
from datetime import datetime from datetime import datetime
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from enum import Enum
import threading


from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.explainer.common.enums import ExplainFieldsEnum from mindinsight.explainer.common.enums import ExplainFieldsEnum
from mindinsight.explainer.common.log import logger from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_parser import ExplainParser from mindinsight.explainer.manager.explain_parser import ExplainParser
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.common.exceptions import TrainJobNotExistError
from mindinsight.utils.exceptions import ParamValueError, UnknownError
from mindinsight.utils.exceptions import ParamValueError


_NAN_CONSTANT = 'NaN' _NAN_CONSTANT = 'NaN'
_NUM_DIGITS = 6 _NUM_DIGITS = 6
@@ -166,7 +167,7 @@ class ExplainLoader:
Returns: Returns:
list[dict], A list of evaluation results of each explainer. Each item contains: list[dict], A list of evaluation results of each explainer. Each item contains:
- explainer (str): Name of evaluated explainer. - explainer (str): Name of evaluated explainer.
- evaluations (list[dict]): A list of evlauation results by different metrics.
- evaluations (list[dict]): A list of evaluation results by different metrics.
- class_scores (list[dict]): A list of evaluation results on different labels. - class_scores (list[dict]): A list of evaluation results on different labels.


Each item in the evaluations contains: Each item in the evaluations contains:
@@ -175,7 +176,7 @@ class ExplainLoader:


Each item in the class_scores contains: Each item in the class_scores contains:
- label (str): Name of label - label (str): Name of label
- evaluations (list[dict]): A list of evalution results on different labels by different metrics.
- evaluations (list[dict]): A list of evaluation results on different labels by different metrics.


Each item in evaluations contains: Each item in evaluations contains:
- metric (str): Name of metric method - metric (str): Name of metric method
@@ -247,7 +248,7 @@ class ExplainLoader:


@property @property
def uncertainty_enabled(self): def uncertainty_enabled(self):
"""Whethter uncertainty is enabled."""
"""Whether uncertainty is enabled."""
return self._loader_info['uncertainty_enabled'] return self._loader_info['uncertainty_enabled']


@property @property
@@ -286,7 +287,7 @@ class ExplainLoader:


is_end = False is_end = False
while not is_end and self.status != _LoaderStatus.STOP.value: while not is_end and self.status != _LoaderStatus.STOP.value:
file_changed, is_end, event_dict = self._parser.parse_explain(filenames)
file_changed, is_end, event_dict = self._parser.list_events(filenames)


if file_changed: if file_changed:
logger.info('Summary file in %s update, reload the data in the summary.', logger.info('Summary file in %s update, reload the data in the summary.',
@@ -371,7 +372,7 @@ class ExplainLoader:
def _import_data_from_event(self, event_dict: Dict): def _import_data_from_event(self, event_dict: Dict):
"""Parse and import data from the event data.""" """Parse and import data from the event data."""
if 'metadata' not in event_dict and self._is_metadata_empty(): if 'metadata' not in event_dict and self._is_metadata_empty():
raise ParamValueError('metadata is imcomplete, should write metadata first in the summary.')
raise ParamValueError('metadata is incomplete, should write metadata first in the summary.')


for tag, event in event_dict.items(): for tag, event in event_dict.items():
if tag == ExplainFieldsEnum.METADATA.value: if tag == ExplainFieldsEnum.METADATA.value:
@@ -407,8 +408,8 @@ class ExplainLoader:
""" """
Parse the benchmark event. Parse the benchmark event.


Benchmark data are separeted into 'explainer_score' and 'label_score'. 'explainer_score' contains overall
evaluation results of each explainer by different metrics, while 'label_score' additionally devides the results
Benchmark data are separated into 'explainer_score' and 'label_score'. 'explainer_score' contains overall
evaluation results of each explainer by different metrics, while 'label_score' additionally divides the results
w.r.t different labels. w.r.t different labels.


The structure of self._benchmark['explainer_score'] demonstrates below: The structure of self._benchmark['explainer_score'] demonstrates below:
@@ -487,15 +488,12 @@ class ExplainLoader:
self._samples[sample_id]['image'] = sample.image_path self._samples[sample_id]['image'] = sample.image_path


for tag in _SAMPLE_FIELD_NAMES: for tag in _SAMPLE_FIELD_NAMES:
try:
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
elif tag == ExplainFieldsEnum.INFERENCE:
self._import_inference_from_event(sample, sample_id)
else:
self._import_explanation_from_event(sample, sample_id)
except UnknownError as ex:
logger.warning("Parse %s data failed within image related data, detail: %r", tag, str(ex))
if tag == ExplainFieldsEnum.GROUND_TRUTH_LABEL:
self._samples[sample_id]['ground_truth_label'].extend(list(sample.ground_truth_label))
elif tag == ExplainFieldsEnum.INFERENCE:
self._import_inference_from_event(sample, sample_id)
else:
self._import_explanation_from_event(sample, sample_id)


def _import_inference_from_event(self, event, sample_id): def _import_inference_from_event(self, event, sample_id):
"""Parse the inference event.""" """Parse the inference event."""


+ 5
- 4
mindinsight/explainer/manager/explain_manager.py View File

@@ -72,7 +72,6 @@ class ExplainManager:
name='explainer.start_load_thread', name='explainer.start_load_thread',
args=(reload_interval,), args=(reload_interval,),
daemon=True) daemon=True)
time.sleep(1)
thread.start() thread.start()


def get_job(self, loader_id: str) -> Optional[ExplainLoader]: def get_job(self, loader_id: str) -> Optional[ExplainLoader]:
@@ -127,6 +126,8 @@ class ExplainManager:


def _repeat_loading(self, repeat_interval): def _repeat_loading(self, repeat_interval):
"""Periodically loading summary.""" """Periodically loading summary."""
# Allocate CPU resources to enable gunicorn to start the web service.
time.sleep(1)
while True: while True:
try: try:
if self.status == _ExplainManagerStatus.STOPPING.value: if self.status == _ExplainManagerStatus.STOPPING.value:
@@ -178,7 +179,7 @@ class ExplainManager:


def _cache_loaders(self): def _cache_loaders(self):
"""Cache explain loader in cache pool.""" """Cache explain loader in cache pool."""
dir_map_mtime_dict = []
dir_map_mtimes = []
_, summaries_info = self._summary_watcher.list_explain_directories(self._summary_base_dir) _, summaries_info = self._summary_watcher.list_explain_directories(self._summary_base_dir)


for summary_info in summaries_info: for summary_info in summaries_info:
@@ -188,9 +189,9 @@ class ExplainManager:
if summary_path in self._loader_pool: if summary_path in self._loader_pool:
summary_update_time = max(summary_update_time, self._loader_pool[summary_path].query_time) summary_update_time = max(summary_update_time, self._loader_pool[summary_path].query_time)


dir_map_mtime_dict.append((summary_info, summary_update_time))
dir_map_mtimes.append((summary_info, summary_update_time))


sorted_summaries_info = sorted(dir_map_mtime_dict, key=lambda x: x[1])[-_MAX_LOADERS_NUM:]
sorted_summaries_info = sorted(dir_map_mtimes, key=lambda x: x[1])[-_MAX_LOADERS_NUM:]


with self._loader_pool_mutex: with self._loader_pool_mutex:
for summary_info, query_time in sorted_summaries_info: for summary_info, query_time in sorted_summaries_info:


+ 19
- 13
mindinsight/explainer/manager/explain_parser.py View File

@@ -52,21 +52,25 @@ class ExplainParser(_SummaryParser):


def __init__(self, summary_dir): def __init__(self, summary_dir):
super(ExplainParser, self).__init__(summary_dir) super(ExplainParser, self).__init__(summary_dir)
self._latest_filename = ''
self._latest_offset = 0


def parse_explain(self, filenames):
def list_events(self, filenames):
""" """
Load summary file and parse file content. Load summary file and parse file content.


Args: Args:
filenames (list[str]): File name list. filenames (list[str]): File name list.
Returns: Returns:
bool, True if all the summary files are finished loading.
tuple, will return (file_changed, is_end, event_data),

file_changed (bool): True if the 9latest file is changed.
is_end (bool): True if all the summary files are finished loading.
event_data (dict): return an event data, key is field.
""" """
summary_files = self.sort_files(filenames) summary_files = self.sort_files(filenames)


is_end = False is_end = False
is_clean = False
file_changed = False
event_data = {} event_data = {}
filename = summary_files[-1] filename = summary_files[-1]


@@ -74,13 +78,13 @@ class ExplainParser(_SummaryParser):
if filename != self._latest_filename: if filename != self._latest_filename:
self._summary_file_handler = FileHandler(file_path, 'rb') self._summary_file_handler = FileHandler(file_path, 'rb')
self._latest_filename = filename self._latest_filename = filename
self._latest_file_size = 0
is_clean = True
self._latest_offset = 0
file_changed = True


new_size = FileHandler.file_stat(file_path).size new_size = FileHandler.file_stat(file_path).size
if new_size == self._latest_file_size:
if new_size == self._latest_offset:
is_end = True is_end = True
return is_clean, is_end, event_data
return file_changed, is_end, event_data


while True: while True:
start_offset = self._summary_file_handler.offset start_offset = self._summary_file_handler.offset
@@ -89,7 +93,7 @@ class ExplainParser(_SummaryParser):
if event_str is None: if event_str is None:
self._summary_file_handler.reset_offset(start_offset) self._summary_file_handler.reset_offset(start_offset)
is_end = True is_end = True
return is_clean, is_end, event_data
return file_changed, is_end, event_data
if len(event_str) > MAX_EVENT_STRING: if len(event_str) > MAX_EVENT_STRING:
logger.warning("file_path: %s, event string: %d exceeds %d and drop it.", logger.warning("file_path: %s, event string: %d exceeds %d and drop it.",
self._summary_file_handler.file_path, len(event_str), MAX_EVENT_STRING) self._summary_file_handler.file_path, len(event_str), MAX_EVENT_STRING)
@@ -98,24 +102,26 @@ class ExplainParser(_SummaryParser):
field_list, tensor_value_list = self._event_decode(event_str) field_list, tensor_value_list = self._event_decode(event_str)
for field, tensor_value in zip(field_list, tensor_value_list): for field, tensor_value in zip(field_list, tensor_value_list):
event_data[field] = tensor_value event_data[field] = tensor_value

logger.debug("Parse summary file offset %d, file path: %s.", logger.debug("Parse summary file offset %d, file path: %s.",
self._summary_file_handler.offset, file_path) self._summary_file_handler.offset, file_path)
return is_clean, is_end, event_data

return file_changed, is_end, event_data
except (exceptions.CRCFailedError, exceptions.CRCLengthFailedError) as ex: except (exceptions.CRCFailedError, exceptions.CRCLengthFailedError) as ex:
self._summary_file_handler.reset_offset(start_offset) self._summary_file_handler.reset_offset(start_offset)
is_end = True is_end = True
logger.warning("Check crc failed and ignore this file, file_path=%s, offset=%s. Detail: %r.", logger.warning("Check crc failed and ignore this file, file_path=%s, offset=%s. Detail: %r.",
self._summary_file_handler.file_path, self._summary_file_handler.offset, str(ex)) self._summary_file_handler.file_path, self._summary_file_handler.offset, str(ex))
return is_clean, is_end, event_data
return file_changed, is_end, event_data
except (OSError, DecodeError, exceptions.MindInsightException) as ex: except (OSError, DecodeError, exceptions.MindInsightException) as ex:
is_end = True is_end = True
logger.warning("Parse log file fail, and ignore this file, detail: %r," logger.warning("Parse log file fail, and ignore this file, detail: %r,"
"file path: %s.", str(ex), self._summary_file_handler.file_path) "file path: %s.", str(ex), self._summary_file_handler.file_path)
return is_clean, is_end, event_data
return file_changed, is_end, event_data
except Exception as ex: except Exception as ex:
logger.exception(ex) logger.exception(ex)
raise UnknownError(str(ex)) raise UnknownError(str(ex))
finally:
self._latest_offset = self._summary_file_handler.offset


@staticmethod @staticmethod
def _event_decode(event_str): def _event_decode(event_str):


+ 4
- 3
tests/ut/explainer/manager/test_explain_loader.py View File

@@ -28,17 +28,18 @@ def abc():
FileHandler.is_file('aaa') FileHandler.is_file('aaa')
print('after') print('after')



class TestExplainLoader: class TestExplainLoader:
"""Test explain loader class.""" """Test explain loader class."""
@patch.object(ExplainParser, 'parse_explain')
@patch.object(ExplainParser, 'list_events')
@patch.object(FileHandler, 'list_dir') @patch.object(FileHandler, 'list_dir')
@patch.object(FileHandler, 'is_file') @patch.object(FileHandler, 'is_file')
@patch.object(os, 'stat') @patch.object(os, 'stat')
def test_stop(self, mock_stat, mock_is_file, mock_list_dir, mock_parse_explain):
def test_stop(self, mock_stat, mock_is_file, mock_list_dir, mock_list_events):
"""Test stop function.""" """Test stop function."""
mock_is_file.return_value = True mock_is_file.return_value = True
mock_list_dir.return_value = ['events.summary.123.host_explain'] mock_list_dir.return_value = ['events.summary.123.host_explain']
mock_parse_explain.return_value = (True, False, None)
mock_list_events.return_value = (True, False, None)


class _MockStat: class _MockStat:
def __init__(self, _): def __init__(self, _):


Loading…
Cancel
Save