|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """This file is used to define the DataLoader."""
- import os
- import json
- from mindinsight.debugger.proto.ms_graph_pb2 import ModelProto
- from mindinsight.debugger.common.log import LOGGER as log
- from mindinsight.utils.exceptions import ParamValueError
- from mindinsight.debugger.common.utils import DumpSettings
-
-
- class DataLoader:
- """The DataLoader object provides interface to load graphs and device information from base_dir."""
- def __init__(self, base_dir):
- self._debugger_base_dir = os.path.realpath(base_dir)
- self._graph_protos = []
- self._device_info = {}
- self._step_num = {}
- # flag for whether the data is from sync dump or async dump, True for sync dump, False for async dump.
- self._is_sync = None
- self._net_dir = ""
- self._net_name = ""
- self.initialize()
-
- def initialize(self):
- """Initialize the data_mode and net_dir of DataLoader."""
- dump_config_file = os.path.join(self._debugger_base_dir, os.path.join(".metadata", "data_dump.json"))
- with open(dump_config_file, 'r') as load_f:
- dump_config = json.load(load_f)
- common_settings = dump_config.get(DumpSettings.COMMON_DUMP_SETTINGS.value)
- if not common_settings:
- raise ParamValueError('common_dump_settings not found in dump_config file.')
- self._net_name = common_settings['net_name']
- if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \
- dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']:
- self._is_sync = True
- self._net_dir = os.path.realpath(os.path.join(self._debugger_base_dir, self._net_name))
- elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \
- dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']:
- self._is_sync = False
- self._net_dir = self._debugger_base_dir
- else:
- raise ParamValueError('The data must be generated from sync dump or async dump.')
-
- def load_graphs(self):
- """Load graphs from the debugger_base_dir."""
- files = os.listdir(self._net_dir)
- for file in files:
- if not self.is_device_dir(file):
- continue
- device_id, device_dir = self.get_device_id_and_dir(file)
- graphs_dir = os.path.join(device_dir, 'graphs')
- if not os.path.exists(graphs_dir) or not os.path.isdir(graphs_dir):
- log.debug("Directory '%s' not exist.", graphs_dir)
- self._graph_protos.append({'device_id': device_id, 'graph_protos': []})
- continue
- graph_protos = get_graph_protos_from_dir(graphs_dir)
- self._graph_protos.append({'device_id': device_id, 'graph_protos': graph_protos})
-
- return self._graph_protos
-
- def load_device_info(self):
- """Load device_info from file"""
- hccl_json_file = os.path.join(self._debugger_base_dir, '.metadata/hccl.json')
- if not os.path.isfile(hccl_json_file):
- device = []
- device_ids = self.get_all_device_id()
- device_ids.sort()
- for i, device_id in enumerate(device_ids):
- rank_id = i
- device.append({'device_id': str(device_id), 'rank_id': str(rank_id)})
- device_target = 'Ascend'
- self._device_info = {'device_target': device_target,
- 'server_list': [{'server_id': 'localhost', 'device': device}]}
- else:
- with open(hccl_json_file, 'r') as load_f:
- load_dict = json.load(load_f)
- self._device_info = {'device_target': 'Ascend', 'server_list': load_dict['server_list']}
- return self._device_info
-
- def load_step_number(self):
- """Load step number in the directory"""
- files = os.listdir(self._net_dir)
- for file in files:
- if not self.is_device_dir(file):
- continue
- device_id, device_dir = self.get_device_id_and_dir(file)
- max_step = 0
- files_in_device = os.listdir(device_dir)
- if self._is_sync:
- for file_in_device in files_in_device:
- abs_file_in_device = os.path.join(device_dir, file_in_device)
- if os.path.isdir(abs_file_in_device) and file_in_device.startswith("iteration_"):
- step_id_str = file_in_device.split('_')[-1]
- max_step = update_max_step(step_id_str, max_step)
- self._step_num[str(device_id)] = max_step
- else:
- net_graph_dir = []
- for file_in_device in files_in_device:
- abs_file_in_device = os.path.join(device_dir, file_in_device)
- if os.path.isdir(abs_file_in_device) and file_in_device.startswith(self._net_name):
- net_graph_dir.append(abs_file_in_device)
- if len(net_graph_dir) > 1:
- log.warning("There are more than one graph directory in device_dir: %s. "
- "OfflineDebugger use data in %s.", device_dir, net_graph_dir[0])
- net_graph_dir_to_use = net_graph_dir[0]
- graph_id = net_graph_dir_to_use.split('_')[-1]
- graph_id_dir = os.path.join(net_graph_dir_to_use, graph_id)
- step_ids = os.listdir(graph_id_dir)
- for step_id_str in step_ids:
- max_step = update_max_step(step_id_str, max_step)
- self._step_num[str(device_id)] = max_step
-
- return self._step_num
-
- def is_device_dir(self, file_name):
- """Judge if the file_name is a sub directory named 'device_x'."""
- if not file_name.startswith("device_"):
- return False
- id_str = file_name.split("_")[-1]
- if not id_str.isdigit():
- return False
- device_dir = os.path.join(self._net_dir, file_name)
- if not os.path.isdir(device_dir):
- return False
- return True
-
- def get_device_id_and_dir(self, file_name):
- """Get device_id and absolute directory of file_name."""
- id_str = file_name.split("_")[-1]
- device_id = int(id_str)
- device_dir = os.path.join(self._net_dir, file_name)
- return device_id, device_dir
-
- def get_all_device_id(self):
- """Get all device_id int the debugger_base_dir"""
- device_ids = []
- files = os.listdir(self._net_dir)
- for file in files:
- if not self.is_device_dir(file):
- continue
- id_str = file.split("_")[-1]
- device_id = int(id_str)
- device_ids.append(device_id)
- return device_ids
-
- def get_net_dir(self):
- """Get graph_name directory of the data."""
- return self._net_dir
-
- def get_sync_flag(self):
- """Get the sync flag of the data."""
- return self._is_sync
-
- def get_net_name(self):
- """Get net_name of the data."""
- return self._net_name
-
-
- def load_graph_from_file(graph_file_name):
- """Load graph from file."""
- with open(graph_file_name, 'rb') as file_handler:
- model_bytes = file_handler.read()
- model = ModelProto.FromString(model_bytes)
- graph = model.graph
-
- return graph
-
-
- def get_graph_protos_from_dir(graphs_dir):
- """
- Get graph from graph directory.
-
- Args:
- graph_dir (str): The absolute directory of graph files.
-
- Returns:
- list, list of 'GraphProto' object.
- """
- files_in_graph_dir = os.listdir(graphs_dir)
- graph_protos = []
- pre_file_name = "ms_output_trace_code_graph_"
- for file_in_device in files_in_graph_dir:
- if file_in_device.startswith(pre_file_name) and file_in_device.endswith(".pb"):
- abs_graph_file = os.path.join(graphs_dir, file_in_device)
- graph_proto = load_graph_from_file(abs_graph_file)
- graph_protos.append(graph_proto)
- return graph_protos
-
-
- def update_max_step(step_id_str, max_step):
- """Update max_step by compare step_id_str and max_step."""
- res = max_step
- if step_id_str.isdigit():
- step_id = int(step_id_str)
- if step_id > max_step:
- res = step_id
- return res
|