diff --git a/mindinsight/backend/profiler/profile_api.py b/mindinsight/backend/profiler/profile_api.py index be3199d3..73eb6262 100644 --- a/mindinsight/backend/profiler/profile_api.py +++ b/mindinsight/backend/profiler/profile_api.py @@ -509,7 +509,7 @@ def get_memory_usage_summary(): check_train_job_and_profiler_dir(profiler_dir_abs) device_id = request.args.get("device_id", default='0') - _ = to_int(device_id, 'device_id') + to_int(device_id, 'device_id') device_type = request.args.get("device_type", default='ascend') if device_type not in ['ascend']: logger.info("Invalid device_type, Memory Usage only supports Ascend for now.") @@ -538,7 +538,7 @@ def get_memory_usage_graphics(): check_train_job_and_profiler_dir(profiler_dir_abs) device_id = request.args.get("device_id", default='0') - _ = to_int(device_id, 'device_id') + to_int(device_id, 'device_id') device_type = request.args.get("device_type", default='ascend') if device_type not in ['ascend']: logger.info("Invalid device_type, Memory Usage only supports Ascend for now.") @@ -551,6 +551,38 @@ def get_memory_usage_graphics(): return graphics +@BLUEPRINT.route("/profile/memory-breakdowns", methods=["GET"]) +def get_memory_usage_breakdowns(): + """ + Get memory breakdowns of each node. + + Returns: + Response, the memory breakdowns for each node. + + Examples: + >>> GET http://xxxx/v1/mindinsight/profile/memory-breakdowns + """ + summary_dir = request.args.get("dir") + profiler_dir_abs = validate_and_normalize_profiler_path(summary_dir, settings.SUMMARY_BASE_DIR) + check_train_job_and_profiler_dir(profiler_dir_abs) + + device_id = request.args.get("device_id", default='0') + to_int(device_id, 'device_id') + device_type = request.args.get("device_type", default='ascend') + graph_id = request.args.get("graph_id", default='0') + node_id = request.args.get("node_id", default='0') + node_id = to_int(node_id, 'node_id') + if device_type not in ['ascend']: + logger.error("Invalid device_type, Memory Usage only supports Ascend for now.") + raise ParamValueError("Invalid device_type.") + + analyser = AnalyserFactory.instance().get_analyser( + 'memory_usage', profiler_dir_abs, device_id) + breakdowns = analyser.get_memory_usage_breakdowns(device_type, graph_id, node_id) + + return breakdowns + + def init_module(app): """ Init module entry. diff --git a/mindinsight/profiler/analyser/memory_usage_analyser.py b/mindinsight/profiler/analyser/memory_usage_analyser.py index 70c614f6..fc2b6103 100644 --- a/mindinsight/profiler/analyser/memory_usage_analyser.py +++ b/mindinsight/profiler/analyser/memory_usage_analyser.py @@ -18,7 +18,8 @@ import json import os from mindinsight.profiler.analyser.base_analyser import BaseAnalyser -from mindinsight.profiler.common.exceptions.exceptions import ProfilerIOException +from mindinsight.profiler.common.exceptions.exceptions import ProfilerIOException, \ + ProfilerFileNotFoundException from mindinsight.profiler.common.log import logger from mindinsight.profiler.common.validator.validate_path import validate_and_normalize_path from mindinsight.utils.exceptions import ParamValueError @@ -72,10 +73,38 @@ class MemoryUsageAnalyser(BaseAnalyser): json, the content of memory usage data. """ memory_details = self._get_file_content(device_type, FileType.DETAILS.value) - self._process_memory_details(memory_details) + for graph_id in memory_details.keys(): + if 'breakdowns' in memory_details[graph_id]: + memory_details[graph_id].pop('breakdowns') return memory_details + def get_memory_usage_breakdowns(self, device_type, graph_id, node_id): + """ + Get memory usage breakdowns for each node. + + Args: + device_type (str): Device type, e.g., GPU, Ascend. + graph_id (int): Graph id. + node_id (int): Node id. + + Returns: + json, the content of memory usage breakdowns. + """ + memory_details = self._get_file_content(device_type, FileType.DETAILS.value) + if graph_id not in memory_details: + logger.error('Invalid graph id: %s', graph_id) + raise ParamValueError('Invalid graph id.') + + graph = memory_details[graph_id] + if not ('breakdowns' in graph and node_id < len(graph['breakdowns'])): + logger.error('Invalid node id: %s', node_id) + raise ParamValueError('Invalid node id.') + + memory_breakdowns = graph.get('breakdowns')[node_id] + + return {'breakdowns': memory_breakdowns} + def _get_file_content(self, device_type, file_type): """ Get file content for different types of memory usage files. @@ -88,26 +117,18 @@ class MemoryUsageAnalyser(BaseAnalyser): dict, file content corresponding to file_type. """ file_path = self._get_file_path(device_type, file_type) - file_content = {} - if os.path.exists(file_path): - try: - with open(file_path, 'r') as f_obj: - file_content = json.load(f_obj) - except (IOError, OSError, json.JSONDecodeError) as err: - logger.error('Error occurred when read memory file: %s', err) - raise ProfilerIOException - else: - logger.info('Invalid file path. Please check the output path: %s', file_path) + if not os.path.exists(file_path): + logger.error('Invalid file path. Please check the output path: %s', file_path) + raise ProfilerFileNotFoundException(msg='Invalid memory file path.') - return file_content + try: + with open(file_path, 'r') as f_obj: + file_content = json.load(f_obj) + except (IOError, OSError, json.JSONDecodeError) as err: + logger.error('Error occurred when read memory file: %s', err) + raise ProfilerIOException - @staticmethod - def _process_memory_details(memory_details): - """Process memory details, change the node dict to node list.""" - for key in memory_details.keys(): - if 'nodes' in memory_details[key]: - nodes = list(memory_details[key]['nodes'].values()) - memory_details[key]['nodes'] = nodes + return file_content def _get_file_path(self, device_type, file_type): """ @@ -127,8 +148,8 @@ class MemoryUsageAnalyser(BaseAnalyser): elif file_type is FileType.DETAILS.value: filename = self._details_filename.format(self._device_id) else: - logger.info('Memory Usage only supports Ascend for now. Please check the device type.') - raise ParamValueError("Invalid device_type.") + logger.error('Memory Usage only supports Ascend for now. Please check the device type.') + raise ParamValueError("Invalid device type.") file_path = os.path.join(self._profiling_dir, filename) file_path = validate_and_normalize_path(