You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 6.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the node of graph and associated base types.
  17. """
  18. from enum import Enum
  19. class NodeTypeEnum(Enum):
  20. """Node type enum. The following types are new to our custom."""
  21. NAME_SCOPE = 'name_scope'
  22. AGGREGATION_SCOPE = 'aggregation_scope'
  23. PARAMETER = 'Parameter'
  24. CONST = 'Const'
  25. class Node:
  26. """
  27. Define a node object.
  28. Args:
  29. name (str): Name of new node.
  30. node_id (str): The id of this node, and node id is unique in graph.
  31. """
  32. def __init__(self, name, node_id):
  33. self._node_id = node_id
  34. self.name = name
  35. self.type = ""
  36. self._attr = dict()
  37. self._input = dict()
  38. self.output_i = 0
  39. self._output = {}
  40. self._proxy_input = {}
  41. self._proxy_output = {}
  42. self.subnode_count = 0
  43. self.scope = ""
  44. self.independent_layout = False
  45. self.output_shape = []
  46. self.output_data_type = ""
  47. def to_dict(self):
  48. """Converts the node object to dictionary format."""
  49. return {
  50. 'name': self.name,
  51. 'type': self.type,
  52. 'attr': self._attr,
  53. 'input': self._input,
  54. 'output': self._output,
  55. 'output_i': self.output_i,
  56. 'proxy_input': self._proxy_input,
  57. 'proxy_output': self._proxy_output,
  58. 'subnode_count': self.subnode_count,
  59. 'independent_layout': self.independent_layout
  60. }
  61. @property
  62. def node_id(self):
  63. """The id of this node, and id is unique in graph."""
  64. return self._node_id
  65. @staticmethod
  66. def create_node_name(scope, base_name):
  67. """
  68. The name of the node consists of the scope and the basic name.
  69. Args:
  70. scope (str): The scope of node, such as 'Default/Conv2D'
  71. base_name (str): The base name of node, such as 'Add11'.
  72. Returns:
  73. str, a node name.
  74. """
  75. return f'{scope}/{base_name}' if scope else base_name
  76. @property
  77. def attr(self):
  78. """Get node attr."""
  79. return self._attr
  80. def add_attr(self, attr_dict):
  81. """
  82. Update node attr.
  83. Args:
  84. attr_dict (dict[str, str]): The attr of node.
  85. """
  86. self._attr.update(attr_dict)
  87. @property
  88. def input(self):
  89. """
  90. Get all input of current node.
  91. Returns:
  92. dict[str, dict], refer to the input attr.
  93. """
  94. return self._input
  95. def add_input(self, src_name, input_attr):
  96. """
  97. Update input.
  98. Args:
  99. src_name (stc): The source node name.
  100. input_attr (dict): The attribute of the input.
  101. - shape (list): The shape of input tensor.
  102. - edge_type (str): The type of edge, optional value refer to `EdgeTypeEnum`.
  103. - data_type (str): The data type of the input.
  104. - independent_layout (bool): Indicates whether the source nodes are laid out independently.
  105. """
  106. self._input.update({src_name: input_attr})
  107. def delete_input(self, src_name):
  108. """
  109. Delete input attribute by the given source name.
  110. Args:
  111. src_name (str): The source node name.
  112. """
  113. self._input.pop(src_name)
  114. @property
  115. def output(self):
  116. """The output node of this node."""
  117. return self._output
  118. def add_output(self, dst_name, output_attr):
  119. """
  120. Add a output node to this node.
  121. Args:
  122. dst_name (str): The name of the output node.
  123. output_attr (dict: Same as the input attribute.
  124. """
  125. self._output.update({dst_name: output_attr})
  126. def delete_output(self, dst_name):
  127. """
  128. Delete a output node.
  129. Args:
  130. dst_name (str): The name of the node to be deleted.
  131. """
  132. self._output.pop(dst_name)
  133. @property
  134. def proxy_input(self):
  135. """Return proxy input, type is dict."""
  136. return self._proxy_input
  137. def add_proxy_input(self, src_name, attr):
  138. """
  139. Add a proxy input to node.
  140. Args:
  141. src_name (str): The name of the input node.
  142. attr (dict): The attr of the input.
  143. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  144. """
  145. self._proxy_input.update({src_name: attr})
  146. def delete_proxy_input(self, src_name):
  147. """Delete a proxy input by the src name."""
  148. self._proxy_input.pop(src_name)
  149. @property
  150. def proxy_output(self):
  151. """Get proxy output, data type is dict."""
  152. return self._proxy_output
  153. def add_proxy_output(self, dst_name, attr):
  154. """
  155. Add a proxy output to node.
  156. Args:
  157. dst_name (str): The name of the output node.
  158. attr (dict): The attr of the output.
  159. - edge_type (str): The edge type, refer to `EdgeTypeEnum`.
  160. """
  161. self._proxy_output.update({dst_name: attr})
  162. def delete_proxy_output(self, dst_name):
  163. """Delete a proxy output by dst name."""
  164. self._proxy_output.pop(dst_name)
  165. @staticmethod
  166. def copy_node_without_input_output(src_node, dst_node):
  167. """
  168. Copy a source node attribute to a new node, but not input and output.
  169. Args:
  170. src_node (Node): The copied node.
  171. dst_node (Node): The destination node.
  172. """
  173. dst_node.type = src_node.type
  174. dst_node.output_i = src_node.output_i
  175. dst_node.subnode_count = src_node.subnode_count
  176. dst_node.scope = src_node.scope
  177. dst_node.independent_layout = src_node.independent_layout
  178. dst_node.output_shape = src_node.output_shape
  179. dst_node.output_data_type = src_node.output_data_type
  180. dst_node.add_attr(src_node.attr)
  181. def __str__(self):
  182. return f'<Node, name: {self.name}, type: {self.type}>'

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。