You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the DataLoader."""
  16. import os
  17. import json
  18. from mindinsight.debugger.proto.ms_graph_pb2 import ModelProto
  19. from mindinsight.debugger.common.log import LOGGER as log
  20. from mindinsight.utils.exceptions import ParamValueError
  21. from mindinsight.debugger.common.utils import DumpSettings
  22. class DataLoader:
  23. """The DataLoader object provides interface to load graphs and device information from base_dir."""
  24. def __init__(self, base_dir):
  25. self._debugger_base_dir = os.path.realpath(base_dir)
  26. self._graph_protos = []
  27. self._device_info = {}
  28. self._step_num = {}
  29. # flag for whether the data is from sync dump or async dump, True for sync dump, False for async dump.
  30. self._is_sync = None
  31. self._net_dir = ""
  32. self._net_name = ""
  33. self.initialize()
  34. def initialize(self):
  35. """Initialize the data_mode and net_dir of DataLoader."""
  36. dump_config_file = os.path.join(self._debugger_base_dir, os.path.join(".metadata", "data_dump.json"))
  37. with open(dump_config_file, 'r') as load_f:
  38. dump_config = json.load(load_f)
  39. common_settings = dump_config.get(DumpSettings.COMMON_DUMP_SETTINGS.value)
  40. if not common_settings:
  41. raise ParamValueError('common_dump_settings not found in dump_config file.')
  42. self._net_name = common_settings['net_name']
  43. if dump_config.get(DumpSettings.E2E_DUMP_SETTINGS.value) and \
  44. dump_config[DumpSettings.E2E_DUMP_SETTINGS.value]['enable']:
  45. self._is_sync = True
  46. self._net_dir = os.path.realpath(os.path.join(self._debugger_base_dir, self._net_name))
  47. elif dump_config.get(DumpSettings.ASYNC_DUMP_SETTINGS.value) and \
  48. dump_config[DumpSettings.ASYNC_DUMP_SETTINGS.value]['enable']:
  49. self._is_sync = False
  50. self._net_dir = self._debugger_base_dir
  51. else:
  52. raise ParamValueError('The data must be generated from sync dump or async dump.')
  53. def load_graphs(self):
  54. """Load graphs from the debugger_base_dir."""
  55. files = os.listdir(self._net_dir)
  56. for file in files:
  57. if not self.is_device_dir(file):
  58. continue
  59. device_id, device_dir = self.get_device_id_and_dir(file)
  60. graphs_dir = os.path.join(device_dir, 'graphs')
  61. if not os.path.exists(graphs_dir) or not os.path.isdir(graphs_dir):
  62. log.debug("Directory '%s' not exist.", graphs_dir)
  63. self._graph_protos.append({'device_id': device_id, 'graph_protos': []})
  64. continue
  65. graph_protos = get_graph_protos_from_dir(graphs_dir)
  66. self._graph_protos.append({'device_id': device_id, 'graph_protos': graph_protos})
  67. return self._graph_protos
  68. def load_device_info(self):
  69. """Load device_info from file"""
  70. hccl_json_file = os.path.join(self._debugger_base_dir, '.metadata/hccl.json')
  71. if not os.path.isfile(hccl_json_file):
  72. device = []
  73. device_ids = self.get_all_device_id()
  74. device_ids.sort()
  75. for i, device_id in enumerate(device_ids):
  76. rank_id = i
  77. device.append({'device_id': str(device_id), 'rank_id': str(rank_id)})
  78. device_target = 'Ascend'
  79. self._device_info = {'device_target': device_target,
  80. 'server_list': [{'server_id': 'localhost', 'device': device}]}
  81. else:
  82. with open(hccl_json_file, 'r') as load_f:
  83. load_dict = json.load(load_f)
  84. self._device_info = {'device_target': 'Ascend', 'server_list': load_dict['server_list']}
  85. return self._device_info
  86. def load_step_number(self):
  87. """Load step number in the directory"""
  88. files = os.listdir(self._net_dir)
  89. for file in files:
  90. if not self.is_device_dir(file):
  91. continue
  92. device_id, device_dir = self.get_device_id_and_dir(file)
  93. max_step = 0
  94. files_in_device = os.listdir(device_dir)
  95. if self._is_sync:
  96. for file_in_device in files_in_device:
  97. abs_file_in_device = os.path.join(device_dir, file_in_device)
  98. if os.path.isdir(abs_file_in_device) and file_in_device.startswith("iteration_"):
  99. step_id_str = file_in_device.split('_')[-1]
  100. max_step = update_max_step(step_id_str, max_step)
  101. self._step_num[str(device_id)] = max_step
  102. else:
  103. net_graph_dir = []
  104. for file_in_device in files_in_device:
  105. abs_file_in_device = os.path.join(device_dir, file_in_device)
  106. if os.path.isdir(abs_file_in_device) and file_in_device.startswith(self._net_name):
  107. net_graph_dir.append(abs_file_in_device)
  108. if len(net_graph_dir) > 1:
  109. log.warning("There are more than one graph directory in device_dir: %s. "
  110. "OfflineDebugger use data in %s.", device_dir, net_graph_dir[0])
  111. net_graph_dir_to_use = net_graph_dir[0]
  112. graph_id = net_graph_dir_to_use.split('_')[-1]
  113. graph_id_dir = os.path.join(net_graph_dir_to_use, graph_id)
  114. step_ids = os.listdir(graph_id_dir)
  115. for step_id_str in step_ids:
  116. max_step = update_max_step(step_id_str, max_step)
  117. self._step_num[str(device_id)] = max_step
  118. return self._step_num
  119. def is_device_dir(self, file_name):
  120. """Judge if the file_name is a sub directory named 'device_x'."""
  121. if not file_name.startswith("device_"):
  122. return False
  123. id_str = file_name.split("_")[-1]
  124. if not id_str.isdigit():
  125. return False
  126. device_dir = os.path.join(self._net_dir, file_name)
  127. if not os.path.isdir(device_dir):
  128. return False
  129. return True
  130. def get_device_id_and_dir(self, file_name):
  131. """Get device_id and absolute directory of file_name."""
  132. id_str = file_name.split("_")[-1]
  133. device_id = int(id_str)
  134. device_dir = os.path.join(self._net_dir, file_name)
  135. return device_id, device_dir
  136. def get_all_device_id(self):
  137. """Get all device_id int the debugger_base_dir"""
  138. device_ids = []
  139. files = os.listdir(self._net_dir)
  140. for file in files:
  141. if not self.is_device_dir(file):
  142. continue
  143. id_str = file.split("_")[-1]
  144. device_id = int(id_str)
  145. device_ids.append(device_id)
  146. return device_ids
  147. def get_net_dir(self):
  148. """Get graph_name directory of the data."""
  149. return self._net_dir
  150. def get_sync_flag(self):
  151. """Get the sync flag of the data."""
  152. return self._is_sync
  153. def get_net_name(self):
  154. """Get net_name of the data."""
  155. return self._net_name
  156. def load_graph_from_file(graph_file_name):
  157. """Load graph from file."""
  158. with open(graph_file_name, 'rb') as file_handler:
  159. model_bytes = file_handler.read()
  160. model = ModelProto.FromString(model_bytes)
  161. graph = model.graph
  162. return graph
  163. def get_graph_protos_from_dir(graphs_dir):
  164. """
  165. Get graph from graph directory.
  166. Args:
  167. graph_dir (str): The absolute directory of graph files.
  168. Returns:
  169. list, list of 'GraphProto' object.
  170. """
  171. files_in_graph_dir = os.listdir(graphs_dir)
  172. graph_protos = []
  173. pre_file_name = "ms_output_trace_code_graph_"
  174. for file_in_device in files_in_graph_dir:
  175. if file_in_device.startswith(pre_file_name) and file_in_device.endswith(".pb"):
  176. abs_graph_file = os.path.join(graphs_dir, file_in_device)
  177. graph_proto = load_graph_from_file(abs_graph_file)
  178. graph_protos.append(graph_proto)
  179. return graph_protos
  180. def update_max_step(step_id_str, max_step):
  181. """Update max_step by compare step_id_str and max_step."""
  182. res = max_step
  183. if step_id_str.isdigit():
  184. step_id = int(step_id_str)
  185. if step_id > max_step:
  186. res = step_id
  187. return res