You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 6.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the node of graph and associated base types.
  17. """
  18. from enum import Enum
  19. class NodeTypeEnum(Enum):
  20. """Node type enum. The following types are new to our custom."""
  21. NAME_SCOPE = 'name_scope'
  22. AGGREGATION_SCOPE = 'aggregation_scope'
  23. PARAMETER = 'Parameter'
  24. CONST = 'Const'
  25. class Node:
  26. """
  27. Define a node object.
  28. Args:
  29. name (str): Name of new node.
  30. node_id (str): The id of this node, and node id is unique in graph.
  31. """
  32. def __init__(self, name, node_id):
  33. self._node_id = node_id
  34. self.name = name
  35. self.type = ""
  36. self._attr = dict()
  37. self._input = dict()
  38. self.output_i = 0
  39. self._output = {}
  40. self._proxy_input = {}
  41. self._proxy_output = {}
  42. self.subnode_count = 0
  43. self.scope = ""
  44. self.independent_layout = False
  45. self.output_shape = []
  46. self.output_data_type = ""
  47. self.output_nums = 0
  48. self.elem_types = []
  49. self.full_name = ""
  50. def to_dict(self):
  51. """Converts the node object to dictionary format."""
  52. return {
  53. 'name': self.name,
  54. 'type': self.type,
  55. 'attr': self._attr,
  56. 'input': self._input,
  57. 'output': self._output,
  58. 'output_i': self.output_i,
  59. 'proxy_input': self._proxy_input,
  60. 'proxy_output': self._proxy_output,
  61. 'subnode_count': self.subnode_count,
  62. 'independent_layout': self.independent_layout
  63. }
  64. @property
  65. def node_id(self):
  66. """The id of this node, and id is unique in graph."""
  67. return self._node_id
  68. @staticmethod
  69. def create_node_name(scope, base_name):
  70. """
  71. The name of the node consists of the scope and the basic name.
  72. Args:
  73. scope (str): The scope of node, such as 'Default/Conv2D'
  74. base_name (str): The base name of node, such as 'Add11'.
  75. Returns:
  76. str, a node name.
  77. """
  78. return f'{scope}/{base_name}' if scope else base_name
  79. @property
  80. def attr(self):
  81. """Get node attr."""
  82. return self._attr
  83. def add_attr(self, attr_dict):
  84. """
  85. Update node attr.
  86. Args:
  87. attr_dict (dict[str, str]): The attr of node.
  88. """
  89. self._attr.update(attr_dict)
  90. @property
  91. def inputs(self):
  92. """
  93. Get all input of current node.
  94. Returns:
  95. dict[str, dict], refer to the input attr.
  96. """
  97. return self._input
  98. def add_inputs(self, src_name, input_attr):
  99. """
  100. Update input.
  101. Args:
  102. src_name (stc): The source node name.
  103. input_attr (dict): The attribute of the input.
  104. - shape (list): The shape of input tensor.
  105. - edge_type (str): The type of edge, optional value refer to `EdgeTypeEnum`.
  106. - data_type (str): The data type of the input.
  107. - independent_layout (bool): Indicates whether the source nodes are laid out independently.
  108. """
  109. self._input.update({src_name: input_attr})
  110. def delete_inputs(self, src_name):
  111. """
  112. Delete input attribute by the given source name.
  113. Args:
  114. src_name (str): The source node name.
  115. """
  116. self._input.pop(src_name)
  117. @property
  118. def outputs(self):
  119. """The output node of this node."""
  120. return self._output
  121. def add_outputs(self, dst_name, output_attr):
  122. """
  123. Add a output node to this node.
  124. Args:
  125. dst_name (str): The name of the output node.
  126. output_attr (dict: Same as the input attribute.
  127. """
  128. self._output.update({dst_name: output_attr})
  129. def delete_outputs(self, dst_name):
  130. """
  131. Delete a output node.
  132. Args:
  133. dst_name (str): The name of the node to be deleted.
  134. """
  135. self._output.pop(dst_name)
  136. @property
  137. def proxy_inputs(self):
  138. """Return proxy input, type is dict."""
  139. return self._proxy_input
  140. def add_proxy_inputs(self, src_name, attr):
  141. """
  142. Add a proxy input to node.
  143. Args:
  144. src_name (str): The name of the input node.
  145. attr (dict): The attr of the input.
  146. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  147. """
  148. self._proxy_input.update({src_name: attr})
  149. def delete_proxy_inputs(self, src_name):
  150. """Delete a proxy input by the src name."""
  151. self._proxy_input.pop(src_name)
  152. @property
  153. def proxy_outputs(self):
  154. """Get proxy output, data type is dict."""
  155. return self._proxy_output
  156. def add_proxy_outputs(self, dst_name, attr):
  157. """
  158. Add a proxy output to node.
  159. Args:
  160. dst_name (str): The name of the output node.
  161. attr (dict): The attr of the output.
  162. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  163. """
  164. self._proxy_output.update({dst_name: attr})
  165. def delete_proxy_outputs(self, dst_name):
  166. """Delete a proxy output by dst name."""
  167. self._proxy_output.pop(dst_name)
  168. @staticmethod
  169. def copy_node_without_input_output(src_node, dst_node):
  170. """
  171. Copy a source node attribute to a new node, but not input and output.
  172. Args:
  173. src_node (Node): The copied node.
  174. dst_node (Node): The destination node.
  175. """
  176. dst_node.full_name = src_node.full_name
  177. dst_node.type = src_node.type
  178. dst_node.output_i = src_node.output_i
  179. dst_node.subnode_count = src_node.subnode_count
  180. dst_node.scope = src_node.scope
  181. dst_node.independent_layout = src_node.independent_layout
  182. dst_node.output_shape = src_node.output_shape
  183. dst_node.output_data_type = src_node.output_data_type
  184. dst_node.output_nums = src_node.output_nums
  185. dst_node.elem_types = src_node.elem_types
  186. dst_node.add_attr(src_node.attr)
  187. def __str__(self):
  188. return f'<Node, name: {self.name}, type: {self.type}>'