You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_handler.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the graph stream handler."""
  16. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
  17. DebuggerNodeNotInGraphError, DebuggerGraphNotExistError
  18. from mindinsight.debugger.common.log import logger as log
  19. from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph
  20. from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
  21. class GraphHandler(StreamHandlerBase):
  22. """Metadata Handler."""
  23. def __init__(self):
  24. self._graph_proto = None
  25. self._graph = None
  26. self._searched_node_list = []
  27. self.bfs_order = []
  28. @property
  29. def graph(self):
  30. """The property of graph."""
  31. return self._graph_proto
  32. def put(self, value):
  33. """
  34. Put value into graph cache. Called by grpc server.
  35. Args:
  36. value (GraphProto): The Graph proto message.
  37. """
  38. self._graph_proto = value
  39. log.info("Put graph into cache.")
  40. # build graph
  41. graph = DebuggerGraph()
  42. graph.build_graph(value)
  43. self._graph = graph
  44. self.bfs_order = self._graph.get_bfs_order()
  45. def get(self, filter_condition=None):
  46. """
  47. Get the graph of specific node.
  48. Args:
  49. filter_condition (dict):
  50. - name (str): The full debug node name.
  51. - single_node (bool): If True, return the graph from root
  52. to the specific node; else, return the sublayer of the
  53. graph. Default: False.
  54. Returns:
  55. dict, the metadata.
  56. """
  57. try:
  58. self._graph_exists()
  59. except DebuggerGraphNotExistError:
  60. log.warning('The graph is empty. To view a graph, '
  61. 'please start the training script first.')
  62. return {'graph': {}}
  63. if filter_condition is None:
  64. filter_condition = {}
  65. single_node = filter_condition.get('single_node', False)
  66. name = filter_condition.get('name')
  67. graph = {}
  68. if single_node is True:
  69. nodes = self.get_single_node(name)
  70. else:
  71. nodes = self.list_nodes(name)
  72. graph.update(nodes)
  73. return {'graph': graph}
  74. def get_tensor_history(self, node_name, depth=0):
  75. """
  76. Get the tensor history of a specified node.
  77. Args:
  78. node_name (str): The debug name of the node.
  79. depth (int): The number of layers the user
  80. wants to trace. Default is 0.
  81. Returns:
  82. dict, basic tensor history, only including tensor name and tensor type and node type.
  83. """
  84. self._graph_exists()
  85. if not self._graph.exist_node(node_name):
  86. raise DebuggerNodeNotInGraphError(node_name)
  87. tensor_history, cur_outputs_nums = self._graph.get_tensor_history(
  88. node_name, depth
  89. )
  90. # add the tensor type for tensor history
  91. self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output')
  92. self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input')
  93. log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name)
  94. return {'tensor_history': tensor_history}
  95. @staticmethod
  96. def _update_tensor_history(tensor_history, tensor_type):
  97. """
  98. Add tensor source type for tensor history.
  99. Args:
  100. tensor_history (list[dict]): Tensor history from Graph stream. Each element has two
  101. keys: `node_type` and `name`. `node_type` refers to the type of the node which
  102. the tensor come from. `name` refers to the tensor name.
  103. tensor_type (str): The source type of the tensor. `input` or `output`.
  104. """
  105. for single_tensor_info in tensor_history:
  106. single_tensor_info['type'] = tensor_type
  107. def search_nodes(self, pattern):
  108. """
  109. Search nodes by given pattern.
  110. Args:
  111. pattern (Union[str, None]): The pattern of the node to search,
  112. if None, return all node names.
  113. Returns:
  114. dict, the searched node.
  115. """
  116. self._graph_exists()
  117. self._searched_node_list = self._graph.search_nodes_by_pattern(pattern)
  118. nodes = self._graph.get_nodes(self._searched_node_list)
  119. return {'nodes': nodes}
  120. def get_nodes_by_scope(self, scope_name):
  121. """
  122. Get node by a given scope name.
  123. Args:
  124. scope_name (str): The name of scope.
  125. Returns:
  126. list[Node], a list of node.
  127. """
  128. return self._graph.search_leaf_nodes_by_pattern(scope_name)
  129. def get_searched_node_list(self):
  130. """Get searched node list."""
  131. return self._searched_node_list
  132. def get_node_type(self, node_name):
  133. """
  134. Get the type of the specified node.
  135. Args:
  136. node_name (str): The debug name of the node.
  137. Returns:
  138. A string of the node type, name_scope or leaf.
  139. """
  140. self._graph_exists()
  141. node_type = self._graph.get_node_type(node_name)
  142. return node_type
  143. def get_full_name(self, node_name):
  144. """Get full name according to ui node name."""
  145. full_name = self._graph.get_full_name_by_node_name(node_name) if node_name else ''
  146. return full_name
  147. def get_node_name_by_full_name(self, full_name):
  148. """Get UI node name by full name."""
  149. if self._graph:
  150. node_name = self._graph.get_node_name_by_full_name(full_name)
  151. else:
  152. node_name = ''
  153. log.info("No graph received yet.")
  154. return node_name
  155. def list_nodes(self, scope):
  156. """
  157. Get the nodes of every layer in graph.
  158. Args:
  159. scope (str): The name of a scope.
  160. Returns:
  161. TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}.
  162. example:
  163. {
  164. "nodes" : [
  165. {
  166. "attr" :
  167. {
  168. "index" : "i: 0\n"
  169. },
  170. "input" : {},
  171. "name" : "input_tensor",
  172. "output" :
  173. {
  174. "Default/TensorAdd-op17" :
  175. {
  176. "edge_type" : "data",
  177. "scope" : "name_scope",
  178. "shape" : [1, 16, 128, 128]
  179. }
  180. },
  181. "output_i" : -1,
  182. "proxy_input" : {},
  183. "proxy_output" : {},
  184. "independent_layout" : False,
  185. "subnode_count" : 0,
  186. "type" : "Data"
  187. }
  188. ]
  189. }
  190. """
  191. if scope and not self._graph.exist_node(scope):
  192. raise DebuggerNodeNotInGraphError(node_name=scope)
  193. nodes = self._graph.list_node_by_scope(scope=scope)
  194. return {'nodes': nodes}
  195. def get_node_by_bfs_order(self, node_name=None, ascend=True):
  196. """
  197. Traverse the graph in order of breath-first search by given node.
  198. Args:
  199. node_name (str): The name of current chosen leaf node.
  200. ascend (bool): If True, traverse the input nodes;
  201. If False, traverse the output nodes. Default is True.
  202. Returns:
  203. Union[None, dict], the next node object in dict type or None.
  204. """
  205. self._graph_exists()
  206. bfs_order = self.bfs_order
  207. length = len(bfs_order)
  208. if not bfs_order:
  209. log.error('Cannot get the BFS order of the graph!')
  210. msg = 'Cannot get the BFS order of the graph!'
  211. raise DebuggerParamValueError(msg)
  212. if node_name is None:
  213. if ascend is False:
  214. next_node = None
  215. else:
  216. next_node = bfs_order[0]
  217. else:
  218. try:
  219. index = bfs_order.index(node_name)
  220. log.debug("The index of the node in BFS list is: %d", index)
  221. except ValueError as err:
  222. log.error('Cannot find the node: %s. Please check '
  223. 'the node name: %s', node_name, err)
  224. msg = f'Cannot find the node: {node_name}. ' \
  225. f'Please check the node name {err}.'
  226. raise DebuggerParamValueError(msg)
  227. next_node = self.get_next_node_in_bfs(index, length, ascend)
  228. return next_node
  229. def get_next_node_in_bfs(self, index, length, ascend):
  230. """
  231. Get the next node in bfs order.
  232. Args:
  233. index (int): The current index.
  234. length (int): The number of all leaf nodes.
  235. ascend (bool): Whether get the node in ascend order or not.
  236. Returns:
  237. Union[None, dict], the next node object in dict type or None.
  238. """
  239. next_node = None
  240. if 0 <= index < length:
  241. if ascend is True and index < length - 1:
  242. next_node = self.bfs_order[index + 1]
  243. elif ascend is False and index > 0:
  244. next_node = self.bfs_order[index - 1]
  245. return next_node
  246. def get_single_node(self, name):
  247. """
  248. Search node, and return every layer nodes until this node.
  249. Args:
  250. name (str): The name of node.
  251. Returns:
  252. dict, every layer nodes until this node.
  253. """
  254. nodes = self._graph.search_single_node(name)
  255. return nodes
  256. def _graph_exists(self):
  257. """
  258. Check if the graph has been loaded in the debugger cache.
  259. Raises:
  260. DebuggerGraphNotExistError: If the graph does not exist.
  261. """
  262. if self._graph is None:
  263. log.error('The graph does not exist. Please start the '
  264. 'training script and try again.')
  265. raise DebuggerGraphNotExistError