# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ This file is to process `data_transform.data_manager` to handle graph, and the status of graph will be checked before calling `Graph` object. """ from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.validation import Validation from mindinsight.datavisual.data_transform.graph import NodeTypeEnum from mindinsight.datavisual.processors.base_processor import BaseProcessor from mindinsight.utils.exceptions import ParamValueError class GraphProcessor(BaseProcessor): """ This object is to handle `DataManager` object, and process graph object. Args: train_id (str): To get train job data by this given id. data_manager (DataManager): A `DataManager` object. tag (str): The tag of graph, if tag is None, will load the first graph. """ def __init__(self, train_id, data_manager, tag=None): Validation.check_param_empty(train_id=train_id) super(GraphProcessor, self).__init__(data_manager) train_job = self._data_manager.get_train_job_by_plugin(train_id, PluginNameEnum.GRAPH.value) if train_job is None: raise exceptions.SummaryLogPathInvalid() if not train_job['tags']: raise ParamValueError("Can not find any graph data in the train job.") if tag is None: tag = train_job['tags'][0] tensors = self._data_manager.list_tensors(train_id, tag=tag) self._graph = tensors[0].value def get_nodes(self, name, node_type): """ Get the nodes of every layer in graph. Args: name (str): The name of a node. node_type (Any): The type of node, either 'name_scope' or 'polymeric'. Returns: TypedDict('Nodes', {'nodes': list[Node]}), format is {'nodes': []}. example: { "nodes" : [ { "attr" : { "index" : "i: 0\n" }, "input" : {}, "name" : "input_tensor", "output" : { "Default/TensorAdd-op17" : { "edge_type" : "data", "scope" : "name_scope", "shape" : [1, 16, 128, 128] } }, "output_i" : -1, "polymeric_input" : {}, "polymeric_output" : {}, "polymeric_scope_name" : "", "subnode_count" : 0, "type" : "Data" } ] } """ if node_type not in [NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value]: raise ParamValueError( 'The node type is not support, only either %s or %s.' '' % (NodeTypeEnum.NAME_SCOPE.value, NodeTypeEnum.POLYMERIC_SCOPE.value)) if name and not self._graph.exist_node(name): raise ParamValueError("The node name is not in graph.") nodes = [] if node_type == NodeTypeEnum.NAME_SCOPE.value: nodes = self._graph.get_normal_nodes(name) if node_type == NodeTypeEnum.POLYMERIC_SCOPE.value: if not name: raise ParamValueError('The node name "%s" not in graph, node type is %s.' % (name, node_type)) polymeric_scope_name = name nodes = self._graph.get_polymeric_nodes(polymeric_scope_name) return {'nodes': nodes} def search_node_names(self, search_content, offset, limit): """ Search node names by search content. Args: search_content (Any): This content can be the key content of the node to search. offset (int): An offset for page. Ex, offset is 0, mean current page is 1. limit (int): The max data items for per page. Returns: TypedDict('Names', {'names': list[str]}), {"names": ["node_names"]}. """ offset = Validation.check_offset(offset=offset) limit = Validation.check_limit(limit, min_value=1, max_value=1000) names = self._graph.search_node_names(search_content, offset, limit) return {"names": names} def search_single_node(self, name): """ Search node by node name. Args: name (str): The name of node. Returns: dict, format is: item_object = {'nodes': [], 'scope_name': '', 'children': {}} """ Validation.check_param_empty(name=name) nodes = self._graph.search_single_node(name) return nodes