You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debugger_multigraph.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """This file is used to define the basic graph."""
  16. import copy
  17. from mindinsight.debugger.common.log import LOGGER as log
  18. from mindinsight.datavisual.data_transform.graph.node import Node, NodeTypeEnum
  19. from .debugger_graph import DebuggerGraph
  20. class DebuggerMultiGraph(DebuggerGraph):
  21. """The `DebuggerMultiGraph` object provides interfaces to describe a debugger multigraph."""
  22. def add_graph(self, graph_dict):
  23. """
  24. Add graphs to DebuggerMultiGraph.
  25. Args:
  26. graph_dict (dict): The <graph_name, graph_object> dict.
  27. """
  28. if len(graph_dict) == 1:
  29. graph = list(graph_dict.values())[0]
  30. self._normal_node_map = graph.normal_node_map
  31. self._node_id_map_name = graph.node_id_map_name
  32. self._const_node_temp_cache = graph.const_node_temp_cache
  33. self._parameter_node_temp_cache = graph.parameter_node_temp_cache
  34. self._leaf_nodes = graph.leaf_nodes
  35. self._full_name_map_name = graph.full_name_map_name
  36. else:
  37. for graph_name, graph in graph_dict.items():
  38. log.debug("add graph %s into whole graph.", graph_name)
  39. # add nodes
  40. normal_nodes = copy.deepcopy(graph.normal_node_map)
  41. for _, node_obj in normal_nodes.items():
  42. self._add_graph_scope(node_obj, graph_name)
  43. self._cache_node(node_obj)
  44. # add graph_node
  45. node = Node(name=graph_name, node_id=graph_name)
  46. node.type = NodeTypeEnum.NAME_SCOPE.value
  47. node.subnode_count = len(graph.list_node_by_scope())
  48. self._cache_node(node)
  49. self._leaf_nodes = self._get_leaf_nodes()
  50. self._full_name_map_name = self._get_leaf_node_full_name_map()
  51. log.info(
  52. "Build multi_graph end, all node count: %s, const count: %s, parameter count: %s.",
  53. self.normal_node_count, len(self._const_node_temp_cache),
  54. len(self._parameter_node_temp_cache))
  55. def _add_graph_scope(self, node, graph_name):
  56. """Add graph scope to the inputs and outputs in node"""
  57. # add graph scope to node name
  58. pre_scope = graph_name + "/"
  59. node.name = pre_scope + node.name
  60. if node.scope:
  61. node.scope = pre_scope + node.scope
  62. else:
  63. node.scope = graph_name
  64. # update inputs
  65. old_inputs = copy.deepcopy(node.inputs)
  66. for src_name, input_attr in old_inputs.items():
  67. new_src_name = graph_name + "/" + src_name
  68. node.add_inputs(new_src_name, input_attr)
  69. node.delete_inputs(src_name)
  70. # update outputs
  71. old_outputs = copy.deepcopy(node.outputs)
  72. for dst_name, output_attr in old_outputs.items():
  73. new_dst_name = graph_name + "/" + dst_name
  74. node.add_outputs(new_dst_name, output_attr)
  75. node.delete_outputs(dst_name)
  76. # update proxy_inputs
  77. old_proxy_inputs = copy.deepcopy(node.proxy_inputs)
  78. for src_name, input_attr in old_proxy_inputs.items():
  79. new_src_name = graph_name + "/" + src_name
  80. node.add_proxy_inputs(new_src_name, input_attr)
  81. node.delete_proxy_inputs(src_name)
  82. # update proxy_outputs
  83. old_proxy_outputs = copy.deepcopy(node.proxy_outputs)
  84. for dst_name, output_attr in old_proxy_outputs.items():
  85. new_dst_name = graph_name + "/" + dst_name
  86. node.add_proxy_outputs(new_dst_name, output_attr)
  87. node.delete_proxy_outputs(dst_name)