You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_trace_analyser.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """The StepTraceAnalyser analyser class."""
  16. import csv
  17. import json
  18. import os
  19. from mindinsight.datavisual.utils.tools import to_int
  20. from mindinsight.profiler.analyser.base_analyser import BaseAnalyser
  21. from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamValueErrorException, \
  22. ProfilerFileNotFoundException, StepNumNotSupportedException, ProfilerRawFileException
  23. from mindinsight.profiler.common.log import logger as log
  24. from mindinsight.profiler.common.util import query_latest_trace_time_file, get_field_value, \
  25. get_summary_for_step_trace, to_millisecond
  26. class StepTraceAnalyser(BaseAnalyser):
  27. """The analyser for analyzing training steps."""
  28. _col_names = []
  29. _attr_ui_name = 'name'
  30. _attr_ui_start = 'start'
  31. _attr_ui_duration = 'duration'
  32. _point_info = {}
  33. @property
  34. def summary(self):
  35. """The property of summary info."""
  36. summary = get_summary_for_step_trace(self._data[-1], self.__column__)
  37. summary['total_steps'] = self._size
  38. return summary
  39. @property
  40. def point_info(self):
  41. """The property of point info."""
  42. return self._point_info
  43. def query(self, condition=None):
  44. """
  45. Query data according to the condition.
  46. Args:
  47. condition (dict): The search condition, only contains `filter_condition` parameter.
  48. Default: None.
  49. Returns:
  50. dict, the result after filtered, sorted and grouped.
  51. """
  52. if condition is None:
  53. condition = {}
  54. filter_condition = condition.get('filter_condition', {})
  55. log.info("Receive query request. %s", filter_condition)
  56. self._validate_filter_condition(filter_condition)
  57. self._result = {'size': self._size}
  58. self._filter(filter_condition)
  59. return self._result
  60. def query_for_all_reduce(self, min_cycle_counter):
  61. """
  62. Query for all reduce info.
  63. Returns:
  64. list[dict], reduce information. Each item is the reduce info for one step.
  65. The reduce info is format like:
  66. {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}.
  67. """
  68. reduce_infos = []
  69. for row_info in self._data[:-1]:
  70. row_info_dict = self._get_info_dict_from_row_data(row_info, 'systime')
  71. reduce_info = self._sort_reduce_by_time(row_info_dict, min_cycle_counter)
  72. if reduce_info:
  73. reduce_infos.append(reduce_info)
  74. return reduce_infos
  75. def _load(self):
  76. """Load data according to the parsed AICORE operator types file."""
  77. file_path = query_latest_trace_time_file(self._profiling_dir, self._device_id)
  78. if not file_path:
  79. log.error("Failed to find parsed trace time file.")
  80. raise ProfilerFileNotFoundException('parsed step trace time file')
  81. with open(file_path, 'r') as handle:
  82. csv_reader = csv.reader(handle)
  83. self.__column__ = next(csv_reader)
  84. self._data = list(csv_reader)
  85. self._size = len(self._data) - 1
  86. self._display_col_names = self._col_names[:]
  87. self._load_point_info()
  88. def _load_point_info(self):
  89. """Load point info."""
  90. file_path = os.path.join(self._profiling_dir, 'step_trace_point_info.json')
  91. if os.path.isfile(file_path):
  92. with open(file_path, 'r', encoding='utf-8') as file:
  93. try:
  94. self._point_info = json.load(file)
  95. except (json.JSONDecodeError, TypeError) as err:
  96. log.exception(err)
  97. raise ProfilerRawFileException('Fail to parse point info file.')
  98. def _filter(self, filter_condition):
  99. """
  100. Filter the profiling data according to the filter condition.
  101. Args:
  102. filter_condition (dict): The filter condition.
  103. - mode (str): The kind of information. `step` return the info about specific
  104. step. `proc` return the info about specific field in parsed trace file.
  105. - step_id (int): The selected step_id. If not given, it means all steps is required.
  106. If the value is 0, it means average info for all steps except the first is
  107. required.
  108. - proc_name (str): The selected field name.
  109. - time_type (str): The value type. `systime` keeps the original value.
  110. `realtime` transforms the value in millisecond. Default: `realtime`.
  111. """
  112. mode = filter_condition.get('mode', 'step')
  113. if mode == 'step':
  114. self._get_step_details(step_id=filter_condition.get('step_id'),
  115. time_type=filter_condition.get('time_type', 'realtime'))
  116. else:
  117. self._get_proc_details(step_id=filter_condition.get('step_id'),
  118. proc_name=filter_condition.get('proc_name'),
  119. time_type=filter_condition.get('time_type', 'realtime'))
  120. def _construct_time_point(self, name, start, duration):
  121. """Construct time point."""
  122. point = {}
  123. if start >= 0 and duration >= 0:
  124. point = {
  125. self._attr_ui_name: name,
  126. self._attr_ui_start: round(start, 4),
  127. self._attr_ui_duration: round(duration, 4)
  128. }
  129. else:
  130. log.warning("Not invalid point info: "
  131. "name: %s, start: %s, duration: %s", name, start, duration)
  132. return point
  133. def _get_step_details(self, step_id, time_type='realtime'):
  134. """
  135. Get step trace info for selected step and save the result.
  136. Args:
  137. step_id (int): The selected step_id. If the value is 0, it means average info
  138. for all steps except the first is required.
  139. time_type (str): The value type. `systime` keeps the original value.
  140. `realtime` transforms the value in millisecond. Default: `realtime`.
  141. """
  142. if step_id is None:
  143. step_id = 0
  144. row_info = self._data[step_id - 1]
  145. row_info_dict = self._get_info_dict_from_row_data(row_info, time_type)
  146. # first line only contains total time
  147. first_line = [self._construct_time_point('', 0, row_info_dict.get('total', 0))]
  148. # second line contains iteration_interval, fp_and_bp and tail
  149. second_line = self._get_main_proc_points(row_info_dict)
  150. # construct reduces lines
  151. reduce_lines = self._construct_reduce_lines(row_info_dict)
  152. graph = [first_line, second_line]
  153. graph.extend(reduce_lines)
  154. self._result['training_trace_graph'] = graph
  155. def _get_info_dict_from_row_data(self, row_info, time_type):
  156. """
  157. Get step info in dict format.
  158. Args:
  159. row_info (list[str]): Step info, the value is corresponding to `__column__`.
  160. time_type (str): The value type. `systime` keeps the original value.
  161. `realtime` transforms the value in millisecond. Default: `realtime`.
  162. Returns:
  163. dict, step trace information. The key is in `__column__`.
  164. """
  165. row_info_dict = {}
  166. for key, value in zip(self.__column__, row_info):
  167. if key == 'step_num':
  168. continue
  169. value = to_int(value, key)
  170. row_info_dict[key] = to_millisecond(value) if time_type == 'realtime' else value
  171. return row_info_dict
  172. def _get_main_proc_points(self, row_info_dict):
  173. """
  174. Get iteration_interval, fp_and_bp and tail points.
  175. Args:
  176. row_info_dict (dict): Step trace information.
  177. Returns:
  178. list[dict], the list of time points.
  179. """
  180. start_point = row_info_dict.get('start_point', 0)
  181. fp_point = row_info_dict.get('fp_point', 0)
  182. bp_point = row_info_dict.get('bp_point', 0)
  183. points = [
  184. self._construct_time_point(
  185. 'iteration_interval', 0, row_info_dict.get('iteration_interval', 0)),
  186. self._construct_time_point(
  187. 'fp_and_bp', fp_point - start_point, row_info_dict.get('fp_and_bp', 0)),
  188. self._construct_time_point('tail', bp_point - start_point, row_info_dict.get('tail', 0))
  189. ]
  190. return points
  191. def _get_reduce_time_in_order(self, row_info_dict):
  192. """
  193. Get reduce time in order.
  194. Args:
  195. row_info_dict (dict): Step trace information.
  196. Returns:
  197. dict, sorted reduce information. The reduce info is format like:
  198. {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}
  199. """
  200. reduce_info = {}
  201. reduce_fields = [field_name for field_name in self.__column__
  202. if field_name.startswith('stream_') and not field_name.endswith('point')]
  203. for reduce_field in reduce_fields:
  204. reduce_start = row_info_dict.get(reduce_field + '_start_point', 0)
  205. reduce_end = row_info_dict.get(reduce_field + '_end_point', 0)
  206. reduce_duration = row_info_dict.get(reduce_field, 0)
  207. if not (reduce_start and reduce_end and reduce_duration):
  208. log.info("Reduce event missing value.")
  209. continue
  210. cur_stream_id = reduce_field.split('_', 2)[1]
  211. cur_stream = reduce_info.get(cur_stream_id)
  212. if not cur_stream:
  213. cur_stream = []
  214. reduce_info[cur_stream_id] = cur_stream
  215. cur_stream.append((reduce_start, reduce_end, reduce_duration, reduce_field))
  216. for _, reduce_events in reduce_info.items():
  217. reduce_events.sort(key=lambda elem: elem[1])
  218. return reduce_info
  219. def _sort_reduce_by_time(self, row_info_dict, min_cycle_counter):
  220. """
  221. Sort reduce info by time.
  222. Args:
  223. row_info_dict (dict): Step trace information.
  224. min_cycle_counter (int): The minimum cycle counter.
  225. Returns:
  226. list, including the all reduce info sorted by start time only.
  227. [
  228. [reduce_field, stream_id, reduce_start, reduce_duration],
  229. [...],
  230. [...]
  231. ]
  232. """
  233. factor = 1e5 # convert time unit from 10ns to 1ms
  234. reduce_pid = 10000
  235. reduce_info = []
  236. reduce_fields = [field_name for field_name in self.__column__
  237. if field_name.startswith('stream_') and not field_name.endswith('point')]
  238. for reduce_field in reduce_fields:
  239. reduce_start = row_info_dict.get(reduce_field + '_start_point')
  240. reduce_start = (reduce_start - min_cycle_counter) / factor \
  241. if reduce_start else 0
  242. reduce_duration = row_info_dict.get(reduce_field)
  243. reduce_duration = reduce_duration / factor if reduce_duration else 0
  244. if not (reduce_start and reduce_duration):
  245. log.info("Reduce event missing value.")
  246. continue
  247. cur_stream_id = reduce_field.split('_', 2)[1]
  248. reduce_info = [reduce_field, int(cur_stream_id), reduce_start,
  249. reduce_duration, reduce_pid]
  250. return reduce_info
  251. def _construct_reduce_lines(self, row_info_dict):
  252. """
  253. Contruct first line in detailed graph.
  254. Args:
  255. row_info_dict (dict): Step trace information.
  256. Returns:
  257. list, list of reduce information of each stream. Each item is a list of time points.
  258. """
  259. reduce_lines = []
  260. start_point = row_info_dict.get('start_point', 0)
  261. fp_point = row_info_dict.get('fp_point', 0)
  262. end_point = row_info_dict.get('end_point', 0)
  263. reduce_info = self._get_reduce_time_in_order(row_info_dict)
  264. # construct time point for each line
  265. for _, reduce_events in reduce_info.items():
  266. current_line = self._construct_reduce_line(
  267. start_point, end_point, fp_point, reduce_events)
  268. reduce_lines.append(current_line)
  269. return reduce_lines
  270. def _construct_reduce_line(self, start_point, end_point, fp_point, reduce_events):
  271. """
  272. Construct list of time points for reduce line.
  273. Args:
  274. start_point (int): The start point of current step.
  275. end_point (int): The end point of current step.
  276. fp_point (int): The fp point of current step.
  277. reduce_events (list[Tuple]): The reduce information of current step. Each item
  278. contains the start, end duration and name of one reduce event.
  279. Returns:
  280. list[dict], list of time points.
  281. """
  282. current_line = []
  283. previous_start = fp_point
  284. for start, end, duration, field_name in reduce_events:
  285. current_line.extend([
  286. self._construct_time_point(
  287. '', previous_start - start_point, start - previous_start),
  288. self._construct_time_point(
  289. field_name, start - start_point, duration)
  290. ])
  291. previous_start = end
  292. current_line.append(self._construct_time_point(
  293. '', previous_start - start_point, end_point - previous_start))
  294. return current_line
  295. def _get_proc_details(self, proc_name, step_id=None, time_type='realtime'):
  296. """
  297. Get step trace info for selected step and save the result.
  298. Args:
  299. proc_name (str): The selected field name.
  300. step_id (int): The selected step_id. If not given, it means all steps is required.
  301. If the value is 0, it means average info for all steps except the first is
  302. required. Default: None.
  303. time_type (str): The value type. `systime` keeps the original value.
  304. `realtime` transforms the value in millisecond. Default: `realtime`.
  305. """
  306. if proc_name is None:
  307. log.error('`proc_name` is required for query.')
  308. raise ProfilerParamValueErrorException('`proc_name` is required for query.')
  309. if step_id is None:
  310. rows_info = self._data[:-1]
  311. else:
  312. rows_info = [self._data[step_id - 1]]
  313. proc_info = [get_field_value(row_info, proc_name, self.__column__, time_type)
  314. for row_info in rows_info]
  315. self._result['info'] = {proc_name: proc_info}
  316. def _validate_filter_condition(self, filter_condition):
  317. """Validate step trace filter_condition."""
  318. mode = filter_condition.get('mode', 'step')
  319. self._validate_str_param(mode, ['step', 'proc'], 'mode')
  320. step_id = filter_condition.get('step_id')
  321. self._validate_step_id(step_id)
  322. proc_name = filter_condition.get('proc_name')
  323. self._validate_str_param(proc_name, self.__column__, 'proc_name')
  324. time_type = filter_condition.get('time_type', 'realtime')
  325. self._validate_str_param(time_type, ['realtime', 'systime'], 'time_type')
  326. def _validate_step_id(self, step_id):
  327. """Validate step_id."""
  328. if step_id is None or isinstance(step_id, int) and 0 <= step_id <= self._size:
  329. return
  330. log.error("Invalid step_id in request. step_id should be in [0, %d].", self._size)
  331. raise StepNumNotSupportedException([0, self._size])
  332. @staticmethod
  333. def _validate_str_param(proc_name, accept_param, error_name=''):
  334. """Validate proc_name."""
  335. if proc_name is None or isinstance(proc_name, str) and proc_name in accept_param:
  336. return
  337. log.error("Invalid param %s in request. Acceptable value is %s.", error_name, accept_param)
  338. raise ProfilerParamValueErrorException(f"Invalid {error_name}.")