|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Define the graph stream handler."""
- from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError, \
- DebuggerNodeNotInGraphError, DebuggerGraphNotExistError
- from mindinsight.debugger.common.log import logger as log
- from mindinsight.debugger.stream_cache.debugger_graph import DebuggerGraph
- from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
-
-
- class GraphHandler(StreamHandlerBase):
- """Metadata Handler."""
-
- def __init__(self):
- self._graph_proto = None
- self._graph = None
- self._searched_node_list = []
- self.bfs_order = []
-
- @property
- def graph(self):
- """The property of graph."""
- return self._graph_proto
-
- def put(self, value):
- """
- Put value into graph cache. Called by grpc server.
-
- Args:
- value (GraphProto): The Graph proto message.
- """
- self._graph_proto = value
- log.info("Put graph into cache.")
-
- # build graph
- graph = DebuggerGraph()
- graph.build_graph(value)
- self._graph = graph
- self.bfs_order = self._graph.get_bfs_order()
-
- def get(self, filter_condition=None):
- """
- Get the graph of specific node.
-
- Args:
- filter_condition (dict):
-
- - name (str): The full debug node name.
-
- - single_node (bool): If True, return the graph from root
- to the specific node; else, return the sublayer of the
- graph. Default: False.
-
- Returns:
- dict, the metadata.
- """
- try:
- self._graph_exists()
- except DebuggerGraphNotExistError:
- log.warning('The graph is empty. To view a graph, '
- 'please start the training script first.')
- return {'graph': {}}
-
- if filter_condition is None:
- filter_condition = {}
- single_node = filter_condition.get('single_node', False)
- name = filter_condition.get('name')
-
- graph = {}
- if single_node is True:
- nodes = self.get_single_node(name)
- else:
- nodes = self.list_nodes(name)
- graph.update(nodes)
-
- return {'graph': graph}
-
- def get_tensor_history(self, node_name, depth=0):
- """
- Get the tensor history of a specified node.
-
- Args:
- node_name (str): The debug name of the node.
- depth (int): The number of layers the user
- wants to trace. Default is 0.
-
- Returns:
- dict, basic tensor history, only including tensor name and tensor type and node type.
- """
- self._graph_exists()
- if not self._graph.exist_node(node_name):
- raise DebuggerNodeNotInGraphError(node_name)
-
- tensor_history, cur_outputs_nums = self._graph.get_tensor_history(
- node_name, depth
- )
- # add the tensor type for tensor history
- self._update_tensor_history(tensor_history[0:cur_outputs_nums], 'output')
- self._update_tensor_history(tensor_history[cur_outputs_nums:], 'input')
- log.debug("Get %d tensors in tensor history for node <%s>.", len(tensor_history), node_name)
- return {'tensor_history': tensor_history}
-
- @staticmethod
- def _update_tensor_history(tensor_history, tensor_type):
- """
- Add tensor source type for tensor history.
-
- Args:
- tensor_history (list[dict]): Tensor history from Graph stream. Each element has two
- keys: `node_type` and `name`. `node_type` refers to the type of the node which
- the tensor come from. `name` refers to the tensor name.
- tensor_type (str): The source type of the tensor. `input` or `output`.
- """
- for single_tensor_info in tensor_history:
- single_tensor_info['type'] = tensor_type
-
- def search_nodes(self, pattern):
- """
- Search nodes by given pattern.
-
- Args:
- pattern (Union[str, None]): The pattern of the node to search,
- if None, return all node names.
-
- Returns:
- dict, the searched node.
- """
- self._graph_exists()
- self._searched_node_list = self._graph.search_nodes_by_pattern(pattern)
- nodes = self._graph.get_nodes(self._searched_node_list)
-
- return {'nodes': nodes}
-
- def get_nodes_by_scope(self, scope_name):
- """
- Get node by a given scope name.
-
- Args:
- scope_name (str): The name of scope.
-
- Returns:
- list[Node], a list of node.
- """
- return self._graph.search_leaf_nodes_by_pattern(scope_name)
-
- def get_searched_node_list(self):
- """Get searched node list."""
- return self._searched_node_list
-
- def get_node_type(self, node_name):
- """
- Get the type of the specified node.
-
- Args:
- node_name (str): The debug name of the node.
-
- Returns:
- A string of the node type, name_scope or leaf.
- """
- self._graph_exists()
- node_type = self._graph.get_node_type(node_name)
-
- return node_type
-
- def get_full_name(self, node_name):
- """Get full name according to ui node name."""
- full_name = self._graph.get_full_name_by_node_name(node_name) if node_name else ''
- return full_name
-
- def get_node_name_by_full_name(self, full_name):
- """Get UI node name by full name."""
- if self._graph:
- node_name = self._graph.get_node_name_by_full_name(full_name)
- else:
- node_name = ''
- log.info("No graph received yet.")
- return node_name
-
- def list_nodes(self, scope):
- """
- Get the nodes of every layer in graph.
-
- Args:
- scope (str): The name of a scope.
-
- Returns:
- TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}.
- example:
- {
- "nodes" : [
- {
- "attr" :
- {
- "index" : "i: 0\n"
- },
- "input" : {},
- "name" : "input_tensor",
- "output" :
- {
- "Default/TensorAdd-op17" :
- {
- "edge_type" : "data",
- "scope" : "name_scope",
- "shape" : [1, 16, 128, 128]
- }
- },
- "output_i" : -1,
- "proxy_input" : {},
- "proxy_output" : {},
- "independent_layout" : False,
- "subnode_count" : 0,
- "type" : "Data"
- }
- ]
- }
- """
- if scope and not self._graph.exist_node(scope):
- raise DebuggerNodeNotInGraphError(node_name=scope)
-
- nodes = self._graph.list_node_by_scope(scope=scope)
- return {'nodes': nodes}
-
- def get_node_by_bfs_order(self, node_name=None, ascend=True):
- """
- Traverse the graph in order of breath-first search by given node.
-
- Args:
- node_name (str): The name of current chosen leaf node.
- ascend (bool): If True, traverse the input nodes;
- If False, traverse the output nodes. Default is True.
-
- Returns:
- Union[None, dict], the next node object in dict type or None.
- """
- self._graph_exists()
- bfs_order = self.bfs_order
- length = len(bfs_order)
-
- if not bfs_order:
- log.error('Cannot get the BFS order of the graph!')
- msg = 'Cannot get the BFS order of the graph!'
- raise DebuggerParamValueError(msg)
-
- if node_name is None:
- if ascend is False:
- next_node = None
- else:
- next_node = bfs_order[0]
- else:
- try:
- index = bfs_order.index(node_name)
- log.debug("The index of the node in BFS list is: %d", index)
- except ValueError as err:
- log.error('Cannot find the node: %s. Please check '
- 'the node name: %s', node_name, err)
- msg = f'Cannot find the node: {node_name}. ' \
- f'Please check the node name {err}.'
- raise DebuggerParamValueError(msg)
-
- next_node = self.get_next_node_in_bfs(index, length, ascend)
-
- return next_node
-
- def get_next_node_in_bfs(self, index, length, ascend):
- """
- Get the next node in bfs order.
-
- Args:
- index (int): The current index.
- length (int): The number of all leaf nodes.
- ascend (bool): Whether get the node in ascend order or not.
-
- Returns:
- Union[None, dict], the next node object in dict type or None.
- """
- next_node = None
- if 0 <= index < length:
- if ascend is True and index < length - 1:
- next_node = self.bfs_order[index + 1]
- elif ascend is False and index > 0:
- next_node = self.bfs_order[index - 1]
-
- return next_node
-
- def get_single_node(self, name):
- """
- Search node, and return every layer nodes until this node.
-
- Args:
- name (str): The name of node.
-
- Returns:
- dict, every layer nodes until this node.
- """
- nodes = self._graph.search_single_node(name)
-
- return nodes
-
- def _graph_exists(self):
- """
- Check if the graph has been loaded in the debugger cache.
-
- Raises:
- DebuggerGraphNotExistError: If the graph does not exist.
- """
- if self._graph is None:
- log.error('The graph does not exist. Please start the '
- 'training script and try again.')
- raise DebuggerGraphNotExistError
|