|
- /**
- * Copyright 2019-2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include "utils/node_utils.h"
- #include "utils/op_desc_utils.h"
- #include "graph/utils/graph_utils.h"
- #include "debug/ge_op_types.h"
- #include "debug/ge_util.h"
- #include "framework/common/debug/ge_log.h"
- #include "graph/anchor.h"
- #include "graph/debug/ge_attr_define.h"
- #include "graph/types.h"
- #include "utils/tensor_utils.h"
- #include "utils/type_utils.h"
-
- namespace ge {
- std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
- std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
-
- const std::set<std::string> kConstOpTypes = {"Const", "Constant"};
-
- const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
- const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
- const std::set<std::string> kCaseOpTypes = {"Case"};
- const std::set<std::string> kForOpTypes = {"For"};
-
- bool OpShapeIsUnknown(const OpDescPtr &desc) {
- for (const auto &ptr : desc->GetAllInputsDescPtr()) {
- auto ge_shape = ptr->GetShape();
- for (const auto &dim : ge_shape.GetDims()) {
- if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
- return true;
- }
- }
- }
- for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
- auto ge_shape = ptr->GetShape();
- for (const auto &dim : ge_shape.GetDims()) {
- if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
- return true;
- }
- }
- }
- return false;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
- const uint32_t &event_id) {
- GE_CHECK_NOTNULL(node);
- map_send_info_[node].push_back(event_id);
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
- const uint32_t &event_id) {
- GE_CHECK_NOTNULL(node);
- map_recv_info_[node].push_back(event_id);
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
- NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
- GE_CHECK_NOTNULL(node);
- auto find = map_send_info_.find(node);
- if (find == map_send_info_.end()) {
- return GRAPH_FAILED;
- } else {
- vec_send = find->second;
- return GRAPH_SUCCESS;
- }
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
- NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
- GE_CHECK_NOTNULL(node);
- auto find = map_recv_info_.find(node);
- if (find == map_recv_info_.end()) {
- return GRAPH_FAILED;
- } else {
- vec_recv = find->second;
- return GRAPH_SUCCESS;
- }
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
- map_send_info_.clear();
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
- map_recv_info_.clear();
- return GRAPH_SUCCESS;
- }
-
- graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
- GE_CHECK_NOTNULL(src);
- NodePtr cur_ptr;
- if (depth < 1) {
- return GRAPH_FAILED;
- }
- for (int i = 0; i < depth; i++) {
- if (src->GetOutDataNodes().size() != 1) {
- return GRAPH_FAILED;
- }
- cur_ptr = src->GetOutDataNodes().at(0);
- GE_CHECK_NOTNULL(cur_ptr);
- }
- dst = cur_ptr;
- return GRAPH_SUCCESS;
- }
-
- graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
- InControlAnchorPtr &in_control) {
- GE_CHECK_NOTNULL(node_ptr);
- for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
- GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
- for (const auto &p_in : p->GetPeerInControlAnchors()) {
- GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
- out_data = p;
- in_control = p_in;
- return GRAPH_SUCCESS;
- }
- }
- return GRAPH_FAILED;
- }
-
- graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
- GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
- "node or in_data_anchor is nullptr");
-
- bool find_flag = false;
- uint32_t index = 0;
- vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
- for (const auto &tmp : node_ptr->in_data_anchors_) {
- if (tmp == in_data_anchor) {
- find_flag = true;
- auto iter = node_ptr->in_data_anchors_.begin() + index;
- if (iter != node_ptr->in_data_anchors_.end()) {
- it = node_ptr->in_data_anchors_.erase(iter);
- }
- break;
- }
- index++;
- }
- for (; it != node_ptr->in_data_anchors_.end(); ++it) {
- (*it)->SetIdx(index);
- index++;
- }
-
- if (!find_flag) {
- return GRAPH_FAILED;
- }
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
- GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
- GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
- return GRAPH_SUCCESS;
- }
-
- graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
- node.anchor_status_updated_ = true;
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
- GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
- return IsAnchorStatusSet(*node_ptr);
- }
-
- bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
-
- graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
- if ((origin_node == nullptr) || (new_node == nullptr)) {
- return GRAPH_FAILED;
- }
- auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
- auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
- if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
- return GRAPH_FAILED;
- }
-
- for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
- for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
- GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
- "unlink peer_anchor failed");
- GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
- "linkto peer_anchor failed");
- }
-
- for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
- GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
- "unlink peer_anchor failed");
- GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
- "linkto peer_anchor failed");
- }
- }
-
- auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
- GE_CHECK_NOTNULL(origin_out_control_anchor);
- auto new_out_control_anchor = new_node->GetOutControlAnchor();
- GE_CHECK_NOTNULL(new_out_control_anchor);
- for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
- GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
- "linkto peer_anchor failed");
- }
- for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
- GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
- "linkto peer_anchor failed");
- }
- origin_out_control_anchor->UnlinkAll();
-
- return GRAPH_SUCCESS;
- }
-
- bool NodeUtils::IsConst(const Node &node) {
- auto src_node_type = node.GetType();
- bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
- return is_const;
- }
-
- void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
- if (node_ptr == nullptr) {
- GELOGE(GRAPH_FAILED, "node is null");
- return;
- }
- UpdateIsInputConst(*node_ptr);
- }
-
- ///
- /// update is_input_const
- /// @param node
- /// @return void
- ///
- void NodeUtils::UpdateIsInputConst(Node &node) {
- std::vector<bool> is_input_const;
- size_t anchor_num = node.GetAllInDataAnchors().size();
- for (size_t i = 0; i < anchor_num; i++) {
- auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
- if (in_anchor == nullptr) {
- is_input_const.push_back(false);
- continue;
- }
- auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
- if (peer_out_anchor == nullptr) {
- is_input_const.push_back(false);
- continue;
- }
- auto src_node = peer_out_anchor->GetOwnerNode();
- if (src_node == nullptr) {
- is_input_const.push_back(false);
- continue;
- }
- if (IsConst(*(src_node))) {
- is_input_const.push_back(true);
- } else {
- is_input_const.push_back(false);
- }
- }
- if (node.GetOpDesc() == nullptr) {
- GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
- return;
- }
- node.GetOpDesc()->SetIsInputConst(is_input_const);
- }
-
- void NodeUtils::UnlinkAll(const Node &node) {
- for (const auto &anchor : node.GetAllOutAnchors()) {
- anchor->UnlinkAll();
- }
- for (const auto &anchor : node.GetAllInAnchors()) {
- anchor->UnlinkAll();
- }
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
- if (node_ptr == nullptr) {
- GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
- return GRAPH_FAILED;
- }
- auto op_desc = node_ptr->GetOpDesc();
- if (op_desc == nullptr) {
- return GRAPH_FAILED;
- }
- bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
- if (is_unknown_graph) {
- return GRAPH_SUCCESS;
- }
- for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
- auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
- auto out_dims = output_tensor->GetShape().GetDims();
- auto out_dtype = output_tensor->GetDataType();
- ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
- output_tensor->SetOriginShape(output_tensor->GetShape());
- output_tensor->SetOriginDataType(output_tensor->GetDataType());
-
- GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
- node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
- TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
- TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
-
- for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
- if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
- GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
- continue;
- }
- auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
- if (peer_input_desc == nullptr) {
- GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
- continue;
- }
- // check shape and dtype continuity. do not stop process
- auto peer_input_dims = peer_input_desc->GetShape().GetDims();
- auto peer_input_dtype = peer_input_desc->GetDataType();
- if (out_dtype != peer_input_dtype) {
- GELOGW(
- "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th "
- "input_dtype is [%s].The two dtype should be same! Please check graph and fix it",
- node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(),
- peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
- TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str());
- } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) {
- string out_shape_str, peer_in_shape_str;
- out_shape_str += "[";
- for (int64_t dim : out_dims) {
- out_shape_str += std::to_string(dim) + " ";
- }
- out_shape_str += "]";
- peer_in_shape_str += "[";
- for (int64_t dim : peer_input_dims) {
- peer_in_shape_str += std::to_string(dim) + " ";
- }
- peer_in_shape_str += "]";
-
- GELOGW(
- "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
- "input_shape is [%s].The two shape should be same! Please check graph and fix it",
- node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(),
- peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str());
- }
- GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
- peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
- output_tensor->GetDataType(), output_tensor->GetOriginDataType());
- peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
- peer_input_desc->SetShape(output_tensor->GetShape());
- peer_input_desc->SetDataType(output_tensor->GetDataType());
- peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
- std::vector<std::pair<int64_t, int64_t>> shape_range;
- (void)output_tensor->GetShapeRange(shape_range);
- peer_input_desc->SetShapeRange(shape_range);
- ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
- static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
- GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
- peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
- peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
- }
- }
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node,
- uint32_t num) {
- if (node == nullptr) {
- GELOGE(GRAPH_FAILED, "Input node is null");
- return GRAPH_FAILED;
- }
-
- GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
- const auto &op_desc = node->GetOpDesc();
- for (size_t i = op_desc->GetInputsSize(); i < num; ++i) {
- if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
- GELOGE(GRAPH_FAILED, "Add input desc failed");
- return GRAPH_FAILED;
- }
-
- auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
- if (anchor == nullptr) {
- GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
- return GRAPH_FAILED;
- }
- node->in_data_anchors_.push_back(anchor);
- }
-
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node,
- uint32_t num) {
- if (node == nullptr) {
- GELOGE(GRAPH_FAILED, "Input node is null");
- return GRAPH_FAILED;
- }
-
- const auto &op_desc = node->GetOpDesc();
- while (op_desc->GetInputsSize() > num) {
- if (!OpDescUtils::ClearInputDesc(op_desc, num)) {
- return GRAPH_FAILED;
- }
- }
-
- auto input_names = op_desc->GetAllInputName();
- (void)op_desc->UpdateInputName(input_names);
- auto is_input_const = op_desc->GetIsInputConst();
- is_input_const.resize(num);
- op_desc->SetIsInputConst(is_input_const);
-
- while (node->in_data_anchors_.size() > num) {
- node->in_data_anchors_.pop_back();
- }
-
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node,
- uint32_t num) {
- if (node == nullptr) {
- GELOGE(GRAPH_FAILED, "Input node is null");
- return GRAPH_FAILED;
- }
-
- GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
- const OpDescPtr &op_desc = node->GetOpDesc();
- for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) {
- if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
- GELOGE(GRAPH_FAILED, "Add output desc failed");
- return GRAPH_FAILED;
- }
-
- auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
- if (anchor == nullptr) {
- GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
- return GRAPH_FAILED;
- }
- node->out_data_anchors_.push_back(anchor);
- }
-
- return GRAPH_SUCCESS;
- }
-
- GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node,
- uint32_t num) {
- if (node == nullptr) {
- GELOGE(GRAPH_FAILED, "Input node is null");
- return GRAPH_FAILED;
- }
-
- const auto &op_desc = node->GetOpDesc();
- auto output_names = op_desc->GetAllOutputName();
- while (op_desc->GetOutputsSize() > num) {
- if (!OpDescUtils::ClearOutputDesc(op_desc, num)) {
- return GRAPH_FAILED;
- }
- }
- (void)op_desc->UpdateOutputName(output_names);
-
- while (node->out_data_anchors_.size() > num) {
- node->out_data_anchors_.pop_back();
- }
-
- return GRAPH_SUCCESS;
- }
-
- bool NodeUtils::IsInNodesEmpty(const Node &node) {
- for (const auto &in_anchor : node.in_data_anchors_) {
- if (in_anchor != nullptr) {
- auto out_anchor = in_anchor->GetPeerOutAnchor();
- if (out_anchor != nullptr) {
- if (out_anchor->GetOwnerNode() != nullptr) {
- return false;
- }
- }
- }
- }
-
- if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
- auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
- for (const auto &out_control_anchor : peer_out_control_anchors) {
- if (out_control_anchor != nullptr) {
- if (out_control_anchor->GetOwnerNode() != nullptr) {
- return false;
- }
- }
- }
- }
-
- return true;
- }
- GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
- auto desc = node.GetOpDesc();
- if (desc == nullptr) {
- return GeTensorDesc();
- }
- return desc->GetOutputDesc(index);
- }
- GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
- auto desc = node.GetOpDesc();
- if (desc == nullptr) {
- return GeTensorDesc();
- }
- return desc->GetInputDesc(index);
- }
- graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
- auto desc = node.GetOpDesc();
- if (desc == nullptr) {
- return GRAPH_PARAM_INVALID;
- }
- auto output_desc = desc->MutableOutputDesc(index);
- if (output_desc == nullptr) {
- return GRAPH_PARAM_INVALID;
- }
- output_desc->SetShape(shape);
- return GRAPH_SUCCESS;
- }
- graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
- auto desc = node.GetOpDesc();
- if (desc == nullptr) {
- return GRAPH_PARAM_INVALID;
- }
- auto input_desc = desc->MutableInputDesc(index);
- if (input_desc == nullptr) {
- return GRAPH_PARAM_INVALID;
- }
- input_desc->SetShape(shape);
- return GRAPH_SUCCESS;
- }
-
- graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
- auto desc = node.GetOpDesc();
- GE_CHECK_NOTNULL(desc);
- // check self
- is_unknow = OpShapeIsUnknown(desc);
- if (is_unknow) {
- return GRAPH_SUCCESS;
- }
- auto sub_graph_names = desc->GetSubgraphInstanceNames();
- if (sub_graph_names.empty()) {
- return GRAPH_SUCCESS;
- } else {
- auto owner_graph = node.GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(owner_graph);
- auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
- if (root_graph == nullptr) {
- GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
- return GRAPH_PARAM_INVALID;
- }
- for (auto &sub_graph_name : sub_graph_names) {
- auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
- GE_CHECK_NOTNULL(sub_graph);
- for (const auto &node_ptr : sub_graph->GetDirectNode()) {
- auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
- if (status != GRAPH_SUCCESS) {
- GE_LOGE("get node unknown shape status failed!");
- return status;
- }
- if (is_unknow) {
- return GRAPH_SUCCESS;
- }
- }
- }
- }
- return GRAPH_SUCCESS;
- }
-
- std::string NodeUtils::GetNodeType(const Node &node) {
- if (node.GetType() != FRAMEWORKOP) {
- return node.GetType();
- }
-
- std::string type;
- (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
- return type;
- }
-
- std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); }
-
- graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) {
- return GRAPH_SUCCESS;
- }
-
- graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) {
- return GRAPH_SUCCESS;
- }
-
- ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
- auto op_desc = node.GetOpDesc();
- if (op_desc == nullptr) {
- return nullptr;
- }
- auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
- if (root_graph == nullptr) {
- return nullptr;
- }
- return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
- }
-
- graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
- if (subgraph == nullptr) {
- GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
- return GRAPH_PARAM_INVALID;
- }
- auto op_desc = node.GetOpDesc();
- if (op_desc == nullptr) {
- return GRAPH_PARAM_INVALID;
- }
- auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
- if (root_graph == nullptr) {
- GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
- return GRAPH_PARAM_INVALID;
- }
- auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
- if (ret != GRAPH_SUCCESS) {
- GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
- return ret;
- }
- subgraph->SetParentNode(node.shared_from_this());
- subgraph->SetParentGraph(node.GetOwnerComputeGraph());
- return root_graph->AddSubgraph(subgraph);
- }
-
- ///
- /// Check if node is input of subgraph
- /// @param [in] node
- /// @return bool
- ///
- bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
- if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
- (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
- return false;
- }
-
- auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
- if (parent_op_desc == nullptr) {
- return false;
- }
-
- // dynamic shape unknown graph false
- // dynamic shape known graph with functional subgraph maybe true
- if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
- if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
- return false;
- } else {
- if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
- return false;
- }
- }
- }
-
- return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
- }
-
- ///
- /// Check if node is output of subgraph
- /// @param [in] node
- /// @return bool
- ///
- bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
- if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
- (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
- return false;
- }
-
- auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
- if (parent_op_desc == nullptr) {
- return false;
- }
-
- if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
- if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
- return false;
- } else {
- if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
- return false;
- }
- }
- }
-
- for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
- if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
- return true;
- }
- }
-
- return false;
- }
-
- ///
- /// @brief Get subgraph original input node.
- /// @param [in] node
- /// @return Node
- ///
- NodePtr NodeUtils::GetParentInput(const Node &node) {
- uint32_t parent_index = 0;
- if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
- return nullptr;
- }
-
- // Subgraph Data Node, check for constant input.
- const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
- GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
-
- const NodePtr &parent_node = graph->GetParentNode();
- GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
-
- const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
- GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
-
- const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
- GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
-
- return peer_out_anchor->GetOwnerNode();
- }
-
- NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); }
-
- ///
- /// @brief Get is dynamic shape graph from node.
- /// @param [in] node
- /// @return bool
- ///
- bool NodeUtils::IsDynamicShape(const Node &node) {
- const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
- if (graph == nullptr) {
- return false;
- }
-
- bool is_dynamic_shape = false;
- (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
- return is_dynamic_shape;
- }
-
- bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); }
-
- ///
- /// @brief Check is varying_input for while node
- /// @param [in] node: Data node for subgraph
- /// @return bool
- ///
- bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
- if (node == nullptr) {
- return false;
- }
- if (node->GetType() != DATA) {
- return false; // not input_node for subgraph
- }
-
- const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
- if (parent_node == nullptr) {
- return false; // root graph
- }
-
- if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
- return false; // not input_node for while subgraph
- }
-
- uint32_t index_i = 0;
- if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
- GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
- return false;
- }
- bool varying_flag = true;
- for (const auto &item : node->GetOutDataNodesAndAnchors()) {
- if (item.first->GetType() != NETOUTPUT) {
- continue;
- }
- OpDescPtr op_desc = item.first->GetOpDesc();
- uint32_t index_o = 0;
- if ((op_desc == nullptr) ||
- !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
- continue; // input for while-cond subgraph
- }
- if (index_i != index_o) {
- continue; // varying input for while-body subgraph
- }
- varying_flag = false;
- break;
- }
- return varying_flag;
- }
-
- ///
- /// @brief Get subgraph input is constant.
- /// @param [in] node
- /// @param [out] string
- /// @return bool
- ///
- bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) {
- if (node == nullptr) {
- return false;
- }
-
- if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
- type = node->GetType();
- return true;
- }
-
- if (node->GetType() != DATA) {
- return false; // not subgraph input node
- }
-
- const auto &parent = GetParentInput(node);
- return GetConstOpType(parent, type);
- }
-
- ///
- /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
- /// @param [in] node
- /// @return return GRAPH_SUCCESS if remove successfully, other for failed.
- ///
- Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
- GE_CHECK_NOTNULL(node);
- auto op_desc = node->GetOpDesc();
- GE_CHECK_NOTNULL(op_desc);
- auto subgraph_names = op_desc->GetSubgraphInstanceNames();
- if (subgraph_names.empty()) {
- return GRAPH_SUCCESS;
- } else {
- auto owner_graph = node->GetOwnerComputeGraph();
- GE_CHECK_NOTNULL(owner_graph);
- auto root_graph = GraphUtils::FindRootGraph(owner_graph);
- GE_CHECK_NOTNULL(root_graph);
-
- std::unordered_set<std::string> subgraph_to_remove;
- for (auto &subgraph_name : subgraph_names) {
- std::deque<std::string> queue;
- queue.push_back(subgraph_name);
- subgraph_to_remove.insert(subgraph_name);
- op_desc->RemoveSubgraphInstanceName(subgraph_name);
- while (!queue.empty()) {
- auto graph_name = queue.front();
- queue.pop_front();
-
- auto subgraph = root_graph->GetSubgraph(graph_name);
- GE_CHECK_NOTNULL(subgraph);
- for (const auto &sub_node : subgraph->GetDirectNode()) {
- auto sub_op_desc = sub_node->GetOpDesc();
- GE_CHECK_NOTNULL(sub_op_desc);
- auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
- // Subgraph and all nodes in it will be removed later,
- // no need to remove 'SubgraphInstanceName' in op desc here.
- for (auto &name : sub_names) {
- if (subgraph_to_remove.insert(name).second) {
- queue.push_back(name);
- }
- }
- }
- }
- }
- // Remove subgraph from root_graph
- for (const auto &name : subgraph_to_remove) {
- GELOGI("Remove subgraph:%s.", name.c_str());
- root_graph->RemoveSubgraph(name);
- }
- }
-
- return GRAPH_SUCCESS;
- }
- ///
- /// @brief Get subgraph input data node by index.
- /// @param [in] node
- /// @return Node
- ///
- vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
- vector<NodePtr> in_data_node_vec;
- auto op_desc = node.GetOpDesc();
- GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
- auto subgraph_names = op_desc->GetSubgraphInstanceNames();
- if (subgraph_names.empty()) {
- GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
- return in_data_node_vec;
- }
- auto compute_graph = node.GetOwnerComputeGraph();
- for (const std::string &instance_name : subgraph_names) {
- auto subgraph = compute_graph->GetSubgraph(instance_name);
- for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
- int parent_index = -1;
- if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
- (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
- if (parent_index == index) {
- in_data_node_vec.emplace_back(node_in_subgraph);
- }
- }
- }
- }
- return in_data_node_vec;
- }
- ///
- /// @brief Get subgraph input data node by index.
- /// @param [in] node
- /// @return Node
- ///
- vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
- vector<NodePtr> out_data_node_vec;
- auto op_desc = node.GetOpDesc();
- GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
- auto subgraph_names = op_desc->GetSubgraphInstanceNames();
- if (subgraph_names.empty()) {
- GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
- return out_data_node_vec;
- }
- auto compute_graph = node.GetOwnerComputeGraph();
- for (const std::string &instance_name : subgraph_names) {
- auto subgraph = compute_graph->GetSubgraph(instance_name);
- for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
- if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
- out_data_node_vec.emplace_back(node_in_subgraph);
- }
- }
- }
- return out_data_node_vec;
- }
-
- NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) {
- if (node.GetInDataAnchor(index) == nullptr) {
- return nullptr;
- }
- if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
- return nullptr;
- }
- return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
- }
-
- vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) {
- vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes;
- auto out_data_anchor = node.GetOutDataAnchor(index);
- if (out_data_anchor == nullptr) {
- return out_data_nodes;
- }
-
- for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
- if (peer_in_anchor == nullptr) {
- continue;
- }
- if (peer_in_anchor->GetOwnerNode() == nullptr) {
- continue;
- }
- out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode()));
- }
- return out_data_nodes;
- }
-
- ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); }
- } // namespace ge
|