Browse Source

fix the bug of step_trace incorrect content in inference scene

tags/v1.1.0
gzhcv 5 years ago
parent
commit
44b683122e
3 changed files with 36 additions and 13 deletions
  1. +15
    -4
      mindinsight/profiler/analyser/step_trace_analyser.py
  2. +19
    -7
      mindinsight/profiler/common/util.py
  3. +2
    -2
      mindinsight/profiler/common/validator/validate.py

+ 15
- 4
mindinsight/profiler/analyser/step_trace_analyser.py View File

@@ -217,13 +217,24 @@ class StepTraceAnalyser(BaseAnalyser):
start_point = row_info_dict.get('start_point', 0) start_point = row_info_dict.get('start_point', 0)
fp_point = row_info_dict.get('fp_point', 0) fp_point = row_info_dict.get('fp_point', 0)
bp_point = row_info_dict.get('bp_point', 0) bp_point = row_info_dict.get('bp_point', 0)
points = [
points_part = [
self._construct_time_point( self._construct_time_point(
'iteration_interval', 0, row_info_dict.get('iteration_interval', 0)), 'iteration_interval', 0, row_info_dict.get('iteration_interval', 0)),
self._construct_time_point(
'fp_and_bp', fp_point - start_point, row_info_dict.get('fp_and_bp', 0)),
self._construct_time_point('tail', bp_point - start_point, row_info_dict.get('tail', 0))
] ]
# if fp key exist, inference scene
if 'fp' in row_info_dict.keys():
points = [
self._construct_time_point(
'fp', fp_point - start_point, row_info_dict.get('fp', 0)),
]
# training scene
else:
points = [
self._construct_time_point(
'fp_and_bp', fp_point - start_point, row_info_dict.get('fp_and_bp', 0)),
self._construct_time_point('tail', bp_point - start_point, row_info_dict.get('tail', 0))
]
points = points_part + points
return points return points


def _get_reduce_time_in_order(self, row_info_dict): def _get_reduce_time_in_order(self, row_info_dict):


+ 19
- 7
mindinsight/profiler/common/util.py View File

@@ -125,17 +125,29 @@ def get_summary_for_step_trace(average_info, header):
total_time = get_field_value(average_info, 'total', header) total_time = get_field_value(average_info, 'total', header)
iteration_interval = get_field_value(average_info, 'iteration_interval', iteration_interval = get_field_value(average_info, 'iteration_interval',
header) header)
fp_and_bp = get_field_value(average_info, 'fp_and_bp', header)
tail = get_field_value(average_info, 'tail', header)
summary = {
summary_part = {
'total_time': total_time, 'total_time': total_time,
'iteration_interval': iteration_interval, 'iteration_interval': iteration_interval,
'iteration_interval_percent': calculate_percent(iteration_interval, total_time), 'iteration_interval_percent': calculate_percent(iteration_interval, total_time),
'fp_and_bp': fp_and_bp,
'fp_and_bp_percent': calculate_percent(fp_and_bp, total_time),
'tail': tail,
'tail_percent': calculate_percent(tail, total_time)
} }
# training scene data for ui display
if 'fp_and_bp' in header:
fp_and_bp = get_field_value(average_info, 'fp_and_bp', header)
tail = get_field_value(average_info, 'tail', header)
summary = {
'fp_and_bp': fp_and_bp,
'fp_and_bp_percent': calculate_percent(fp_and_bp, total_time),
'tail': tail,
'tail_percent': calculate_percent(tail, total_time)
}
# inference scene data for ui display
else:
fp = get_field_value(average_info, 'fp', header)
summary = {
'fp': fp,
'fp_percent': calculate_percent(fp, total_time)
}
summary.update(summary_part)
return summary return summary


+ 2
- 2
mindinsight/profiler/common/validator/validate.py View File

@@ -250,12 +250,12 @@ def validate_ui_proc(proc_name):
Args: Args:
proc_name (str): The proc name to query. Acceptable value is in proc_name (str): The proc name to query. Acceptable value is in
[`iteration_interval`, `fp_and_bp`, `tail`].
[`iteration_interval`, `fp_and_bp`, `fp`, `tail`].
Raises: Raises:
ProfilerParamValueErrorException: If the proc_name is invalid. ProfilerParamValueErrorException: If the proc_name is invalid.
""" """
accept_names = ['iteration_interval', 'fp_and_bp', 'tail']
accept_names = ['iteration_interval', 'fp_and_bp', 'fp', 'tail']
if proc_name not in accept_names: if proc_name not in accept_names:
log.error("Invalid proc_name. The proc_name for restful api is in %s", accept_names) log.error("Invalid proc_name. The proc_name for restful api is in %s", accept_names)
raise ProfilerParamValueErrorException(f'proc_name should be in {accept_names}.') raise ProfilerParamValueErrorException(f'proc_name should be in {accept_names}.')


Loading…
Cancel
Save