You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 7.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the node of graph and associated base types.
  17. """
  18. from enum import Enum
  19. class NodeTypeEnum(Enum):
  20. """Node type enum. The following types are new to our custom."""
  21. NAME_SCOPE = 'name_scope'
  22. AGGREGATION_SCOPE = 'aggregation_scope'
  23. PARAMETER = 'Parameter'
  24. CONST = 'Const'
  25. class Node:
  26. """
  27. Define a node object.
  28. Args:
  29. name (str): Name of new node.
  30. node_id (str): The id of this node, and node id is unique in graph.
  31. """
  32. def __init__(self, name, node_id, topological_index=-1):
  33. self._node_id = node_id
  34. self.name = name
  35. self.type = ""
  36. self._attr = dict()
  37. self._input = dict()
  38. self.output_i = 0
  39. self._output = {}
  40. self._proxy_input = {}
  41. self._proxy_output = {}
  42. self.subnode_count = 0
  43. self.scope = ""
  44. self.independent_layout = False
  45. self.output_shape = []
  46. self.output_data_type = ""
  47. self.output_nums = 0
  48. self.elem_types = []
  49. self.full_name = ""
  50. # This value will be used as the priority field.
  51. self.topological_index = topological_index
  52. def to_dict(self):
  53. """Converts the node object to dictionary format."""
  54. return {
  55. 'name': self.name,
  56. 'type': self.type,
  57. 'attr': self._attr,
  58. 'input': self._input,
  59. 'output': self._output,
  60. 'output_i': self.output_i,
  61. 'proxy_input': self._proxy_input,
  62. 'proxy_output': self._proxy_output,
  63. 'subnode_count': self.subnode_count,
  64. 'independent_layout': self.independent_layout
  65. }
  66. @property
  67. def node_id(self):
  68. """The id of this node, and id is unique in graph."""
  69. return self._node_id
  70. @staticmethod
  71. def create_node_name(scope, base_name):
  72. """
  73. The name of the node consists of the scope and the basic name.
  74. Args:
  75. scope (str): The scope of node, such as 'Default/Conv2D'
  76. base_name (str): The base name of node, such as 'Add11'.
  77. Returns:
  78. str, a node name.
  79. """
  80. return f'{scope}/{base_name}' if scope else base_name
  81. @property
  82. def attr(self):
  83. """Get node attr."""
  84. return self._attr
  85. def add_attr(self, attr_dict):
  86. """
  87. Update node attr.
  88. Args:
  89. attr_dict (dict[str, str]): The attr of node.
  90. """
  91. self._attr.update(attr_dict)
  92. @property
  93. def inputs(self):
  94. """
  95. Get all input of current node.
  96. Returns:
  97. dict[str, dict], refer to the input attr.
  98. """
  99. return self._input
  100. def add_inputs(self, src_name, input_attr):
  101. """
  102. Update input.
  103. Args:
  104. src_name (stc): The source node name.
  105. input_attr (dict): The attribute of the input.
  106. - shape (list): The shape of input tensor.
  107. - edge_type (str): The type of edge, optional value refer to `EdgeTypeEnum`.
  108. - data_type (str): The data type of the input.
  109. - independent_layout (bool): Indicates whether the source nodes are laid out independently.
  110. """
  111. self._input.update({src_name: input_attr})
  112. def delete_inputs(self, src_name):
  113. """
  114. Delete input attribute by the given source name.
  115. Args:
  116. src_name (str): The source node name.
  117. """
  118. self._input.pop(src_name)
  119. @property
  120. def outputs(self):
  121. """The output node of this node."""
  122. return self._output
  123. def add_outputs(self, dst_name, output_attr):
  124. """
  125. Add a output node to this node.
  126. Args:
  127. dst_name (str): The name of the output node.
  128. output_attr (dict: Same as the input attribute.
  129. """
  130. self._output.update({dst_name: output_attr})
  131. def delete_outputs(self, dst_name):
  132. """
  133. Delete a output node.
  134. Args:
  135. dst_name (str): The name of the node to be deleted.
  136. """
  137. self._output.pop(dst_name)
  138. @property
  139. def proxy_inputs(self):
  140. """Return proxy input, type is dict."""
  141. return self._proxy_input
  142. def add_proxy_inputs(self, src_name, attr):
  143. """
  144. Add a proxy input to node.
  145. Args:
  146. src_name (str): The name of the input node.
  147. attr (dict): The attr of the input.
  148. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  149. """
  150. self._proxy_input.update({src_name: attr})
  151. def delete_proxy_inputs(self, src_name):
  152. """Delete a proxy input by the src name."""
  153. self._proxy_input.pop(src_name)
  154. @property
  155. def proxy_outputs(self):
  156. """Get proxy output, data type is dict."""
  157. return self._proxy_output
  158. def add_proxy_outputs(self, dst_name, attr):
  159. """
  160. Add a proxy output to node.
  161. Args:
  162. dst_name (str): The name of the output node.
  163. attr (dict): The attr of the output.
  164. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  165. """
  166. self._proxy_output.update({dst_name: attr})
  167. def delete_proxy_outputs(self, dst_name):
  168. """Delete a proxy output by dst name."""
  169. self._proxy_output.pop(dst_name)
  170. @staticmethod
  171. def copy_node_without_input_output(src_node, dst_node):
  172. """
  173. Copy a source node attribute to a new node, but not input and output.
  174. Args:
  175. src_node (Node): The copied node.
  176. dst_node (Node): The destination node.
  177. """
  178. dst_node.full_name = src_node.full_name
  179. dst_node.type = src_node.type
  180. dst_node.output_i = src_node.output_i
  181. dst_node.subnode_count = src_node.subnode_count
  182. dst_node.scope = src_node.scope
  183. dst_node.independent_layout = src_node.independent_layout
  184. dst_node.output_shape = src_node.output_shape
  185. dst_node.output_data_type = src_node.output_data_type
  186. dst_node.output_nums = src_node.output_nums
  187. dst_node.elem_types = src_node.elem_types
  188. dst_node.add_attr(src_node.attr)
  189. def __str__(self):
  190. return f'<Node, name: {self.name}, type: {self.type}>'