You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

device_handler.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the Device stream handler."""
  16. from collections import defaultdict
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, DeviceIdUnregistered, \
  18. DebuggerParamTypeError
  19. from mindinsight.debugger.common.log import LOGGER as log
  20. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  21. class DeviceHandler(StreamHandlerBase):
  22. """Metadata Handler."""
  23. def __init__(self):
  24. # contains all device infos, the format is like Dict[int(<device_id>, <device_info>)]
  25. self._rank_info = defaultdict(DeviceInfo)
  26. self._device_rank_map = {}
  27. @property
  28. def rank_ids(self):
  29. """The rank ids."""
  30. return list(self._rank_info)
  31. @property
  32. def device_amount(self):
  33. """The rank ids."""
  34. return len(self._rank_info)
  35. def put(self, value):
  36. """
  37. Put value into device info cache.
  38. Args:
  39. value (list): The list of server info. Each item is format like:
  40. {
  41. "server_id": str,
  42. "device": list[<Device Info>]
  43. },
  44. The format of <Device Info> is like:
  45. {
  46. "device_id": str,
  47. "device_ip": str,
  48. "rank_id": str
  49. }.
  50. """
  51. if not isinstance(value, list):
  52. log.error("Invalid input type. list object is expected.")
  53. raise DebuggerParamTypeError("List object is expected.")
  54. try:
  55. self._extract_rank_info(value)
  56. except TypeError as err:
  57. log.exception(err)
  58. log.error("Invalid Device info.")
  59. raise DebuggerParamValueError("Invalid device info.")
  60. log.debug("Put Device into cache")
  61. def _extract_rank_info(self, value):
  62. """Extract rank info and save."""
  63. for server_info in value:
  64. server_ip = server_info.get('server_id')
  65. for device_info in server_info.get('device', []):
  66. rank_id = int(device_info.get('rank_id'))
  67. if rank_id in self._rank_info:
  68. log.error("Repeated rank info for rank_id: %d", rank_id)
  69. raise DebuggerParamValueError("Repeated rank info.")
  70. device_info_obj = self._rank_info[rank_id]
  71. device_info_obj.rank_id = rank_id
  72. device_info_obj.server_ip = server_ip
  73. device_info_obj.device_id = int(device_info.get('device_id'))
  74. device_info_obj.device_ip = device_info.get('device_ip')
  75. self._device_rank_map[device_info_obj.device_id] = rank_id
  76. def add_step_num_info(self, step_info):
  77. """
  78. Add step number information for each device.
  79. Args:
  80. step_info (dict): Step info per device. The key is the device id, the value
  81. is the relative step number.
  82. """
  83. if not step_info:
  84. log.warning("No step number information.")
  85. return
  86. if len(step_info) == 1 and not self._rank_info:
  87. device_id = int(list(step_info)[0])
  88. log.info("Default registered device %d as rank 0.", device_id)
  89. self._rank_info[0].device_id = device_id
  90. if len(step_info) > 1 and not self._rank_info:
  91. log.error("Missing device info for multi-card training.")
  92. raise DeviceIdUnregistered("all")
  93. for device_id, step_num in step_info.items():
  94. device_id = int(device_id)
  95. rank_id = self.get_rank_id_by_device_id(device_id)
  96. self._rank_info[rank_id].step_num = step_num
  97. def add_graph_name_info(self, graphs):
  98. """
  99. Add graph name per device.
  100. Args:
  101. graphs (dict): Graph infos of all rank id. Each item is format like
  102. """
  103. for rank_id, graph_info in graphs.items():
  104. graph_names = list(graph_info)
  105. if len(graph_names) > 1:
  106. # if more than one graphs in a device, sort them
  107. # by the number following the last "_" in the graph_name
  108. graph_names = sorted(graph_names, key=lambda x: x.split("_")[-1])
  109. self._rank_info[rank_id].graph_names = graph_names
  110. def get(self, filter_condition=None):
  111. """
  112. Get device information according to filter_condition.
  113. Args:
  114. filter_condition (list): The rank id.
  115. Returns:
  116. dict, the device info.
  117. """
  118. if filter_condition is None:
  119. filter_condition = self.rank_ids
  120. if not isinstance(filter_condition, list):
  121. filter_condition = [filter_condition]
  122. device_infos = []
  123. for rank_id in filter_condition:
  124. device_info = self._rank_info.get(rank_id)
  125. if device_info is None:
  126. log.error("Invalid rank id.")
  127. raise DeviceIdUnregistered(rank_id)
  128. device_infos.append(device_info.to_dict())
  129. return {'devices': device_infos}
  130. def get_rank_id_by_device_id(self, device_id):
  131. """
  132. Get rank id by device id.
  133. Args:
  134. device_id (int): The device id.
  135. Returns:
  136. int, the rank id.
  137. """
  138. rank_id = self._device_rank_map.get(device_id)
  139. if rank_id is None:
  140. log.error("Failed to find rank_id for device_id %s", device_id)
  141. raise DeviceIdUnregistered(device_id)
  142. return rank_id
  143. def get_device_id_by_rank_id(self, rank_id):
  144. """
  145. Get device id by rank id.
  146. Args:
  147. rank_id (int): The rank id.
  148. Returns:
  149. int, the device id.
  150. """
  151. device_info = self._rank_info.get(rank_id)
  152. if device_info:
  153. return device_info.device_id
  154. log.error("Failed to find device id according to rank_id %s", rank_id)
  155. raise DeviceIdUnregistered(rank_id)
  156. class DeviceInfo:
  157. """Device info object."""
  158. def __init__(self):
  159. self.rank_id = 0
  160. self.device_id = 0
  161. self.server_ip = ''
  162. self.graph_names = []
  163. self.device_ip = ''
  164. self.step_num = 0
  165. def to_dict(self):
  166. """Convert device info to dict."""
  167. res = {
  168. 'rank_id': self.rank_id,
  169. 'server_ip': self.server_ip,
  170. 'device_id': self.device_id,
  171. 'device_ip': self.device_ip,
  172. 'graph_names': self.graph_names,
  173. 'total_step_num': self.step_num
  174. }
  175. return res