You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_trace_parser.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """The parser for step trace data."""
  16. import csv
  17. import json
  18. import os
  19. import stat
  20. import struct
  21. from collections import namedtuple
  22. from decimal import Decimal
  23. from mindinsight.profiler.common.exceptions.exceptions import ProfilerPathErrorException, \
  24. JobIdMismatchException, ProfilerIOException
  25. from mindinsight.profiler.common.log import logger as log
  26. from mindinsight.profiler.common.util import get_summary_for_step_trace
  27. StepTraceStruct = namedtuple(
  28. 'TrainingTraceStruct', ['tag_id', 'task_id', 'stream_id', 'sys_count']
  29. )
  30. class StepTraceParser:
  31. """
  32. The parser for step trace data.
  33. Args:
  34. input_dir (str): The directory that contains original step trace data.
  35. output_file_path (str): The output file path.
  36. job_id (int): The job id used to define the start of new step. Default: 0.
  37. skip_first_step (bool): Whether skip the first step or not.
  38. """
  39. _event_size = 20
  40. _fp_tag = 1
  41. _bp_tag = 2
  42. def __init__(self, input_dir, output_file_path, job_id=0, skip_first_step=False):
  43. self._input_dir = input_dir
  44. self._output_path = output_file_path
  45. self._job_id = job_id
  46. self._skip_first_step = skip_first_step
  47. self._result = []
  48. self._header = []
  49. self._step_num = 0
  50. @property
  51. def output_file(self):
  52. """The property of step trace header."""
  53. file_name = self._output_path.rsplit('/', 2)
  54. return file_name[-1] if len(file_name) == 3 else ''
  55. def show(self):
  56. """The property of step trace info."""
  57. summary_info = {}
  58. if self._result:
  59. summary_info = get_summary_for_step_trace(self._result[-1], self._header)
  60. summary_info['total_steps'] = len(self._result) - 1
  61. print('\nStep trace summary info (unit: syscnt):')
  62. print(summary_info)
  63. print('\nThe step trace parse result saves under ${summary_dir}/profiler/%s'
  64. % self.output_file)
  65. def parse_and_save(self):
  66. """Parse step trace files and save the result."""
  67. try:
  68. source_files = self._get_step_trace_files()
  69. self._parse(source_files)
  70. self._save()
  71. except IOError as err:
  72. log.exception(err)
  73. raise ProfilerIOException()
  74. else:
  75. log.info("Finish to save intermediate result for step trace file.")
  76. def record_point_info(self, point_info, output_path):
  77. """
  78. Record point info into json.
  79. Args:
  80. point_info (dict): The point info about tag id and relative op name.
  81. output_path (str): The output path for saving point info.
  82. Returns:
  83. dict, parsed point info.
  84. """
  85. points = {
  86. 'fp_start': point_info.get(self._fp_tag, ''),
  87. 'bp_end': point_info.get(self._bp_tag, '')
  88. }
  89. try:
  90. with open(output_path, 'w') as json_file:
  91. json.dump(points, json_file)
  92. os.chmod(output_path, stat.S_IREAD)
  93. except (IOError, OSError) as err:
  94. log.warning('Failed to save point info. %s', err)
  95. raise ProfilerIOException
  96. return points
  97. def _get_step_trace_files(self):
  98. """Get step trace files."""
  99. # step trace files may under $profiler_dir or $profiler_dir/data
  100. profiler_dir = self._input_dir
  101. step_trace_files = self._search_file(profiler_dir)
  102. if not step_trace_files:
  103. # try to find step trace files under $profiler_dir/data
  104. profiler_dir = os.path.join(profiler_dir, 'data')
  105. step_trace_files = self._search_file(profiler_dir)
  106. if not step_trace_files:
  107. raise ProfilerPathErrorException('Training trace file does not exist.')
  108. return step_trace_files
  109. @staticmethod
  110. def _search_file(input_dir):
  111. """Search step trace file under specific input directory."""
  112. # validate input_dir
  113. if not os.path.isdir(input_dir):
  114. raise ProfilerPathErrorException(
  115. '{} does not exist or is not a dir'.format(input_dir)
  116. )
  117. # get step trace files
  118. files = os.listdir(input_dir)
  119. step_trace_files = list(
  120. filter(
  121. lambda file: file.startswith('training_trace') and not file.endswith('.done'),
  122. files
  123. )
  124. )
  125. # validate result
  126. if len(step_trace_files) > 1:
  127. # the format of file name is like
  128. # `training_trace.46.dev.profiler_default_tag.$id.slice_$number`
  129. # use the $number as the sorted key
  130. try:
  131. step_trace_files.sort(key=lambda path: int(path.rsplit('_', 1)[-1]))
  132. except ValueError as err:
  133. log.warning("Unable to parse file names: %s. %s", step_trace_files, err)
  134. step_trace_files = []
  135. file_paths = [os.path.join(input_dir, file) for file in step_trace_files]
  136. log.info("Find %d step trace files.", len(file_paths))
  137. return file_paths
  138. def _parse(self, source_files):
  139. """Parse source step trace files."""
  140. log.info("Start to parse step trace file.")
  141. event_info = {}
  142. for source_file in source_files:
  143. with open(source_file, 'rb') as handler:
  144. content = handler.read()
  145. for step_trace in self._get_next_step_trace(content, event_info):
  146. if self._skip_first_step:
  147. self._skip_first_step = False
  148. continue
  149. self._record_trace_event(step_trace)
  150. self._record_average_info()
  151. log.info("Finish to parse step trace file.")
  152. def _get_next_step_trace(self, content, event_info):
  153. """
  154. Get next step trace info.
  155. Args:
  156. content (bytes): The input step trace info.
  157. event_info (dict): The event info.
  158. Returns:
  159. Generator, return the step trace one by one.
  160. """
  161. for pos in range(0, len(content), 20):
  162. next_event = self._get_trace_struct(content[pos:pos + self._event_size])
  163. self._construct_event_info(next_event, event_info)
  164. if event_info.get('end'):
  165. yield event_info
  166. def _get_trace_struct(self, bin_info):
  167. """Translate event info to StepTraceStruct."""
  168. if len(bin_info) == self._event_size:
  169. parsed_info = struct.unpack('=QHHQ', bin_info)
  170. return StepTraceStruct(*parsed_info)
  171. return None
  172. def _construct_event_info(self, next_event, event_info):
  173. """Construct event info according to next_event."""
  174. min_job_id = 255
  175. step_flag: bool = lambda tag: tag > min_job_id or tag == 0
  176. end_flag: bool = lambda tag: tag == min_job_id
  177. fp_flag: bool = lambda tag: tag == self._fp_tag
  178. bp_flag: bool = lambda tag: tag == self._bp_tag
  179. def _on_step_event():
  180. """Handle step event."""
  181. self._validate_tag_id(tag_id)
  182. start_time = event_info.get('end', '-')
  183. event_info.clear()
  184. event_info['start'] = start_time
  185. event_info['reduce'] = {}
  186. def _on_reduce_event():
  187. """Handle reduce event."""
  188. stream_id = next_event.stream_id
  189. if event_info['reduce'].get(stream_id):
  190. event_info['reduce'][stream_id].append(sys_count)
  191. else:
  192. event_info['reduce'][stream_id] = [sys_count]
  193. tag_id = next_event.tag_id
  194. sys_count = next_event.sys_count
  195. if end_flag(tag_id):
  196. event_info['end'] = sys_count
  197. elif step_flag(tag_id):
  198. _on_step_event()
  199. elif fp_flag(tag_id):
  200. event_info['fp'] = sys_count
  201. elif bp_flag(tag_id):
  202. event_info['bp'] = sys_count
  203. else:
  204. _on_reduce_event()
  205. def _validate_tag_id(self, job_id):
  206. """Check the job id in source step trace file is same os user set."""
  207. if not self._job_id:
  208. self._job_id = job_id
  209. elif self._job_id != job_id:
  210. raise JobIdMismatchException()
  211. def _record_trace_event(self, step_trace):
  212. """Record trace event."""
  213. self._step_num += 1
  214. start_time = step_trace.get('start')
  215. end_time = step_trace.get('end')
  216. fp_time = step_trace.get('fp')
  217. bp_time = step_trace.get('bp')
  218. if not (start_time and end_time and fp_time and bp_time):
  219. log.warning("The step %d is missing basic time.", self._step_num)
  220. return
  221. if start_time == '-':
  222. start_time = fp_time
  223. row_data = {
  224. 'step_num': self._step_num,
  225. 'start_point': start_time,
  226. 'end_point': end_time,
  227. 'total': end_time - start_time,
  228. 'fp_point': fp_time,
  229. 'bp_point': bp_time,
  230. 'iteration_interval': fp_time - start_time,
  231. 'fp_and_bp': bp_time - fp_time,
  232. 'tail': end_time - bp_time
  233. }
  234. # update reduce info
  235. self._update_reduce_info(step_trace, row_data)
  236. # save the row data
  237. if not self._header:
  238. self._header = list(row_data.keys())
  239. row_data_list = [row_data.get(header_name, 0) for header_name in self._header]
  240. self._result.append(row_data_list)
  241. @staticmethod
  242. def _update_reduce_info(step_trace, row_data):
  243. """Extract reduce info."""
  244. reduce_time = step_trace.get('reduce', {})
  245. for stream_id, time_points in reduce_time.items():
  246. time_point_num = len(time_points)
  247. if time_point_num % 2:
  248. log.warning("Stream %d has %d reduce time points.", stream_id, time_point_num)
  249. continue
  250. for index, point_id in enumerate(range(0, time_point_num, 2)):
  251. field_name = f'stream_{stream_id}_parallel_{index}'
  252. row_data[field_name + '_start_point'] = time_points[point_id]
  253. row_data[field_name + '_end_point'] = time_points[point_id + 1]
  254. row_data[field_name] = time_points[point_id + 1] - time_points[point_id]
  255. def _record_average_info(self):
  256. """Calculate average info."""
  257. result_size = len(self._result)
  258. # calculate average data for each column in result data
  259. average_data = [0] * len(self._header)
  260. if result_size >= 2:
  261. for row_info in self._result[1:]:
  262. average_data = [
  263. Decimal(i) + Decimal(j) for i, j in zip(row_info, average_data)
  264. ]
  265. average_data = [
  266. round((item / (result_size - 1))) for item in average_data
  267. ]
  268. # change step num info in average_data to None
  269. step_num_index = self._header.index('step_num')
  270. average_data[step_num_index] = '-'
  271. self._result.append(average_data)
  272. log.info("Finish add average info for step trace.")
  273. def _save(self):
  274. log.info("Start to save step trace file.")
  275. if not self._header:
  276. return
  277. with open(self._output_path, 'w') as file_handle:
  278. csv_writer = csv.writer(file_handle)
  279. csv_writer.writerow(self._header)
  280. for row_data in self._result:
  281. csv_writer.writerow(row_data)
  282. os.chmod(self._output_path, stat.S_IREAD)