/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/gnode.h" #include #include "debug/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" #include "graph/node.h" #include "graph/utils/node_adapter.h" #include "graph/utils/tensor_adapter.h" #include #include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_op_types.h" #include "utils/node_utils.h" #include "utils/op_desc_utils.h" namespace ge { class NodeImpl { public: NodeImpl() = default; ~NodeImpl() = default; NodeImpl(NodeImpl &) = delete; NodeImpl &operator=(const NodeImpl &) = delete; std::weak_ptr node_ptr_; }; NodePtr NodeAdapter::GNode2Node(const ge::GNode &graph_node) { if (graph_node.impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GNode2Node: gnode impl is nullptr."); return nullptr; } return graph_node.impl_->node_ptr_.lock(); } GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "Node2GNode: node is nullptr"); return GNode(); } GNode graph_node; if (graph_node.impl_ == nullptr) { GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str()); return graph_node; } graph_node.impl_->node_ptr_ = node; return graph_node; } GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "Node2GNodePtr: node is nullptr"); return nullptr; } GNodePtr gnode = std::shared_ptr(new (std::nothrow) GNode()); if (gnode == nullptr) { GELOGE(GRAPH_FAILED, "Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str()); return nullptr; } if (gnode->impl_ == nullptr) { GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str()); return nullptr; } gnode->impl_->node_ptr_ = node; return gnode; } GNode::GNode() { impl_ = ComGraphMakeShared(); } graphStatus GNode::GetType(AscendString &type) const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetType: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetType: the shared ptr is not valid."); return GRAPH_FAILED; } std::string node_type = node_ptr->GetType(); AscendString ascend_type(node_type.c_str()); type = ascend_type; return GRAPH_SUCCESS; } graphStatus GNode::GetName(AscendString &name) const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetName: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetName: the shared ptr is not valid."); return GRAPH_FAILED; } std::string node_name = node_ptr->GetName(); AscendString ascend_name(node_name.c_str()); name = ascend_name; return GRAPH_SUCCESS; } std::pair GNode::GetInDataNodesAndPortIndexs(const int32_t index) const { pair gnode_idx = {nullptr, 0xFF}; if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); return gnode_idx; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); return gnode_idx; } auto in_anchor = node_ptr->GetInDataAnchor(index); if (in_anchor == nullptr) { GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node[%s], the anchor does not exist", index, node_ptr->GetName().c_str()); return gnode_idx; } auto out_anchor = in_anchor->GetPeerOutAnchor(); if (out_anchor == nullptr) { GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node [%s], the data input does not exist", index, node_ptr->GetName().c_str()); return gnode_idx; } NodePtr peer_node_ptr = out_anchor->GetOwnerNode(); GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); if (gnode == nullptr) { GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); return gnode_idx; } return {gnode, out_anchor->GetIdx()}; } std::vector GNode::GetInControlNodes() const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); return {}; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); return {}; } std::vector gnodes; auto in_control_nodes = node_ptr->GetInControlNodes(); for (auto &in_control_node : in_control_nodes) { GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node); if (gnode == nullptr) { GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); return {}; } gnodes.emplace_back(gnode); } return gnodes; } std::vector> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); return {}; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); return {}; } auto out_anchor = node_ptr->GetOutDataAnchor(index); if (out_anchor == nullptr) { GELOGE(GRAPH_FAILED, "Failed to get out data node of index %d from node %s, the anchor does not exists", index, node_ptr->GetName().c_str()); return {}; } vector> gnode_index; auto in_data_anchors = out_anchor->GetPeerInDataAnchors(); for (auto &in_data_anchor : in_data_anchors) { if (in_data_anchor == nullptr) { GELOGE(GRAPH_FAILED, "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str()); return {}; } NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode(); GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); if (gnode == nullptr) { GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); return {}; } gnode_index.emplace_back(std::pair(gnode, in_data_anchor->GetIdx())); } return gnode_index; } std::vector GNode::GetOutControlNodes() const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetOutControlNodes: node impl is nullptr."); return {}; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetOutControlNodes: the node shared ptr is not valid."); return {}; } std::vector gnodes; auto out_control_nodes = node_ptr->GetOutControlNodes(); for (auto &out_control_node : out_control_nodes) { GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node); if (gnode == nullptr) { GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); return {}; } gnodes.emplace_back(gnode); } return gnodes; } graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetInputConstData: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetInputConstData: the node shared ptr is not valid."); return GRAPH_FAILED; } NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index); GE_CHECK_NOTNULL(input_data_node); string op_type = input_data_node->GetType(); if (op_type == CONSTANT || op_type == CONSTANTOP) { Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node); if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", input_data_node->GetName().c_str(), node_ptr->GetName().c_str()); return GRAPH_FAILED; } return SUCCESS; } else if (op_type == DATA) { auto parent_node = NodeUtils::GetParentInput(input_data_node); while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { parent_node = NodeUtils::GetParentInput(parent_node); } if ((parent_node != nullptr) && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { Operator const_op = OpDescUtils::CreateOperatorFromNode(parent_node); if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", parent_node->GetName().c_str(), node_ptr->GetName().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } } GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str()); return GRAPH_NODE_WITHOUT_CONST_INPUT; } graphStatus GNode::GetInputIndexByName(const AscendString &name, int32_t &index) { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "GetInputIndexByName: ascend string error."); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetInputIndexByName: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetInputIndexByName: the node shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } std::string node_name = ascend_name; index = op_desc->GetInputIndexByName(node_name); return GRAPH_SUCCESS; } graphStatus GNode::GetOutputIndexByName(const AscendString &name, int32_t &index) { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "GetOutputIndexByName: ascend string error."); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetOutputIndexByName: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetOutputIndexByName: the node shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } std::string node_name = ascend_name; index = op_desc->GetOutputIndexByName(node_name); return GRAPH_SUCCESS; } size_t GNode::GetInputsSize() const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetInputsSize: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetInputsSize: the node shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } return op_desc->GetInputsSize(); } size_t GNode::GetOutputsSize() const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetOutputsSize: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetOutputsSize: the shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } return op_desc->GetOutputsSize(); } graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const { if (index < 0) { GELOGE(GRAPH_PARAM_INVALID, "GetInputDesc: index[%d] cannot be less than zero.", index); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetInputDesc: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetInputDesc: the node shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast(index)); if (ge_tensor_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); return GRAPH_SUCCESS; } graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) { if (index < 0) { GELOGE(GRAPH_PARAM_INVALID, "UpdateInputDesc: index[%d] cannot be less than zero.", index); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "UpdateInputDesc: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "UpdateInputDesc: the node shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); if (op_desc->UpdateInputDesc(static_cast(index), ge_tensor_desc) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const { if (index < 0) { GELOGE(GRAPH_PARAM_INVALID, "GetOutputDesc: index[%d] cannot be less than zero.", index); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetOutputDesc: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetOutputDesc: the node shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast(index)); if (ge_tensor_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); return GRAPH_SUCCESS; } graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) { if (index < 0) { GELOGE(GRAPH_PARAM_INVALID, "Gnode: index[%d] cannot be less than zero.", index); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "UpdateOutputDesc: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "UpdateOutputDesc: the shared ptr is not valid."); return GRAPH_FAILED; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); if (op_desc->UpdateOutputDesc(static_cast(index), ge_tensor_desc) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } #define NODE_ATTR_GET_IMP(ArgType) \ graphStatus GNode::GetAttr(const AscendString &name, ArgType &attr_value) const { \ const char* ascend_name = name.GetString(); \ if (ascend_name == nullptr) { \ GELOGE(GRAPH_PARAM_INVALID, "GetAttr: ascend string error."); \ return GRAPH_PARAM_INVALID; \ } \ \ if (impl_ == nullptr) { \ GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); \ return GRAPH_FAILED; \ } \ \ std::shared_ptr node_ptr = impl_->node_ptr_.lock(); \ if (node_ptr == nullptr) { \ GELOGE(GRAPH_FAILED, "GetAttr: the shared ptr is not valid."); \ return GRAPH_FAILED; \ } \ \ std::string node_name = ascend_name; \ Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \ if (op.GetAttr(node_name, attr_value) != GRAPH_SUCCESS) { \ GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); \ return GRAPH_FAILED; \ } \ \ return GRAPH_SUCCESS; \ } #define NODE_ATTR_SET_IMP(ArgType) \ graphStatus GNode::SetAttr(const AscendString &name, ArgType &attr_value) const { \ const char* ascend_name = name.GetString(); \ if (ascend_name == nullptr) { \ GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); \ return GRAPH_PARAM_INVALID; \ } \ \ if (impl_ == nullptr) { \ GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); \ return GRAPH_FAILED; \ } \ \ std::shared_ptr node_ptr = impl_->node_ptr_.lock(); \ if (node_ptr == nullptr) { \ GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); \ return GRAPH_FAILED; \ } \ \ std::string node_name = ascend_name; \ Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \ (void)op.SetAttr(node_name, attr_value); \ return GRAPH_SUCCESS; \ } NODE_ATTR_GET_IMP(int64_t) NODE_ATTR_GET_IMP(int32_t) NODE_ATTR_GET_IMP(uint32_t) NODE_ATTR_GET_IMP(float) NODE_ATTR_GET_IMP(bool) NODE_ATTR_GET_IMP(Tensor) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(OpBytes) NODE_ATTR_GET_IMP(std::vector>) NODE_ATTR_GET_IMP(std::vector) NODE_ATTR_GET_IMP(ge::DataType) NODE_ATTR_GET_IMP(AttrValue) NODE_ATTR_SET_IMP(int64_t) NODE_ATTR_SET_IMP(int32_t) NODE_ATTR_SET_IMP(uint32_t) NODE_ATTR_SET_IMP(float) NODE_ATTR_SET_IMP(bool) NODE_ATTR_SET_IMP(Tensor) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(OpBytes) NODE_ATTR_SET_IMP(std::vector>) NODE_ATTR_SET_IMP(std::vector) NODE_ATTR_SET_IMP(ge::DataType) graphStatus GNode::SetAttr(const AscendString &name, AttrValue &attr_value) const { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); return GRAPH_FAILED; } std::string node_name = ascend_name; Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); (void)op.SetAttr(node_name, std::move(attr_value)); return GRAPH_SUCCESS; } graphStatus GNode::SetAttr(const AscendString &name, AscendString &attr_value) const { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error."); return GRAPH_PARAM_INVALID; } const char* ascend_attr_value = attr_value.GetString(); if (ascend_attr_value == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr value ascend string error."); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); return GRAPH_FAILED; } std::string node_name = ascend_name; std::string node_attr_value = ascend_attr_value; Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); (void)op.SetAttr(node_name, node_attr_value); return GRAPH_SUCCESS; } graphStatus GNode::SetAttr(const AscendString &name, std::vector &attr_values) const { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error."); return GRAPH_PARAM_INVALID; } for (auto &attr_val : attr_values) { const char* ascend_attr_value = attr_val.GetString(); if (ascend_attr_value == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr val error."); return GRAPH_PARAM_INVALID; } } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); return GRAPH_FAILED; } vector node_attr_vals; for (auto attr_val : attr_values) { if (attr_val.GetString() != nullptr) { std::string node_attr_val = attr_val.GetString(); node_attr_vals.emplace_back(node_attr_val); } } std::string node_name = ascend_name; Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); (void)op.SetAttr(node_name, node_attr_vals); return GRAPH_SUCCESS; } graphStatus GNode::GetAttr(const AscendString &name, AscendString &attr_value) const { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid."); return GRAPH_FAILED; } std::string node_name = ascend_name; Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); std::string op_name; if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } AscendString attr_value_get(op_name.c_str()); attr_value = attr_value_get; return GRAPH_SUCCESS; } graphStatus GNode::GetAttr(const AscendString &name, std::vector &attr_values) const { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); return GRAPH_PARAM_INVALID; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid."); return GRAPH_FAILED; } std::string node_name = ascend_name; Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); vector attr_names; if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); return GRAPH_FAILED; } for (auto &attr_name : attr_names) { AscendString ascend_attr_name(attr_name.c_str()); attr_values.push_back(ascend_attr_name); } return GRAPH_SUCCESS; } bool GNode::HasAttr(const AscendString &name) { const char* ascend_name = name.GetString(); if (ascend_name == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "HasAttr: ascend string error."); return false; } if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "HasAttr: node impl is nullptr."); return false; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "HasAttr: the node shared ptr is not valid."); return false; } OpDescPtr op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); return false; } std::string attr_name = ascend_name; if (!op_desc->HasAttr(attr_name)) { GELOGE(GRAPH_FAILED, "Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str()); return false; } return true; } graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr &graph) const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetSubgraph: the node shared ptr is not valid."); return GRAPH_FAILED; } ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index); if (compute_graph_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); return GRAPH_FAILED; } graph = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph_ptr); if (graph == nullptr) { GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus GNode::GetALLSubgraphs(std::vector &graph_list) const { if (impl_ == nullptr) { GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr."); return GRAPH_FAILED; } std::shared_ptr node_ptr = impl_->node_ptr_.lock(); if (node_ptr == nullptr) { GELOGE(GRAPH_FAILED, "GetALLSubgraphs: the node shared ptr is not valid."); return GRAPH_FAILED; } std::vector sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr); if (sub_graphs.empty()) { GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed from node[%s].", node_ptr->GetName().c_str()); return GRAPH_FAILED; } for (auto &sub_graph : sub_graphs) { if (sub_graph == nullptr) { GELOGE(GRAPH_FAILED, "Get subgraph failed from node[%s].", node_ptr->GetName().c_str()); return GRAPH_FAILED; } GraphPtr graph = GraphUtils::CreateGraphPtrFromComputeGraph(sub_graph); if (graph == nullptr) { GELOGE(GRAPH_FAILED, "Subgraph create compute graph failed from node[%s].", node_ptr->GetName().c_str()); return GRAPH_FAILED; } graph_list.emplace_back(graph); } if (graph_list.empty()) { GELOGW("Node[%s] has no subgraph.", node_ptr->GetName().c_str()); } return GRAPH_SUCCESS; } } // namespace ge