You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

explain_parser.py 8.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. File parser for MindExplain data.
  17. This module is used to parse the MindExplain log file.
  18. """
  19. import re
  20. import collections
  21. from google.protobuf.message import DecodeError
  22. from mindinsight.datavisual.common import exceptions
  23. from mindinsight.explainer.common.enums import PluginNameEnum
  24. from mindinsight.explainer.common.log import logger
  25. from mindinsight.datavisual.data_access.file_handler import FileHandler
  26. from mindinsight.datavisual.data_transform.ms_data_loader import _SummaryParser
  27. from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
  28. from mindinsight.datavisual.proto_files.mindinsight_summary_pb2 import Explain
  29. from mindinsight.utils.exceptions import UnknownError
  30. HEADER_SIZE = 8
  31. CRC_STR_SIZE = 4
  32. MAX_EVENT_STRING = 500000000
  33. BenchmarkContainer = collections.namedtuple('BenchmarkContainer', ['benchmark', 'status'])
  34. MetadataContainer = collections.namedtuple('MetadataContainer', ['metadata', 'status'])
  35. class ImageDataContainer:
  36. """
  37. Container for image data to allow pickling.
  38. Args:
  39. explain_message (Explain): Explain proto buffer message.
  40. """
  41. def __init__(self, explain_message: Explain):
  42. self.sample_id = explain_message.sample_id
  43. self.image_path = explain_message.image_path
  44. self.ground_truth_label = explain_message.ground_truth_label
  45. self.inference = explain_message.inference
  46. self.explanation = explain_message.explanation
  47. self.status = explain_message.status
  48. class _ExplainParser(_SummaryParser):
  49. """The summary file parser."""
  50. def __init__(self, summary_dir):
  51. super(_ExplainParser, self).__init__(summary_dir)
  52. self._latest_filename = ''
  53. def parse_explain(self, filenames):
  54. """
  55. Load summary file and parse file content.
  56. Args:
  57. filenames (list[str]): File name list.
  58. Returns:
  59. bool, True if all the summary files are finished loading.
  60. """
  61. summary_files = self.filter_files(filenames)
  62. summary_files = self.sort_files(summary_files)
  63. is_end = False
  64. is_clean = False
  65. event_data = {}
  66. filename = summary_files[-1]
  67. file_path = FileHandler.join(self._summary_dir, filename)
  68. if filename != self._latest_filename:
  69. self._summary_file_handler = FileHandler(file_path, 'rb')
  70. self._latest_filename = filename
  71. self._latest_file_size = 0
  72. is_clean = True
  73. new_size = FileHandler.file_stat(file_path).size
  74. if new_size == self._latest_file_size:
  75. is_end = True
  76. return is_clean, is_end, event_data
  77. while True:
  78. start_offset = self._summary_file_handler.offset
  79. try:
  80. event_str = self._event_load(self._summary_file_handler)
  81. if event_str is None:
  82. self._summary_file_handler.reset_offset(start_offset)
  83. is_end = True
  84. return is_clean, is_end, event_data
  85. if len(event_str) > MAX_EVENT_STRING:
  86. logger.warning("file_path: %s, event string: %d exceeds %d and drop it.",
  87. self._summary_file_handler.file_path, len(event_str), MAX_EVENT_STRING)
  88. continue
  89. field_list, tensor_value_list = self._event_decode(event_str)
  90. for field, tensor_value in zip(field_list, tensor_value_list):
  91. event_data[field] = tensor_value
  92. logger.info("Parse summary file offset %d, file path: %s.", self._latest_file_size, file_path)
  93. return is_clean, is_end, event_data
  94. except exceptions.CRCFailedError:
  95. self._summary_file_handler.reset_offset(start_offset)
  96. is_end = True
  97. logger.warning("Check crc failed and ignore this file, file_path=%s, "
  98. "offset=%s.", self._summary_file_handler.file_path, self._summary_file_handler.offset)
  99. return is_clean, is_end, event_data
  100. except (OSError, DecodeError, exceptions.MindInsightException) as ex:
  101. is_end = True
  102. logger.warning("Parse log file fail, and ignore this file, detail: %r,"
  103. "file path: %s.", str(ex), self._summary_file_handler.file_path)
  104. return is_clean, is_end, event_data
  105. except Exception as ex:
  106. logger.exception(ex)
  107. raise UnknownError(str(ex))
  108. def filter_files(self, filenames):
  109. """
  110. Gets a list of summary files.
  111. Args:
  112. filenames (list[str]): File name list, like [filename1, filename2].
  113. Returns:
  114. list[str], filename list.
  115. """
  116. return list(filter(
  117. lambda filename: (re.search(r'summary\.\d+', filename)
  118. and filename.endswith("_explain")), filenames))
  119. @staticmethod
  120. def _event_decode(event_str):
  121. """
  122. Transform `Event` data to tensor_event and update it to EventsData.
  123. Args:
  124. event_str (str): Message event string in summary proto, data read from file handler.
  125. """
  126. logger.debug("Start to parse event string. Event string len: %s.", len(event_str))
  127. event = summary_pb2.Event.FromString(event_str)
  128. logger.debug("Deserialize event string completed.")
  129. fields = {
  130. 'sample_id': PluginNameEnum.SAMPLE_ID,
  131. 'benchmark': PluginNameEnum.BENCHMARK,
  132. 'metadata': PluginNameEnum.METADATA
  133. }
  134. tensor_event_value = getattr(event, 'explain')
  135. field_list = []
  136. tensor_value_list = []
  137. for field in fields:
  138. if not getattr(tensor_event_value, field):
  139. continue
  140. if PluginNameEnum.METADATA.value == field and not tensor_event_value.metadata.label:
  141. continue
  142. tensor_value = None
  143. if field == PluginNameEnum.SAMPLE_ID.value:
  144. tensor_value = _ExplainParser._add_image_data(tensor_event_value)
  145. elif field == PluginNameEnum.BENCHMARK.value:
  146. tensor_value = _ExplainParser._add_benchmark(tensor_event_value)
  147. elif field == PluginNameEnum.METADATA.value:
  148. tensor_value = _ExplainParser._add_metadata(tensor_event_value)
  149. logger.debug("Event generated, label is %s, step is %s.", field, event.step)
  150. field_list.append(field)
  151. tensor_value_list.append(tensor_value)
  152. return field_list, tensor_value_list
  153. @staticmethod
  154. def _add_image_data(tensor_event_value):
  155. """
  156. Parse image data based on sample_id in Explain message
  157. Args:
  158. tensor_event_value: the object of Explain message
  159. """
  160. image_data = ImageDataContainer(tensor_event_value)
  161. return image_data
  162. @staticmethod
  163. def _add_benchmark(tensor_event_value):
  164. """
  165. Parse benchmark data from Explain message.
  166. Args:
  167. tensor_event_value: the object of Explain message
  168. Returns:
  169. benchmark_data: An object containing benchmark.
  170. """
  171. benchmark_data = BenchmarkContainer(
  172. benchmark=tensor_event_value.benchmark,
  173. status=tensor_event_value.status
  174. )
  175. return benchmark_data
  176. @staticmethod
  177. def _add_metadata(tensor_event_value):
  178. """
  179. Parse metadata from Explain message.
  180. Args:
  181. tensor_event_value: the object of Explain message
  182. Returns:
  183. benchmark_data: An object containing metadata.
  184. """
  185. metadata_value = MetadataContainer(
  186. metadata=tensor_event_value.metadata,
  187. status=tensor_event_value.status
  188. )
  189. return metadata_value