You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_processor.py 5.8 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. This file is to process `data_transform.data_manager` to handle graph,
  17. and the status of graph will be checked before calling `Graph` object.
  18. """
  19. from mindinsight.datavisual.common import exceptions
  20. from mindinsight.datavisual.common.enums import PluginNameEnum
  21. from mindinsight.datavisual.common.validation import Validation
  22. from mindinsight.datavisual.data_transform.graph import NodeTypeEnum
  23. from mindinsight.datavisual.processors.base_processor import BaseProcessor
  24. from mindinsight.utils.exceptions import ParamValueError
  25. class GraphProcessor(BaseProcessor):
  26. """
  27. This object is to handle `DataManager` object, and process graph object.
  28. Args:
  29. train_id (str): To get train job data by this given id.
  30. data_manager (DataManager): A `DataManager` object.
  31. tag (str): The tag of graph, if tag is None, will load the first graph.
  32. """
  33. def __init__(self, train_id, data_manager, tag=None):
  34. Validation.check_param_empty(train_id=train_id)
  35. super(GraphProcessor, self).__init__(data_manager)
  36. train_job = self._data_manager.get_train_job_by_plugin(train_id, PluginNameEnum.GRAPH.value)
  37. if train_job is None:
  38. raise exceptions.SummaryLogPathInvalid()
  39. if not train_job['tags']:
  40. raise ParamValueError("Can not find any graph data in the train job.")
  41. if tag is None:
  42. tag = train_job['tags'][0]
  43. tensors = self._data_manager.list_tensors(train_id, tag=tag)
  44. self._graph = tensors[0].value
  45. def get_nodes(self, name, node_type):
  46. """
  47. Get the nodes of every layer in graph.
  48. Args:
  49. name (str): The name of a node.
  50. node_type (Any): The type of node, either 'name_scope' or 'polymeric'.
  51. Returns:
  52. TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': [<Node object>]}.
  53. example:
  54. {
  55. "nodes" : [
  56. {
  57. "attr" :
  58. {
  59. "index" : "i: 0\n"
  60. },
  61. "input" : {},
  62. "name" : "input_tensor",
  63. "output" :
  64. {
  65. "Default/TensorAdd-op17" :
  66. {
  67. "edge_type" : "data",
  68. "scope" : "name_scope",
  69. "shape" : [1, 16, 128, 128]
  70. }
  71. },
  72. "output_i" : -1,
  73. "polymeric_input" : {},
  74. "polymeric_output" : {},
  75. "polymeric_scope_name" : "",
  76. "subnode_count" : 0,
  77. "type" : "Data"
  78. }
  79. ]
  80. }
  81. """
  82. if node_type not in [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value]:
  83. raise ParamValueError(
  84. 'The node type is not support, only either %s or %s.'
  85. '' % (NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value))
  86. if name and not self._graph.exist_node(name):
  87. raise ParamValueError("The node name is not in graph.")
  88. nodes = []
  89. if node_type == NodeTypeEnum.NAME_SCOPE.value:
  90. nodes = self._graph.get_normal_nodes(name)
  91. if node_type == NodeTypeEnum.POLYMERIC_SCOPE.value:
  92. if not name:
  93. raise ParamValueError('The node name "%s" not in graph, node type is %s.' %
  94. (name, node_type))
  95. polymeric_scope_name = name
  96. nodes = self._graph.get_polymeric_nodes(polymeric_scope_name)
  97. return {'nodes': nodes}
  98. def search_node_names(self, search_content, offset, limit):
  99. """
  100. Search node names by search content.
  101. Args:
  102. search_content (Any): This content can be the key content of the node to search.
  103. offset (int): An offset for page. Ex, offset is 0, mean current page is 1.
  104. limit (int): The max data items for per page.
  105. Returns:
  106. TypedDict('Names', {'names': list[str]}), {"names": ["node_names"]}.
  107. """
  108. offset = Validation.check_offset(offset=offset)
  109. limit = Validation.check_limit(limit, min_value=1, max_value=1000)
  110. names = self._graph.search_node_names(search_content, offset, limit)
  111. return {"names": names}
  112. def search_single_node(self, name):
  113. """
  114. Search node by node name.
  115. Args:
  116. name (str): The name of node.
  117. Returns:
  118. dict, format is:
  119. item_object = {'nodes': [<Node object>],
  120. 'scope_name': '',
  121. 'children': {<item_object>}}
  122. """
  123. Validation.check_param_empty(name=name)
  124. nodes = self._graph.search_single_node(name)
  125. return nodes

MindInsight为MindSpore提供了简单易用的调优调试能力。在训练过程中,可以将标量、张量、图像、计算图、模型超参、训练耗时等数据记录到文件中,通过MindInsight可视化页面进行查看及分析。

Contributors (1)