You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 7.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the node of graph and associated base types.
  17. """
  18. from enum import Enum
  19. class NodeTypeEnum(Enum):
  20. """Node type enum. The following types are new to our custom."""
  21. NAME_SCOPE = 'name_scope'
  22. AGGREGATION_SCOPE = 'aggregation_scope'
  23. PARAMETER = 'Parameter'
  24. CONST = 'Const'
  25. class Node:
  26. """
  27. Define a node object.
  28. Args:
  29. name (str): Name of new node.
  30. node_id (str): The id of this node, and node id is unique in graph.
  31. """
  32. def __init__(self, name, node_id, topological_index=-1):
  33. self._node_id = node_id
  34. self.name = name
  35. self.type = ""
  36. self._attr = dict()
  37. self._input = dict()
  38. self.output_i = 0
  39. self._output = {}
  40. self._proxy_input = {}
  41. self._proxy_output = {}
  42. self.subnode_count = 0
  43. self.scope = ""
  44. self.independent_layout = False
  45. self.output_shape = []
  46. self.output_data_type = ""
  47. self.output_nums = 0
  48. self.elem_types = []
  49. self.full_name = ""
  50. # This value will be used as the priority field.
  51. self.topological_index = topological_index
  52. self.stack = []
  53. def to_dict(self):
  54. """Converts the node object to dictionary format."""
  55. return {
  56. 'name': self.name,
  57. 'type': self.type,
  58. 'attr': self._attr,
  59. 'input': self._input,
  60. 'output': self._output,
  61. 'output_i': self.output_i,
  62. 'proxy_input': self._proxy_input,
  63. 'proxy_output': self._proxy_output,
  64. 'subnode_count': self.subnode_count,
  65. 'independent_layout': self.independent_layout,
  66. 'stack_info': [source.to_dict() for source in self.stack]
  67. }
  68. @property
  69. def node_id(self):
  70. """The id of this node, and id is unique in graph."""
  71. return self._node_id
  72. @staticmethod
  73. def create_node_name(scope, base_name):
  74. """
  75. The name of the node consists of the scope and the basic name.
  76. Args:
  77. scope (str): The scope of node, such as 'Default/Conv2D'
  78. base_name (str): The base name of node, such as 'Add11'.
  79. Returns:
  80. str, a node name.
  81. """
  82. return f'{scope}/{base_name}' if scope else base_name
  83. @property
  84. def attr(self):
  85. """Get node attr."""
  86. return self._attr
  87. def add_attr(self, attr_dict):
  88. """
  89. Update node attr.
  90. Args:
  91. attr_dict (dict[str, str]): The attr of node.
  92. """
  93. self._attr.update(attr_dict)
  94. @property
  95. def inputs(self):
  96. """
  97. Get all input of current node.
  98. Returns:
  99. dict[str, dict], refer to the input attr.
  100. """
  101. return self._input
  102. def add_inputs(self, src_name, input_attr):
  103. """
  104. Update input.
  105. Args:
  106. src_name (stc): The source node name.
  107. input_attr (dict): The attribute of the input.
  108. - shape (list): The shape of input tensor.
  109. - edge_type (str): The type of edge, optional value refer to `EdgeTypeEnum`.
  110. - data_type (str): The data type of the input.
  111. - independent_layout (bool): Indicates whether the source nodes are laid out independently.
  112. """
  113. self._input.update({src_name: input_attr})
  114. def delete_inputs(self, src_name):
  115. """
  116. Delete input attribute by the given source name.
  117. Args:
  118. src_name (str): The source node name.
  119. """
  120. self._input.pop(src_name)
  121. @property
  122. def outputs(self):
  123. """The output node of this node."""
  124. return self._output
  125. def add_outputs(self, dst_name, output_attr):
  126. """
  127. Add a output node to this node.
  128. Args:
  129. dst_name (str): The name of the output node.
  130. output_attr (dict: Same as the input attribute.
  131. """
  132. self._output.update({dst_name: output_attr})
  133. def delete_outputs(self, dst_name):
  134. """
  135. Delete a output node.
  136. Args:
  137. dst_name (str): The name of the node to be deleted.
  138. """
  139. self._output.pop(dst_name)
  140. @property
  141. def proxy_inputs(self):
  142. """Return proxy input, type is dict."""
  143. return self._proxy_input
  144. def add_proxy_inputs(self, src_name, attr):
  145. """
  146. Add a proxy input to node.
  147. Args:
  148. src_name (str): The name of the input node.
  149. attr (dict): The attr of the input.
  150. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  151. """
  152. self._proxy_input.update({src_name: attr})
  153. def delete_proxy_inputs(self, src_name):
  154. """Delete a proxy input by the src name."""
  155. self._proxy_input.pop(src_name)
  156. @property
  157. def proxy_outputs(self):
  158. """Get proxy output, data type is dict."""
  159. return self._proxy_output
  160. def add_proxy_outputs(self, dst_name, attr):
  161. """
  162. Add a proxy output to node.
  163. Args:
  164. dst_name (str): The name of the output node.
  165. attr (dict): The attr of the output.
  166. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  167. """
  168. self._proxy_output.update({dst_name: attr})
  169. def delete_proxy_outputs(self, dst_name):
  170. """Delete a proxy output by dst name."""
  171. self._proxy_output.pop(dst_name)
  172. @staticmethod
  173. def copy_node_without_input_output(src_node, dst_node):
  174. """
  175. Copy a source node attribute to a new node, but not input and output.
  176. Args:
  177. src_node (Node): The copied node.
  178. dst_node (Node): The destination node.
  179. """
  180. dst_node.full_name = src_node.full_name
  181. dst_node.type = src_node.type
  182. dst_node.output_i = src_node.output_i
  183. dst_node.subnode_count = src_node.subnode_count
  184. dst_node.scope = src_node.scope
  185. dst_node.independent_layout = src_node.independent_layout
  186. dst_node.output_shape = src_node.output_shape
  187. dst_node.output_data_type = src_node.output_data_type
  188. dst_node.output_nums = src_node.output_nums
  189. dst_node.elem_types = src_node.elem_types
  190. dst_node.add_attr(src_node.attr)
  191. dst_node.stack = src_node.stack
  192. def __str__(self):
  193. return f'<Node, name: {self.name}, type: {self.type}>'