You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

step_trace_analyser.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """The StepTraceAnalyser analyser class."""
  16. import csv
  17. from mindinsight.datavisual.utils.tools import to_int
  18. from mindinsight.profiler.analyser.base_analyser import BaseAnalyser
  19. from mindinsight.profiler.common.exceptions.exceptions import ProfilerParamValueErrorException, \
  20. ProfilerFileNotFoundException, StepNumNotSupportedException
  21. from mindinsight.profiler.common.log import logger as log
  22. from mindinsight.profiler.common.util import query_latest_trace_time_file, get_field_value, \
  23. get_summary_for_step_trace, to_millisecond
  24. class StepTraceAnalyser(BaseAnalyser):
  25. """The analyser for analyzing training steps."""
  26. _col_names = []
  27. _attr_ui_name = 'name'
  28. _attr_ui_start = 'start'
  29. _attr_ui_duration = 'duration'
  30. @property
  31. def summary(self):
  32. """The property of summary info."""
  33. summary = get_summary_for_step_trace(self._data[-1], self.__column__)
  34. summary['total_steps'] = self._size
  35. return summary
  36. def query(self, condition=None):
  37. """
  38. Query data according to the condition.
  39. Args:
  40. condition (dict): The search condition, only contains `filter_condition` parameter.
  41. Default: None.
  42. Returns:
  43. dict, the result after filtered, sorted and grouped.
  44. """
  45. if condition is None:
  46. condition = {}
  47. filter_condition = condition.get('filter_condition', {})
  48. log.info("Receive query request. %s", filter_condition)
  49. self._validate_filter_condition(filter_condition)
  50. self._result = {'size': self._size}
  51. self._filter(filter_condition)
  52. return self._result
  53. def query_for_all_reduce(self):
  54. """
  55. Query for all reduce info.
  56. Returns:
  57. list[dict], reduce information. Each item is the reduce info for one step.
  58. The reduce info is format like:
  59. {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}.
  60. """
  61. reduce_infos = []
  62. for row_info in self._data[:-1]:
  63. row_info_dict = self._get_info_dict_from_row_data(row_info, 'systime')
  64. reduce_info = self._get_reduce_time_in_order(row_info_dict)
  65. reduce_infos.append(reduce_info)
  66. return reduce_infos
  67. def _load(self):
  68. """Load data according to the parsed AICORE operator types file."""
  69. file_path = query_latest_trace_time_file(self._profiling_dir, self._device_id)
  70. if not file_path:
  71. log.error("Failed to find parsed trace time file.")
  72. raise ProfilerFileNotFoundException('parsed step trace time file')
  73. with open(file_path, 'r') as handle:
  74. csv_reader = csv.reader(handle)
  75. self.__column__ = next(csv_reader)
  76. self._data = list(csv_reader)
  77. self._size = len(self._data) - 1
  78. self._display_col_names = self._col_names[:]
  79. def _filter(self, filter_condition):
  80. """
  81. Filter the profiling data according to the filter condition.
  82. Args:
  83. filter_condition (dict): The filter condition.
  84. - mode (str): The kind of information. `step` return the info about specific
  85. step. `proc` return the info about specific field in parsed trace file.
  86. - step_id (int): The selected step_id. If not given, it means all steps is required.
  87. If the value is 0, it means average info for all steps except the first is
  88. required.
  89. - proc_name (str): The selected field name.
  90. - time_type (str): The value type. `systime` keeps the original value.
  91. `realtime` transforms the value in millisecond. Default: `realtime`.
  92. """
  93. mode = filter_condition.get('mode', 'step')
  94. if mode == 'step':
  95. self._get_step_details(step_id=filter_condition.get('step_id'),
  96. time_type=filter_condition.get('time_type', 'realtime'))
  97. else:
  98. self._get_proc_details(step_id=filter_condition.get('step_id'),
  99. proc_name=filter_condition.get('proc_name'),
  100. time_type=filter_condition.get('time_type', 'realtime'))
  101. def _construct_time_point(self, name, start, duration):
  102. """Construct time point."""
  103. point = {}
  104. if start >= 0 and duration >= 0:
  105. point = {
  106. self._attr_ui_name: name,
  107. self._attr_ui_start: round(start, 4),
  108. self._attr_ui_duration: round(duration, 4)
  109. }
  110. else:
  111. log.warning("Not invalid point info: "
  112. "name: %s, start: %s, duration: %s", name, start, duration)
  113. return point
  114. def _get_step_details(self, step_id, time_type='realtime'):
  115. """
  116. Get step trace info for selected step and save the result.
  117. Args:
  118. step_id (int): The selected step_id. If the value is 0, it means average info
  119. for all steps except the first is required.
  120. time_type (str): The value type. `systime` keeps the original value.
  121. `realtime` transforms the value in millisecond. Default: `realtime`.
  122. """
  123. if step_id is None:
  124. step_id = 0
  125. row_info = self._data[step_id - 1]
  126. row_info_dict = self._get_info_dict_from_row_data(row_info, time_type)
  127. # first line only contains total time
  128. first_line = [self._construct_time_point('', 0, row_info_dict.get('total', 0))]
  129. # second line contains iteration_interval, fp_and_bp and tail
  130. second_line = self._get_main_proc_points(row_info_dict)
  131. # construct reduces lines
  132. reduce_lines = self._construct_reduce_lines(row_info_dict)
  133. graph = [first_line, second_line]
  134. graph.extend(reduce_lines)
  135. self._result['training_trace_graph'] = graph
  136. def _get_info_dict_from_row_data(self, row_info, time_type):
  137. """
  138. Get step info in dict format.
  139. Args:
  140. row_info (list[str]): Step info, the value is corresponding to `__column__`.
  141. time_type (str): The value type. `systime` keeps the original value.
  142. `realtime` transforms the value in millisecond. Default: `realtime`.
  143. Returns:
  144. dict, step trace information. The key is in `__column__`.
  145. """
  146. row_info_dict = {}
  147. for key, value in zip(self.__column__, row_info):
  148. if key == 'step_num':
  149. continue
  150. value = to_int(value, key)
  151. row_info_dict[key] = to_millisecond(value) if time_type == 'realtime' else value
  152. return row_info_dict
  153. def _get_main_proc_points(self, row_info_dict):
  154. """
  155. Get iteration_interval, fp_and_bp and tail points.
  156. Args:
  157. row_info_dict (dict): Step trace information.
  158. Returns:
  159. list[dict], the list of time points.
  160. """
  161. start_point = row_info_dict.get('start_point', 0)
  162. fp_point = row_info_dict.get('fp_point', 0)
  163. bp_point = row_info_dict.get('bp_point', 0)
  164. points = [
  165. self._construct_time_point(
  166. 'iteration_interval', 0, row_info_dict.get('iteration_interval', 0)),
  167. self._construct_time_point(
  168. 'fp_and_bp', fp_point - start_point, row_info_dict.get('fp_and_bp', 0)),
  169. self._construct_time_point('tail', bp_point - start_point, row_info_dict.get('tail', 0))
  170. ]
  171. return points
  172. def _get_reduce_time_in_order(self, row_info_dict):
  173. """
  174. Get reduce time in order.
  175. Args:
  176. row_info_dict (dict): Step trace information.
  177. Returns:
  178. dict, sorted reduce information. The reduce info is format like:
  179. {stream_id: List[Tuple(start_point, end_point, duration, field_name)]}
  180. """
  181. reduce_info = {}
  182. reduce_fields = [field_name for field_name in self.__column__
  183. if field_name.startswith('stream_') and not field_name.endswith('point')]
  184. for reduce_field in reduce_fields:
  185. reduce_start = row_info_dict.get(reduce_field + '_start_point', 0)
  186. reduce_end = row_info_dict.get(reduce_field + '_end_point', 0)
  187. reduce_duration = row_info_dict.get(reduce_field, 0)
  188. if not (reduce_start and reduce_end and reduce_duration):
  189. log.info("Reduce event missing value.")
  190. continue
  191. cur_stream_id = reduce_field.split('_', 2)[1]
  192. cur_stream = reduce_info.get(cur_stream_id)
  193. if not cur_stream:
  194. cur_stream = []
  195. reduce_info[cur_stream_id] = cur_stream
  196. cur_stream.append((reduce_start, reduce_end, reduce_duration, reduce_field))
  197. for _, reduce_events in reduce_info.items():
  198. reduce_events.sort(key=lambda elem: elem[1])
  199. return reduce_info
  200. def _construct_reduce_lines(self, row_info_dict):
  201. """
  202. Contruct first line in detailed graph.
  203. Args:
  204. row_info_dict (dict): Step trace information.
  205. Returns:
  206. list, list of reduce information of each stream. Each item is a list of time points.
  207. """
  208. reduce_lines = []
  209. start_point = row_info_dict.get('start_point', 0)
  210. fp_point = row_info_dict.get('fp_point', 0)
  211. end_point = row_info_dict.get('end_point', 0)
  212. reduce_info = self._get_reduce_time_in_order(row_info_dict)
  213. # construct time point for each line
  214. for _, reduce_events in reduce_info.items():
  215. current_line = self._construct_reduce_line(
  216. start_point, end_point, fp_point, reduce_events)
  217. reduce_lines.append(current_line)
  218. return reduce_lines
  219. def _construct_reduce_line(self, start_point, end_point, fp_point, reduce_events):
  220. """
  221. Construct list of time points for reduce line.
  222. Args:
  223. start_point (int): The start point of current step.
  224. end_point (int): The end point of current step.
  225. fp_point (int): The fp point of current step.
  226. reduce_events (list[Tuple]): The reduce information of current step. Each item
  227. contains the start, end duration and name of one reduce event.
  228. Returns:
  229. list[dict], list of time points.
  230. """
  231. current_line = []
  232. previous_start = fp_point
  233. for start, end, duration, field_name in reduce_events:
  234. current_line.extend([
  235. self._construct_time_point(
  236. '', previous_start - start_point, start - previous_start),
  237. self._construct_time_point(
  238. field_name, start - start_point, duration)
  239. ])
  240. previous_start = end
  241. current_line.append(self._construct_time_point(
  242. '', previous_start - start_point, end_point - previous_start))
  243. return current_line
  244. def _get_proc_details(self, proc_name, step_id=None, time_type='realtime'):
  245. """
  246. Get step trace info for selected step and save the result.
  247. Args:
  248. proc_name (str): The selected field name.
  249. step_id (int): The selected step_id. If not given, it means all steps is required.
  250. If the value is 0, it means average info for all steps except the first is
  251. required. Default: None.
  252. time_type (str): The value type. `systime` keeps the original value.
  253. `realtime` transforms the value in millisecond. Default: `realtime`.
  254. """
  255. if proc_name is None:
  256. log.error('`proc_name` is required for query.')
  257. raise ProfilerParamValueErrorException('`proc_name` is required for query.')
  258. if step_id is None:
  259. rows_info = self._data[:-1]
  260. else:
  261. rows_info = [self._data[step_id - 1]]
  262. proc_info = [get_field_value(row_info, proc_name, self.__column__, time_type)
  263. for row_info in rows_info]
  264. self._result['info'] = {proc_name: proc_info}
  265. def _validate_filter_condition(self, filter_condition):
  266. """Validate step trace filter_condition."""
  267. mode = filter_condition.get('mode', 'step')
  268. self._validate_str_param(mode, ['step', 'proc'], 'mode')
  269. step_id = filter_condition.get('step_id')
  270. self._validate_step_id(step_id)
  271. proc_name = filter_condition.get('proc_name')
  272. self._validate_str_param(proc_name, self.__column__, 'proc_name')
  273. time_type = filter_condition.get('time_type', 'realtime')
  274. self._validate_str_param(time_type, ['realtime', 'systime'], 'time_type')
  275. def _validate_step_id(self, step_id):
  276. """Validate step_id."""
  277. if step_id is None or isinstance(step_id, int) and 0 <= step_id <= self._size:
  278. return
  279. log.error("Invalid step_id in request. step_id should be in [0, %d].", self._size)
  280. raise StepNumNotSupportedException([0, self._size])
  281. @staticmethod
  282. def _validate_str_param(proc_name, accept_param, error_name=''):
  283. """Validate proc_name."""
  284. if proc_name is None or isinstance(proc_name, str) and proc_name in accept_param:
  285. return
  286. log.error("Invalid param %s in request. Acceptable value is %s.", error_name, accept_param)
  287. raise ProfilerParamValueErrorException(f"Invalid {error_name}.")