You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 2.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is used to define the node of graph and associated base types.
  17. """
  18. from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
  19. from mindinsight.debugger.common.log import logger as log
  20. class NodeTree:
  21. """A class for building a node tree."""
  22. def __init__(self, node_name='', node_type=None):
  23. self.node_name = node_name
  24. self._node_type = node_type
  25. self._children = {}
  26. @property
  27. def node_type(self):
  28. """The property of node type."""
  29. return self._node_type
  30. @node_type.setter
  31. def node_type(self, value):
  32. """Set the node type."""
  33. self._node_type = value
  34. def add(self, name, node_type=None):
  35. """Add sub node."""
  36. sub_name = '/'.join([self.node_name, name]) if self.node_name else name
  37. sub_node = NodeTree(sub_name, node_type)
  38. self._children[name] = sub_node
  39. return sub_node
  40. def get(self, sub_name):
  41. """Get sub node."""
  42. return self._children.get(sub_name)
  43. def get_children(self):
  44. """Get all childrens."""
  45. for name_scope, sub_node in self._children.items():
  46. yield name_scope, sub_node
  47. def remove(self, sub_name):
  48. """Remove sub node."""
  49. try:
  50. self._children.pop(sub_name)
  51. except KeyError as err:
  52. log.error("Failed to find node %s. %s", sub_name, err)
  53. raise DebuggerParamValueError("Failed to find node {}".format(sub_name))