/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "external/graph/operator.h" #include "external/graph/operator_factory.h" #include #include #include #include #include #include "./array_ops.h" #include "debug/ge_log.h" #include "debug/ge_op_types.h" #include "debug/ge_util.h" #include "external/graph/attr_value.h" #include "external/graph/types.h" #include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" #include "graph/ge_attr_value.h" #include "graph/ge_context.h" #include "graph/ge_tensor.h" #include "graph/node.h" #include "graph/op_desc.h" #include "graph/runtime_inference_context.h" #include "graph/usr_types.h" #include "graph/utils/node_utils.h" #include "graph/debug/ge_attr_define.h" #include "utils/graph_utils.h" #include "utils/op_desc_utils.h" #include "utils/tensor_adapter.h" #include "utils/tensor_utils.h" #include "utils/type_utils.h" #include #include #include #include #include using std::enable_shared_from_this; using std::make_pair; using std::shared_ptr; using std::string; using std::to_string; using std::vector; /*lint -save -e529 -e728*/ /*lint -e446 -e732*/ /*lint -e665*/ namespace ge { class OpIO { public: OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} ~OpIO() = default; string GetName() const { return name_; } int GetIndex() const { return index_; } OperatorImplPtr GetOwner() const { return owner_; } bool operator==(const OpIO &r_value) const { return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) && (this->GetOwner() == r_value.GetOwner()); } private: string name_; int index_; std::shared_ptr owner_; }; class TensorTypeImpl { public: TensorTypeImpl() = default; ~TensorTypeImpl() = default; std::vector dt_vec_; }; TensorType::TensorType(DataType dt) { tensor_type_impl_ = ComGraphMakeShared(); if (tensor_type_impl_ != nullptr) { tensor_type_impl_->dt_vec_.push_back(dt); } } TensorType::TensorType(const std::initializer_list &types) { tensor_type_impl_ = ComGraphMakeShared(); if (tensor_type_impl_ != nullptr) { tensor_type_impl_->dt_vec_ = types; } } class OperatorImpl : public std::enable_shared_from_this { friend class GraphBuilderImpl; friend class OpDescUtils; public: explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared(name, type)) { if (op_desc_ == nullptr) { GELOGW("OpDesc make shared failed"); } } explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {} explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) { if (node_ != nullptr && node_->GetOpDesc() != nullptr) { op_desc_ = node_->GetOpDesc(); } } ~OperatorImpl() {} void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) { GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr."); GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr."); auto src_op_impl = src_oprt.GetOperatorImplPtr(); GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null."); GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null."); GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return, "The source operator[%s] must has one output", src_oprt.operator_impl_->op_desc_->GetName().c_str()) uint32_t src_index = 0; string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index); GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty."); OpIO out_handler(src_name, src_index, src_op_impl); input_link_.insert(std::make_pair(dst_name, out_handler)); int dst_index = op_desc_->GetInputIndexByName(dst_name); GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), op_desc_->GetName().c_str()); bool is_const = false; if (src_oprt.GetOpType() == CONSTANT) { is_const = true; } auto is_input_const = op_desc_->GetIsInputConst(); for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { is_input_const.push_back(false); } is_input_const[dst_index] = is_const; op_desc_->SetIsInputConst(is_input_const); OpIO op_dst(dst_name, dst_index, shared_from_this()); src_op_impl->UpdateLinkMapImpl(src_name, op_dst); auto output_desc = src_op_impl->GetOutputDesc(src_name); auto input_desc = op_desc_->GetInputDesc(dst_name); if (input_desc.GetFormat() == FORMAT_RESERVED) { output_desc.SetFormat(FORMAT_ND); } else { output_desc.SetFormat(input_desc.GetFormat()); } // Fix for linking opdesc if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), src_name.c_str()); return; } } void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) { GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr."); GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); input_link_.insert(std::make_pair(dst_name, *out_handler)); string src_name = out_handler->GetName(); int dst_index = op_desc_->GetInputIndexByName(dst_name); GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), op_desc_->GetName().c_str()); auto out_op_impl = out_handler->GetOwner(); GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, "out_handler invalid. name[%s]", dst_name.c_str()); bool is_const = false; if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { is_const = true; } auto is_input_const = op_desc_->GetIsInputConst(); for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { is_input_const.push_back(false); } is_input_const[dst_index] = is_const; op_desc_->SetIsInputConst(is_input_const); OpIO in_handler(dst_name, dst_index, shared_from_this()); GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); out_op_impl->UpdateLinkMapImpl(src_name, in_handler); auto src_output_desc = out_op_impl->GetOutputDesc(src_name); auto dst_input_desc = op_desc_->GetInputDesc(dst_name); if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { src_output_desc.SetFormat(FORMAT_ND); } else { src_output_desc.SetFormat(dst_input_desc.GetFormat()); } GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return, "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), src_name.c_str()); // fix for linking opdesc } void AddControlInputImp(const ge::Operator &src_oprt) { if (src_oprt.operator_impl_ == nullptr) { GELOGE(FAILED, "Src operator impl is nullptr"); return; } for (auto &input : control_input_link_) { if (input.lock() == src_oprt.operator_impl_) { return; } } control_input_link_.push_back(src_oprt.operator_impl_); src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this()); } graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) { auto out = input_link_.find(dst_name); if (out == input_link_.end()) { return GRAPH_FAILED; } out_handler = out->second; return GRAPH_SUCCESS; } bool InputIsSet(const string &name) { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr."); return op_desc_->InputIsSet(name); } string GetName() const { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr."); return op_desc_->GetName(); } GeTensorDesc GetInputDesc(const string &name) const { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); return op_desc_->GetInputDesc(name); } GeTensorDesc GetInputDesc(uint32_t index) const { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); return op_desc_->GetInputDesc(index); } graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr."); return op_desc_->UpdateInputDesc(name, tensor_desc); } OutHandler GetOutput(const string &name) { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); int src_index = op_desc_->GetOutputIndexByName(name); GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str()); shared_ptr output_ptr = ComGraphMakeShared(name, src_index, shared_from_this()); if (output_ptr == nullptr) { GELOGE(GRAPH_FAILED, "OpIO make shared failed"); return nullptr; } return output_ptr; } OutHandler GetOutput(uint32_t index) { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); string name = op_desc_->GetOutputNameByIndex(index); if (name.empty()) { GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index); return nullptr; } shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); if (output_ptr == nullptr) { GELOGE(GRAPH_FAILED, "OpIO make shared failed"); return nullptr; } return output_ptr; } GeTensorDesc GetOutputDesc(const string &name) const { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); return op_desc_->GetOutputDesc(name); } GeTensorDesc GetOutputDesc(uint32_t index) const { GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); return op_desc_->GetOutputDesc(index); } graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) { GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); auto res = op_desc_->UpdateOutputDesc(name, tensor_desc); if (res == GRAPH_SUCCESS) { for (auto ol : output_links_[name]) { if (ol.GetOwner() == nullptr) { GELOGW("%s get owner is nullptr", ol.GetName().c_str()); continue; } GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED, "Could not update next operator's input %s.", ol.GetName().c_str()); } } return res; } size_t GetInputsSize() const { GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); return op_desc_->GetInputsSize(); } size_t GetOutputsSize() const { GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); return op_desc_->GetOutputsSize(); } graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) { GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); return op_desc_->SetAttr(name, std::move(attr_value)); } graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const { GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); return op_desc_->GetAttr(name, attr_value); } OpDescPtr GetOpDescImpl() const { return op_desc_; } void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) { auto it_find = output_links_.find(src_name); if (it_find == output_links_.end()) { std::vector dsts{op_dst}; output_links_.insert(std::make_pair(src_name, dsts)); } else { it_find->second.push_back(op_dst); } } Operator ToOperator() { return Operator(shared_from_this()); } static OpDescPtr GetOpDesc(const Operator &oprt) { GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr); return oprt.operator_impl_->op_desc_; } void ClearOutputLinks() noexcept { output_links_.clear(); } void ClearInputLinks() noexcept { input_link_.clear(); } ge::ConstNodePtr GetNode() { return node_; } void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; } InferenceContextPtr GetInferenceContext() const { return inference_context_; } void SubgraphRegister(const std::string &ir_name, bool dynamic) { op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); } void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { op_desc_->AddSubgraphName(ir_name); subgraph_names_to_builders_[ir_name] = nullptr; } else { for (uint32_t i = 0; i < count; ++i) { string key_name = ir_name + std::to_string(i); op_desc_->AddSubgraphName(key_name); subgraph_names_to_builders_[key_name] = nullptr; } } } void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { string key_name = ir_name; if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { key_name += std::to_string(index); } auto it = subgraph_names_to_builders_.find(key_name); if (it == subgraph_names_to_builders_.end()) { GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); return; } it->second = builder; } SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { string key_name = ir_name; if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { key_name += std::to_string(index); } return GetSubgraphBuilder(key_name); } SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { auto iter = subgraph_names_to_builders_.find(name); if (iter == subgraph_names_to_builders_.end()) { GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); return nullptr; } return iter->second; } std::vector GetSubgraphNames() const { std::vector names; for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { names.emplace_back(subgraph_name_to_type.first); } return names; } size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } OpDescPtr op_desc_ = nullptr; private: ge::ConstNodePtr node_{nullptr}; ge::InferenceContextPtr inference_context_; std::map> output_links_{}; std::map input_link_{}; std::vector> control_input_link_{}; std::vector> control_output_link_{}; std::map subgraph_names_to_builders_; }; // Used to manage OperatorImpl instances created by ge api. class OperatorKeeper { private: OperatorKeeper() = default; ~OperatorKeeper() { for (const auto &iter : operators_) { if (iter) { iter->ClearInputLinks(); iter->ClearOutputLinks(); } } } std::set operators_; std::mutex mutex_; public: static OperatorKeeper &GetInstance() { static OperatorKeeper instance; return instance; } void CheckInOperator(const OperatorImplPtr &op_impl) { if (op_impl) { std::lock_guard lock(mutex_); operators_.insert(op_impl); } } void CheckOutOperator(const OperatorImplPtr &op_impl) { if (op_impl) { std::lock_guard lock(mutex_); operators_.erase(op_impl); } } }; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) { ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared(node_ptr); if (operator_impl_ptr == nullptr) { GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); return Operator("default"); } return operator_impl_ptr->ToOperator(); } Operator::Operator(const std::string &type) { static uint32_t index = 0; string name = type + "_" + std::to_string(index++); operator_impl_ = ComGraphMakeShared(name, type); if (operator_impl_ == nullptr) { GELOGW("OperatorImpl make shared failed"); } OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) { shared_ptr operator_impl_ptr; operator_impl_ptr = ComGraphMakeShared(op_desc); if (operator_impl_ptr == nullptr) { GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); return Operator("default"); } OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); return operator_impl_ptr->ToOperator(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { return OperatorImpl::GetOpDesc(oprt); } GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) { operator_impl_ = ComGraphMakeShared(name, type); if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); return; } OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); } Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); } bool Operator::IsEmpty() const { if (operator_impl_ == nullptr) { return true; } return false; } string Operator::GetName() const { if (operator_impl_ != nullptr) { return operator_impl_->GetName(); } return ""; } GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) { // Describe the connection relationship between operators, no create action GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); operator_impl_->SetInputImpl(dst_name, src_oprt); return *this; } Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); operator_impl_->SetInputImpl(dst_name, out_handler); return *this; } Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) { auto out_handler = src_oprt.GetOutput(name); GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); (void)SetInput(dst_name, out_handler); return *this; } Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { auto out_handler = src_oprt.GetOutput(index); GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); (void)SetInput(dst_name, out_handler); return *this; } Operator &Operator::AddControlInput(const Operator &src_oprt) { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr."); return *this; } operator_impl_->AddControlInputImp(src_oprt); return *this; } graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { GE_CHECK_NOTNULL(operator_impl_); auto node_ptr = operator_impl_->GetNode(); if (node_ptr != nullptr) { // For inner compute graph auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); auto index = op_desc->GetInputIndexByName(dst_name); auto in_data_anchor = node_ptr->GetInDataAnchor(index); GE_CHECK_NOTNULL(in_data_anchor); auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(out_data_anchor); auto peer_node = out_data_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(peer_node); auto peer_op_desc = peer_node->GetOpDesc(); GE_CHECK_NOTNULL(peer_op_desc); auto peer_op_type = peer_op_desc->GetType(); if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { auto const_op_impl = ComGraphMakeShared(peer_node); GE_CHECK_NOTNULL(const_op_impl); Operator const_op(std::move(const_op_impl)); return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); } else if (peer_op_type == DATA) { auto parent_node = NodeUtils::GetParentInput(peer_node); while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { parent_node = NodeUtils::GetParentInput(parent_node); } if ((parent_node != nullptr) && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { auto const_op_impl = ComGraphMakeShared(parent_node); GE_CHECK_NOTNULL(const_op_impl); Operator const_op(std::move(const_op_impl)); return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); } } // Try get from runtime inference context auto session_id = std::to_string(GetContext().SessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); if (ret == GRAPH_SUCCESS) { return GRAPH_SUCCESS; } } } else { // For outer graph return GetInputConstDataOut(dst_name, data); } auto op_name = operator_impl_->GetName(); GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); return GRAPH_FAILED; } graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { ge::OpIO out_handle("", 0, nullptr); GE_CHECK_NOTNULL(operator_impl_); if (operator_impl_->GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) { GELOGE(FAILED, "%s get input impl failed", dst_name.c_str()); return GRAPH_FAILED; } if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) { Operator const_op(out_handle.GetOwner()); const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); if (op_desc_impl_type == CONSTANTOP) { return const_op.GetAttr(op::Constant::name_attr_value(), data); } else if (op_desc_impl_type == CONSTANT) { return const_op.GetAttr(op::Const::name_attr_value(), data); } } return GRAPH_FAILED; } std::shared_ptr Operator::GetNode() const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); return operator_impl_->GetNode(); } TensorDesc Operator::GetInputDesc(const std::string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); } void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); operator_impl_->SetInferenceContext(inference_context); } InferenceContextPtr Operator::GetInferenceContext() const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); return operator_impl_->GetInferenceContext(); } TensorDesc Operator::GetInputDesc(uint32_t index) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); } graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); auto check = operator_impl_->InputIsSet(name); if (check) tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); return check ? GRAPH_SUCCESS : GRAPH_FAILED; } graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); } OutHandler Operator::GetOutput(const string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); return operator_impl_->GetOutput(name); } OutHandler Operator::GetOutput(uint32_t index) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); return operator_impl_->GetOutput(index); } TensorDesc Operator::GetOutputDesc(const std::string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); } TensorDesc Operator::GetOutputDesc(uint32_t index) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index)); } graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); } TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index))); } graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); return operator_impl_->UpdateInputDesc(name + std::to_string(index), TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); } TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index))); } graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); return operator_impl_->UpdateOutputDesc(name + std::to_string(index), TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); } graphStatus Operator::InferShapeAndType() { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); return operator_impl_->GetOpDescImpl()->CallInferFunc(*this); } graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) { return GRAPH_FAILED; } else { return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify(); } } GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); return operator_impl_->GetInputsSize(); } GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); return operator_impl_->GetOutputsSize(); } // According to op get the attrs name and type namespace { const std::map kAttrTypesMap = { {GeAttrValue::VT_NONE, "VT_STRING"}, {GeAttrValue::VT_STRING, "VT_STRING"}, {GeAttrValue::VT_FLOAT, "VT_FLOAT"}, {GeAttrValue::VT_BOOL, "VT_BOOL"}, {GeAttrValue::VT_INT, "VT_INT"}, {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, {GeAttrValue::VT_TENSOR, "VT_TENSOR"}, {GeAttrValue::VT_BYTES, "VT_BYTES"}, {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, }; } // namespace const std::map Operator::GetAllAttrNamesAndTypes() const { std::map attr_types; GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr."); std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); map::iterator iter; for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) { string name = iter->first; GeAttrValue attr_value = iter->second; GeAttrValue::ValueType type = attr_value.GetValueType(); auto iter2 = kAttrTypesMap.find(type); if (iter2 != kAttrTypesMap.end()) { attr_types[name] = iter2->second; } } return attr_types; } void Operator::InputRegister(const string &name) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); } void Operator::OptionalInputRegister(const string &name) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); // [No need to verify return value] (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); } void Operator::InferFuncRegister(const std::function &func) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); // [No need to verify return value] (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); } void Operator::InferFormatFuncRegister(const std::function &func) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); // [No need to verify return value] (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); } void Operator::VerifierFuncRegister(const std::function &func) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); // [No need to verify return value] (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); } void Operator::OutputRegister(const string &name) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); // [No need to verify return value] (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); } void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return, "set int failed"); (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); } void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); } int Operator::GetDynamicInputNum(const string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); int num = 0; GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num, "Get %s int failed", name.c_str()); return num; } void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, "Set %s int failed", name.c_str()); (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); } int Operator::GetDynamicOutputNum(const string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); int num = 0; GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, "Get %s int failed", name.c_str()); return num; } void Operator::RequiredAttrRegister(const string &name) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); } graphStatus Operator::VerifyAll() { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); // Check all inputs defined for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) { GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname), GRAPH_FAILED, "operator input %s is not linked.", iname.c_str()); vector ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims(); for (int64_t dim : ishape) { GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", iname.c_str()); } } // Check all attributes defined const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs(); for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) { GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); } return GRAPH_SUCCESS; } string Operator::GetOpType() const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr."); return OperatorImpl::GetOpDesc(*this)->GetType(); } Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); return SetInput(dynamic_dst_name, src_oprt); } Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, const std::string &name) { string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); return SetInput(dynamic_dst_name, src_oprt, name); } OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } #define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \ Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \ if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ return *this; \ } \ if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ GELOGW("set attr name %s failed.", name.c_str()); \ } \ return *this; \ } // lint !e665 #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ return GRAPH_FAILED; \ } \ if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ GELOGW("get attr name %s failed.", name.c_str()); \ return GRAPH_FAILED; \ } \ return GRAPH_SUCCESS; \ } // lint !e665 void Operator::BreakConnect() const { if (operator_impl_ == nullptr) { GELOGW("operator impl is nullptr."); return; } operator_impl_->ClearInputLinks(); operator_impl_->ClearOutputLinks(); OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_); } #define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \ void Operator::AttrRegister(const string &name, ArgType attr_value) { \ if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ return; \ } \ if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ GELOGW("reg attr name %s failed.", name.c_str()); \ } \ } // lint !e665 OP_ATTR_SET_IMP(int64_t, Int) OP_ATTR_SET_IMP(int32_t, Int) OP_ATTR_SET_IMP(uint32_t, Int) OP_ATTR_GET_IMP(int64_t &, Int) OP_ATTR_GET_IMP(int32_t &, Int) OP_ATTR_GET_IMP(uint32_t &, Int) OP_ATTR_SET_IMP(const vector &, ListInt) OP_ATTR_SET_IMP(const vector &, ListInt) OP_ATTR_SET_IMP(const vector &, ListInt) OP_ATTR_SET_IMP(std::initializer_list &&, ListInt) OP_ATTR_GET_IMP(vector &, ListInt) OP_ATTR_GET_IMP(vector &, ListInt) OP_ATTR_GET_IMP(vector &, ListInt) OP_ATTR_GET_IMP(vector> &, ListListInt) OP_ATTR_SET_IMP(const vector> &, ListListInt) OP_ATTR_SET_IMP(float, Float) OP_ATTR_GET_IMP(float &, Float) OP_ATTR_SET_IMP(const vector &, ListFloat) OP_ATTR_GET_IMP(vector &, ListFloat) // lint !e665 OP_ATTR_SET_IMP(bool, Bool) OP_ATTR_GET_IMP(bool &, Bool) OP_ATTR_SET_IMP(const vector &, ListBool) OP_ATTR_GET_IMP(vector &, ListBool) // lint !e665 OP_ATTR_SET_IMP(const string &, Str) OP_ATTR_GET_IMP(string &, Str) OP_ATTR_SET_IMP(const vector &, ListStr) OP_ATTR_GET_IMP(vector &, ListStr) // lint !e665 OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) OP_ATTR_GET_IMP(vector &, ListNamedAttrs) // lint !e665 OP_ATTR_REG_IMP(int64_t, Int) OP_ATTR_REG_IMP(const vector &, ListInt) OP_ATTR_REG_IMP(float, Float) OP_ATTR_REG_IMP(const vector &, ListFloat) OP_ATTR_REG_IMP(const string &, Str) OP_ATTR_REG_IMP(const vector &, ListStr) OP_ATTR_REG_IMP(bool, Bool) OP_ATTR_REG_IMP(const vector &, ListBool) OP_ATTR_REG_IMP(const vector> &, ListListInt) OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) OP_ATTR_REG_IMP(const vector &, ListNamedAttrs) #undef OP_ATTR_SET_IMP #undef OP_ATTR_GET_IMP #undef OP_ATTR_REG_IMP Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return *this; } GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { GELOGW("set attr name %s failed.", name.c_str()); } return *this; } Operator &Operator::SetAttr(const string &name, const vector &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return *this; } vector val_list; for (const auto &item : attr_value) { auto tensor = TensorAdapter::AsGeTensor(item); val_list.push_back(tensor); } if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { GELOGW("set attr name %s failed.", name.c_str()); } return *this; } graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return GRAPH_FAILED; } ConstGeTensorPtr tensor; if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { GELOGW("get attr name %s failed.", name.c_str()); return GRAPH_FAILED; } attr_value = TensorAdapter::GeTensor2Tensor(tensor); return GRAPH_SUCCESS; } graphStatus Operator::GetAttr(const string &name, vector &attr_value) const { attr_value.clear(); if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return GRAPH_FAILED; } vector val_list; if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { GELOGW("get attr name %s failed.", name.c_str()); return GRAPH_FAILED; } for (auto &tensor : val_list) { attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); } return GRAPH_SUCCESS; } Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return *this; } if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { GELOGW("set attr name %s failed.", name.c_str()); } return *this; } graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return GRAPH_FAILED; } Buffer buffer; if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) { GELOGW("get attr name %s failed.", name.c_str()); return GRAPH_FAILED; } attr_value.clear(); if (buffer.data() == nullptr) { GELOGE(GRAPH_FAILED, "buffer data is null."); return GRAPH_FAILED; } attr_value.assign(buffer.data(), buffer.data() + buffer.size()); return GRAPH_SUCCESS; } Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); return *this; } graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_); } Operator &Operator::SetAttr(const string &name, const std::vector &attr_value) { if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return *this; } if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { GELOGW("set attr name %s failed.", name.c_str()); } return *this; } graphStatus Operator::GetAttr(const string &name, std::vector &attr_value) const { attr_value.clear(); if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return GRAPH_FAILED; } if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { GELOGW("get attr name %s failed.", name.c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return *this; } if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { GELOGW("set attr name %s failed.", name.c_str()); } return *this; } graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return GRAPH_FAILED; } if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { GELOGW("get attr name %s failed.", name.c_str()); return GRAPH_FAILED; } return GRAPH_SUCCESS; } void Operator::AttrRegister(const string &name, const std::vector &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { GELOGW("set attr name %s failed.", name.c_str()); } } void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { GELOGW("set attr name %s failed.", name.c_str()); } } void Operator::AttrRegister(const string &name, const Tensor &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } auto tensor = TensorAdapter::AsGeTensor(attr_value); if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { GELOGW("reg attr name %s failed.", name.c_str()); } } void Operator::AttrRegister(const string &name, const vector &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } vector val_list; for (const auto &item : attr_value) { val_list.push_back(TensorAdapter::AsGeTensor(item)); } if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { GELOGW("reg attr name %s failed.", name.c_str()); } } void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { GELOGW("reg attr name %s failed.", name.c_str()); } } void Operator::SubgraphRegister(const std::string &name, bool dynamic) { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); } void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); return; } operator_impl_->SubgraphCountRegister(name, count); } void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); return; } operator_impl_->SetSubgraphBuilder(ir_name, index, builder); } std::vector Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { if (operator_impl_ == nullptr) { GELOGE(GRAPH_FAILED, "operator impl is nullptr."); return nullptr; } return operator_impl_->GetSubgraphBuilder(ir_name, index); } SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { return GetDynamicSubgraphBuilder(ir_name, 0); } Graph Operator::GetSubgraph(const string &name) const { if (operator_impl_ == nullptr) { GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); return Graph(""); } auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); if (op_desc == nullptr) { GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); return Graph(""); } const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); auto iter = subgraph_names_to_index.find(name); if (iter == subgraph_names_to_index.end()) { GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); return Graph(""); } auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); if (subgraph_instance_name.empty()) { GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); return Graph(""); } auto node = operator_impl_->GetNode(); if (node == nullptr) { GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); return Graph(""); } auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); if (root_graph == nullptr) { GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); return Graph(""); } auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); if (subgraph == nullptr) { GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), iter->second, subgraph_instance_name.c_str()); return Graph(""); } return GraphUtils::CreateGraphFromComputeGraph(subgraph); } Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { return GetSubgraph(name + std::to_string(index)); } size_t Operator::GetSubgraphNamesCount() const { if (operator_impl_ == nullptr) { GE_LOGE("Failed to get subgraph names count, the operator impl is null"); return 0; } return operator_impl_->GetSubgraphNamesCount(); } class GraphBuilderImpl { public: explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared(name)) { if (graph_ == nullptr) { GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); return; } } ~GraphBuilderImpl() {} ComputeGraphPtr BuildGraph(const std::vector &inputs) { std::vector vec_inputs; for (auto &it : inputs) { auto src_op_impl = it.operator_impl_; GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null."); GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null."); string type = src_op_impl->op_desc_->GetType(); auto node_op = ge::OperatorFactory::CreateOperator("node_op", type); auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); node_op.BreakConnect(); GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null."); if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA || type == VARIABLE || type == INITDATA || type == GETNEXT) { vec_inputs.push_back(it.operator_impl_); } else { GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); } } GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, "User Input do not include operator such as " "Data, Variable operator or operator that has output but no input."); auto ret = WalkAllOperators(vec_inputs); GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); ret = AddEdge(); GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed."); return graph_; } const std::map &GetAllNodesInfo() const { return all_nodes_info_; } private: graphStatus WalkAllOperators(const std::vector &vec_ops) { GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.") std::queue> que; que.push(vec_ops); while (!que.empty()) { auto vec_tem = que.front(); que.pop(); for (const auto &op_impl : vec_tem) { GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, "This node %s has created.", op_impl->GetName().c_str()) auto node_ptr = graph_->AddNode(op_impl->op_desc_); GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); auto &out_links = op_impl->output_links_; std::vector vec_op_forward{}; for (const auto &out_link : out_links) { for (const auto &op_forward : out_link.second) { vec_op_forward.push_back(op_forward.GetOwner()); } } auto &out_control_links = op_impl->control_output_link_; for (const auto &out_link : out_control_links) { vec_op_forward.push_back(out_link.lock()); } que.push(vec_op_forward); auto &in_links = op_impl->input_link_; std::vector vec_op_back_forward{}; for (const auto &in_link : in_links) { vec_op_back_forward.push_back(in_link.second.GetOwner()); } auto &in_control_links = op_impl->control_input_link_; for (const auto &in_link : in_control_links) { vec_op_back_forward.push_back(in_link.lock()); } que.push(vec_op_back_forward); if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } } return MoveSubgraphToRoot(graph_); } graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { const string name = node->GetName(); for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); Graph graph = builder(); // Build subgraph from user define builder. const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); subgraph->SetParentNode(node); subgraph->SetParentGraph(graph_); if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { return GRAPH_FAILED; } if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); return GRAPH_FAILED; } } return GRAPH_SUCCESS; } graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); if (root_graph == nullptr) { GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); return GRAPH_FAILED; } if (root_graph == graph) { auto subgraphs = graph->GetAllSubgraphs(); for (auto &subgraph : subgraphs) { if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } } else { auto subgraphs = graph->GetAllSubgraphs(); for (auto &subgraph : subgraphs) { if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { return GRAPH_FAILED; } graph->RemoveSubgraph(subgraph->GetName()); if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { return GRAPH_FAILED; } } } return GRAPH_SUCCESS; } graphStatus AddEdge() { for (const auto &node_info : all_nodes_info_) { auto src_op_impl_ptr = node_info.first; auto src_node_ptr = node_info.second; GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); auto out_links = src_op_impl_ptr->output_links_; GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, "Src operator impl's op_desc is null."); auto &op_desc = src_op_impl_ptr->op_desc_; GE_IF_BOOL_EXEC(op_desc == nullptr, continue); for (const auto &out : out_links) { auto src_idx = op_desc->GetOutputIndexByName(out.first); GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx); GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed."); for (const auto &dst_opio : out.second) { auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor); GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "from node[%s][%d] to node[%s][%d]AddEdge failed.", src_node_ptr->GetName().c_str(), src_anchor->GetIdx(), dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); } } auto out_control_anchor = src_node_ptr->GetOutControlAnchor(); for (const auto &control_out : src_op_impl_ptr->control_output_link_) { auto dst_node_info = all_nodes_info_.find(control_out.lock()); if (dst_node_info == all_nodes_info_.end()) { GELOGE(GRAPH_FAILED, "Find Dst node failed."); return GRAPH_FAILED; } GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); auto in_control_anchor = dst_node_info->second->GetInControlAnchor(); auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor); if (ret != GRAPH_SUCCESS) { GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), dst_node_info->second->GetType().c_str()); return ret; } } } return GRAPH_SUCCESS; } ComputeGraphPtr graph_ = nullptr; std::map all_nodes_info_{}; }; inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { for (const auto &graph : compute_graph->GetAllSubgraphs()) { std::set node_names; for (auto const &node : graph->GetDirectNode()) { auto result = node_names.insert(node->GetName()); if (!result.second) { GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str()); return true; } } } std::set node_names; for (auto const &node : compute_graph->GetDirectNode()) { auto result = node_names.insert(node->GetName()); if (!result.second) { GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str()); return true; } } return false; } ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector &inputs) { auto graph_builder_impl = GraphBuilderImpl(name); ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); if (HasSameNameNode(compute_graph)) { GELOGW("Compute do not allow has same name nodes."); compute_graph = nullptr; } return compute_graph; } void GraphUtils::BreakConnect(const std::map &all_nodes_infos) { for (const auto &it : all_nodes_infos) { OperatorImplPtr op_impl = it.first; if (op_impl == nullptr) { GELOGW("operator impl is nullptr."); continue; } op_impl->ClearOutputLinks(); op_impl->ClearInputLinks(); OperatorKeeper::GetInstance().CheckOutOperator(op_impl); } } } // namespace ge /*lint +e446 +e732*/ /*lint +e665*/