You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

device_handler.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the Device stream handler."""
  16. from collections import defaultdict
  17. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, DeviceIdUnregistered, \
  18. DebuggerParamTypeError
  19. from mindinsight.debugger.common.log import LOGGER as log
  20. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  21. class DeviceHandler(StreamHandlerBase):
  22. """Metadata Handler."""
  23. def __init__(self):
  24. # contains all device infos, the format is like Dict[int(<device_id>, <device_info>)]
  25. self._rank_info = defaultdict(DeviceInfo)
  26. self._device_rank_map = {}
  27. @property
  28. def rank_ids(self):
  29. """The rank ids."""
  30. return list(self._rank_info)
  31. @property
  32. def device_amount(self):
  33. """The rank ids."""
  34. return len(self._rank_info)
  35. def put(self, value):
  36. """
  37. Put value into device info cache.
  38. Args:
  39. value (list): The list of server info. Each item is format like:
  40. {
  41. "server_id": str,
  42. "device": list[<Device Info>]
  43. },
  44. The format of <Device Info> is like:
  45. {
  46. "device_id": str,
  47. "device_ip": str,
  48. "rank_id": str
  49. }.
  50. """
  51. if not isinstance(value, list):
  52. log.error("Invalid input type. list object is expected.")
  53. raise DebuggerParamTypeError("List object is expected.")
  54. try:
  55. self._extract_rank_info(value)
  56. except TypeError as err:
  57. log.exception(err)
  58. log.error("Invalid Device info.")
  59. raise DebuggerParamValueError("Invalid device info.")
  60. log.debug("Put Device into cache")
  61. def _extract_rank_info(self, value):
  62. """Extract rank info and save."""
  63. for server_info in value:
  64. server_ip = server_info.get('server_id')
  65. for device_info in server_info.get('device', []):
  66. rank_id = int(device_info.get('rank_id'))
  67. if rank_id in self._rank_info:
  68. log.error("Repeated rank info for rank_id: %d", rank_id)
  69. raise DebuggerParamValueError("Repeated rank info.")
  70. device_info_obj = self._rank_info[rank_id]
  71. device_info_obj.rank_id = rank_id
  72. device_info_obj.server_ip = server_ip
  73. device_info_obj.device_id = int(device_info.get('device_id'))
  74. device_info_obj.device_ip = device_info.get('device_ip')
  75. self._device_rank_map[device_info_obj.device_id] = rank_id
  76. def add_step_num_info(self, step_info):
  77. """
  78. Add step number information for each device.
  79. Args:
  80. step_info (dict): Step info per device. The key is the device id, the value
  81. is the relative step number.
  82. """
  83. if not step_info:
  84. log.warning("No step number information.")
  85. return
  86. if len(step_info) == 1 and not self._rank_info:
  87. device_id = int(list(step_info)[0])
  88. log.info("Default registered device %d as rank 0.", device_id)
  89. self._rank_info[0].device_id = device_id
  90. if len(step_info) > 1 and not self._rank_info:
  91. log.error("Missing device info for multi-card training.")
  92. raise DeviceIdUnregistered("all")
  93. for device_id, step_num in step_info.items():
  94. device_id = int(device_id)
  95. rank_id = self.get_rank_id_by_device_id(device_id)
  96. self._rank_info[rank_id].step_num = step_num
  97. def add_graph_name_info(self, graphs):
  98. """
  99. Add graph name per device.
  100. Args:
  101. graphs (dict): Graph infos of all rank id. Each item is format like
  102. """
  103. for rank_id, graph_info in graphs.items():
  104. graph_names = list(graph_info)
  105. self._rank_info[rank_id].graph_names = graph_names
  106. def get(self, filter_condition=None):
  107. """
  108. Get device information according to filter_condition.
  109. Args:
  110. filter_condition (list): The rank id.
  111. Returns:
  112. dict, the device info.
  113. """
  114. if filter_condition is None:
  115. filter_condition = self.rank_ids
  116. if not isinstance(filter_condition, list):
  117. filter_condition = [filter_condition]
  118. device_infos = []
  119. for rank_id in filter_condition:
  120. device_info = self._rank_info.get(rank_id)
  121. if device_info is None:
  122. log.error("Invalid rank id.")
  123. raise DeviceIdUnregistered(rank_id)
  124. device_infos.append(device_info.to_dict())
  125. return {'devices': device_infos}
  126. def get_rank_id_by_device_id(self, device_id):
  127. """
  128. Get rank id by device id.
  129. Args:
  130. device_id (int): The device id.
  131. Returns:
  132. int, the rank id.
  133. """
  134. rank_id = self._device_rank_map.get(device_id)
  135. if rank_id is None:
  136. log.error("Failed to find rank_id for device_id %s", device_id)
  137. raise DeviceIdUnregistered(device_id)
  138. return rank_id
  139. def get_device_id_by_rank_id(self, rank_id):
  140. """
  141. Get device id by rank id.
  142. Args:
  143. rank_id (int): The rank id.
  144. Returns:
  145. int, the device id.
  146. """
  147. device_info = self._rank_info.get(rank_id)
  148. if device_info:
  149. return device_info.device_id
  150. log.error("Failed to find device id according to rank_id %s", rank_id)
  151. raise DeviceIdUnregistered(rank_id)
  152. class DeviceInfo:
  153. """Device info object."""
  154. def __init__(self):
  155. self.rank_id = 0
  156. self.device_id = 0
  157. self.server_ip = ''
  158. self.graph_names = []
  159. self.device_ip = ''
  160. self.step_num = 0
  161. def to_dict(self):
  162. """Convert device info to dict."""
  163. res = {
  164. 'rank_id': self.rank_id,
  165. 'server_ip': self.server_ip,
  166. 'device_id': self.device_id,
  167. 'device_ip': self.device_ip,
  168. 'graph_names': self.graph_names,
  169. 'total_step_num': self.step_num
  170. }
  171. return res