/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/op_desc.h" #include "debug/ge_attr_define.h" #include "debug/ge_util.h" #include "external/graph/operator.h" #include "framework/common/debug/ge_log.h" #include "common/util/error_manager/error_manager.h" #include "graph/ge_attr_value.h" #include "graph/ge_tensor.h" #include "graph/operator_factory_impl.h" #include "graph/utils/attr_utils.h" #include "graph/utils/ge_ir_utils.h" #include "graph/utils/op_desc_utils.h" #include "proto/ge_ir.pb.h" using std::make_pair; using std::shared_ptr; using std::string; using std::vector; /*lint -save -e521 -e681 -e732 -e737*/ namespace ge { const std::string ATTR_NAME_ID = "id"; const std::string ATTR_NAME_STREAM_ID = "stream_id"; const std::string ATTR_NAME_INPUT_NAME = "input_name"; const std::string ATTR_NAME_SRC_NAME = "src_name"; const std::string ATTR_NAME_SRC_INDEX = "src_index"; const std::string ATTR_NAME_INPUT = "input"; const std::string ATTR_NAME_OUTPUT = "output"; const std::string ATTR_NAME_INPUT_DESC = "input_desc"; const std::string ATTR_NAME_OUTPUT_DESC = "output_desc"; const std::string ATTR_NAME_DST_NAME = "dst_name"; const std::string ATTR_NAME_DST_INDEX = "dst_index"; const std::string ATTR_NAME_WORKSPACE = "workspace"; const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { op_def_.InitDefault(); if (op_def_.GetProtoMsg() != nullptr) { op_def_.GetProtoMsg()->set_has_out_attr(true); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {} GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) { op_def_.InitDefault(); if (op_def_.GetProtoMsg() != nullptr) { op_def_.GetProtoMsg()->set_has_out_attr(true); } SetName(name); SetType(type); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def) : op_def_(proto_msg_owner, op_def) { if (op_def != nullptr && !op_def->has_out_attr()) { op_def->set_has_out_attr(true); int64_t id = 0; (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id); op_def->set_id(id); int64_t stream_id = 0; (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id); op_def->set_stream_id(stream_id); vector input_name; (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name); for (auto &item : input_name) { op_def->add_input_name(item); } vector src_name; (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name); for (auto &item : src_name) { op_def->add_src_name(item); } vector src_index; (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index); for (auto &item : src_index) { op_def->add_src_index(item); } vector input; (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input); for (auto &item : input) { op_def->add_input_i(item); } vector output; (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output); for (auto &item : output) { op_def->add_output_i(item); } vector dst_name; (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name); for (auto &item : dst_name) { op_def->add_dst_name(item); } vector dst_index; (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index); for (auto &item : dst_index) { op_def->add_dst_index(item); } vector workspace; (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace); for (auto &item : workspace) { op_def->add_workspace(item); } vector workspace_bytes; (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes); for (auto &item : workspace_bytes) { op_def->add_workspace_bytes(item); } vector is_input_const; (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); for (auto item : is_input_const) { op_def->add_is_input_const(item); } auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list(); if (input_desc_mutable_list != nullptr) { *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td()); } auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list(); if (output_desc_mutable_list != nullptr) { *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td()); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { return proto_msg->name(); } return ""; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->set_name(name); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { return proto_msg->type(); } return ""; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->set_type(type); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) { int index = static_cast(inputs_desc_.size()); return AddInputDesc("__input" + std::to_string(index), input_desc); } graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) { graphStatus ret = GRAPH_SUCCESS; if (index < inputs_desc_.size()) { // InputsDesc[index] is exist, then update it ret = UpdateInputDesc(index, input_desc); } else { // InputDesc[index] is not exist, then add it ret = AddInputDesc(input_desc); } return ret; } graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { if (input_name_idx_.find(name) != input_name_idx_.end()) { GELOGI("input %s is exist, update it", name.c_str()); graphStatus ret = UpdateInputDesc(name, input_desc); return ret; } else { int index = static_cast(inputs_desc_.size()); std::shared_ptr in_desc = ComGraphMakeShared(input_desc); if (in_desc == nullptr) { GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed."); return GRAPH_FAILED; } inputs_desc_.push_back(in_desc); (void)input_name_idx_.insert(make_pair(name, index)); if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { register_input_name_.push_back(name); } return GRAPH_SUCCESS; } } graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { for (unsigned int i = 0; i < num; i++) { string input_name = name + std::to_string(i); GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, "Add input tensor_desc is existed. name[%s]", input_name.c_str()); std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); if (in_desc == nullptr) { GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); return GRAPH_FAILED; } if (index > inputs_desc_.size()) { GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); return GRAPH_FAILED; } (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); // Update index in input_name_idx for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { if (it->second >= (index + i)) { it->second += 1; } } (void)input_name_idx_.insert(make_pair(input_name, i + index)); } return GRAPH_SUCCESS; } graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) { for (unsigned int i = 0; i < num; i++) { string output_name = name + std::to_string(i); GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, "Add input tensor_desc is existed. name[%s]", output_name.c_str()); std::shared_ptr out_desc = ComGraphMakeShared(GeTensorDesc()); if (out_desc == nullptr) { GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); return GRAPH_FAILED; } if (index > outputs_desc_.size()) { GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); return GRAPH_FAILED; } (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc); // Update index in input_name_idx for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { if (it->second >= (index + i)) { it->second += 1; } } (void)output_name_idx_.insert(make_pair(output_name, i + index)); } return GRAPH_SUCCESS; } graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { for (unsigned int i = 0; i < num; i++) { string input_name = name + std::to_string(i); GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, "Add input tensor_desc is existed. name[%s]", input_name.c_str()); std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); if (in_desc == nullptr) { GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); return GRAPH_FAILED; } (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); // Update index in input_name_idx for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { it->second += 1; } (void)input_name_idx_.insert(make_pair(input_name, 0)); } return GRAPH_SUCCESS; } graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) { for (unsigned int i = 0; i < num; i++) { string output_name = name + std::to_string(i); GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, "Add output tensor_desc is existed. name[%s]", output_name.c_str()); std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); if (in_desc == nullptr) { GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed."); return GRAPH_FAILED; } (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc); // Update index in output_name_idx for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { it->second += 1; } (void)output_name_idx_.insert(make_pair(output_name, 0)); } return GRAPH_SUCCESS; } graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; (void)optional_input_names_.insert(name); return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); inputs_desc_[index] = ComGraphMakeShared(tensor_Desc); if (inputs_desc_[index] == nullptr) { GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); return GRAPH_FAILED; } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { const auto &op_def = this->op_def_.GetProtoMsg(); const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg(); if ((op_def != nullptr) && (r_op_def != nullptr)) { // Message OpDef in ge_ir.proto return ( IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), "OpDef_.workspace_bytes()") && IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); } else { return ((op_def == nullptr) && (r_op_def == nullptr)); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual( const OpDesc &r_op_desc) const { // 1.Verify inputs and outputs desc size const auto inputs_desc_size = this->inputs_desc_.size(); const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); if (inputs_desc_size != r_inputs_desc_size) { GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str()); return false; } const auto outputs_desc_size = this->outputs_desc_.size(); const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size(); if (outputs_desc_size != r_outputs_desc_size) { GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str()); return false; } // 2.Verify all inputs desc equal for (uint32_t i = 0; i < inputs_desc_size; i++) { const auto &in_ge_tensor_desc = this->GetInputDesc(i); const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i); // Determine the connection relationship by GeTensorDesc if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) { GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.", this->GetName().c_str()); return false; } } // 3.Verify all outputs desc equal for (uint32_t i = 0; i < outputs_desc_size; i++) { const auto &out_ge_tensor_desc = this->GetOutputDesc(i); const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i); if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) { GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.", this->GetName().c_str()); return false; } } return true; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const { return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) && OpDescGenTensorDescsAreEqual(r_op_desc)); } graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { auto it = input_name_idx_.find(name); if (it == input_name_idx_.end()) { GELOGW("Cann't find the input desc. name[%s]", name.c_str()); return GRAPH_FAILED; } if (it->second >= inputs_desc_.size()) { GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second); return GRAPH_FAILED; } GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return GRAPH_FAILED); inputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); if (inputs_desc_[it->second] == nullptr) { GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); return GRAPH_FAILED; } return GRAPH_SUCCESS; } bool OpDesc::InputIsSet(const string &name) const { auto it = input_name_idx_.find(name); if (it != input_name_idx_.end()) { GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); auto tensor_desc = inputs_desc_[it->second]; GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); auto dims = tensor_desc->GetShape().GetDims(); if (dims.size() > 0) { return true; } } return false; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const { GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc()); return *(inputs_desc_[index].get()); } GeTensorDesc OpDesc::GetInputDesc(const string &name) const { auto it = input_name_idx_.find(name); GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); return *(inputs_desc_[it->second].get()); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); if (inputs_desc_[index] == nullptr) { return nullptr; } if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { GELOGW("input desc is invalid"); return nullptr; } return inputs_desc_[index]; } GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { auto input_name_idx = GetAllInputName(); auto it = input_name_idx.find(name); if (it == input_name_idx.end()) { GELOGW("Failed to get [%s] input desc", name.c_str()); return nullptr; } return MutableInputDesc(it->second); } GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { vector names; if (input_name_idx_.empty()) { return OpDesc::Vistor(shared_from_this(), names); } for (std::pair input : input_name_idx_) { names.push_back(input.first); } return OpDesc::Vistor(shared_from_this(), names); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) { op_kernel_lib_name_ = name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const { return op_kernel_lib_name_; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) { engine_name_ = name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { vector temp{}; for (const auto &it : inputs_desc_) { if (it->IsValid() == GRAPH_SUCCESS) { temp.push_back(*it); } else { GELOGW("this inputDesc is InValid, it won't be return"); continue; } } return OpDesc::Vistor(shared_from_this(), temp); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDescPtr() const { vector temp{}; for (const auto &it : inputs_desc_) { if (it->IsValid() == GRAPH_SUCCESS) { temp.push_back(it); } else { GELOGW("this inputDesc is InValid, it won't be return"); continue; } } return OpDesc::Vistor(shared_from_this(), temp); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. size_t size = 0; for (const auto &it : inputs_desc_) { if (it->IsValid() == GRAPH_SUCCESS) { size++; } } return size; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { int index = static_cast(outputs_desc_.size()); return AddOutputDesc("__output" + std::to_string(index), output_desc); } graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) { GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED, "Add output tensor_Desc is existed. name[%s]", name.c_str()); int index = static_cast(outputs_desc_.size()); std::shared_ptr tensor = ComGraphMakeShared(output_desc); if (tensor == nullptr) { GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed."); return GRAPH_FAILED; } outputs_desc_.push_back(tensor); (void)output_name_idx_.insert(make_pair(name, index)); if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { register_output_name_.push_back(name); } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); outputs_desc_[index] = ComGraphMakeShared(tensor_Desc); if (outputs_desc_[index] == nullptr) { GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { auto it = output_name_idx_.find(name); if (it == output_name_idx_.end()) { GELOGW("Cann't find the output desc. name[%s]", name.c_str()); return GRAPH_FAILED; } GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return GRAPH_FAILED); outputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); if (outputs_desc_[it->second] == nullptr) { GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); return GRAPH_FAILED; } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const { GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc()); return *(outputs_desc_[index].get()); } GeTensorDesc OpDesc::GetOutputDesc(const string &name) const { auto it = output_name_idx_.find(name); GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc()); GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc()); return *(outputs_desc_[it->second].get()); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const { GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index); return outputs_desc_[index]; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { auto it = output_name_idx_.find(name); if (it == output_name_idx_.end()) { GELOGW("Failed to get [%s] output desc", name.c_str()); return nullptr; } return MutableOutputDesc(it->second); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { return static_cast(outputs_desc_.size()); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDesc() const { vector temp{}; for (const auto &it : outputs_desc_) { temp.push_back(*it); } return OpDesc::Vistor(shared_from_this(), temp); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDescPtr() const { return OpDesc::Vistor(shared_from_this(), outputs_desc_); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const { GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(outputs_desc_.size()), nullptr); return outputs_desc_[index]; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const { GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); if (inputs_desc_[index] == nullptr) { return nullptr; } if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { GELOGW("inputsDesc[%u] is InValid", index); return nullptr; } else { return inputs_desc_[static_cast(index)]; } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtrDfault(uint32_t index) const { GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); return inputs_desc_[(int32_t)index]; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { auto it = input_name_idx_.find(name); GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); return inputs_desc_[it->second]; } graphStatus OpDesc::AddRegisterInputName(const std::string &name) { if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { register_input_name_.push_back(name); } return GRAPH_SUCCESS; } vector OpDesc::GetRegisterInputName() const { return register_input_name_; } graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { if (is_push_back) { for (unsigned int i = 0; i < num; i++) { if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; } } else { if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; } if (AddRegisterInputName(name) != GRAPH_SUCCESS) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus OpDesc::AddRegisterOutputName(const string &name) { if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { register_output_name_.push_back(name); } return GRAPH_SUCCESS; } vector OpDesc::GetRegisterOutputName() const { return register_output_name_; } graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { if (is_push_back) { for (unsigned int i = 0; i < num; i++) { if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; } } else { if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; } if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { return GRAPH_FAILED; } return GRAPH_SUCCESS; } bool OpDesc::IsOptionalInput(const string &name) const { return optional_input_names_.find(name) != optional_input_names_.end(); } bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } std::map OpDesc::GetAllInputName() const { return input_name_idx_; } std::map OpDesc::GetAllOutputName() { return output_name_idx_; } bool OpDesc::UpdateInputName(std::map input_name_idx) { bool ret = true; // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name. auto input_map_size = inputs_desc_.size(); auto factory_map_size = input_name_idx.size(); // It indicates that some inputs have no optionalname. // The redundant optionalname of factory needs to be deleted and then assigned if (input_map_size < factory_map_size) { GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); for (auto it = input_name_idx.begin(); it != input_name_idx.end();) { if (it->second >= input_map_size) { it = input_name_idx.erase(it); } else { ++it; } } if (input_name_idx.size() == input_map_size) { GELOGI("UpdateInputName"); input_name_idx_ = input_name_idx; } else { ret = false; GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); } } else if (input_map_size == factory_map_size) { input_name_idx_ = input_name_idx; } else { ret = false; GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); } return ret; } bool OpDesc::UpdateOutputName(std::map output_name_idx) { size_t output_map_size = GetAllOutputsDescSize(); size_t factory_map_size = output_name_idx.size(); if (output_map_size < factory_map_size) { GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, factory_map_size); for (auto it = output_name_idx.begin(); it != output_name_idx.end();) { if (it->second >= output_map_size) { it = output_name_idx.erase(it); } else { ++it; } } if (output_name_idx.size() == output_map_size) { GELOGI("UpdateoutputName"); output_name_idx_ = output_name_idx; return true; } } else if (output_map_size == factory_map_size) { output_name_idx_ = output_name_idx; return true; } else { GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); return false; } GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); return false; } std::function OpDesc::GetInferFunc() const { return infer_func_; } std::function OpDesc::GetVerifyFunc() const { return verifier_func_; } void OpDesc::AddInferFunc(const std::function &func) { infer_func_ = func; } std::function OpDesc::GetInferFormatFunc() const { return infer_format_func_; } void OpDesc::AddInferFormatFunc(const std::function &func) { infer_format_func_ = func; } void OpDesc::AddVerifierFunc(const std::function &func) { verifier_func_ = func; } graphStatus OpDesc::InferShapeAndType() { if (infer_func_ == nullptr) { infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); if (infer_func_ == nullptr) { GELOGW("%s does not have inferfunc_.", GetName().c_str()); /// The infoshape function has not been added for each operator in the current operator information library. /// No infoshape added operator skips the call /// and directly uses the shape information passed down by the upper framework return GRAPH_SUCCESS; } } Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); graphStatus ret = (graphStatus)infer_func_(op_proxy); op_proxy.BreakConnect(); return ret; } graphStatus OpDesc::DefaultInferFormat() { ge::Format first_none_nd_format = FORMAT_ND; auto input_descs = GetAllInputsDescPtr(); auto output_descs = GetAllOutputsDescPtr(); // Overall input and output,get the first non-nd format for (const auto &input_desc : input_descs) { Format origin_format = input_desc->GetOriginFormat(); if (origin_format != FORMAT_ND) { first_none_nd_format = origin_format; break; } } for (const auto &output_desc : output_descs) { Format origin_format = output_desc->GetOriginFormat(); if (origin_format != FORMAT_ND) { first_none_nd_format = origin_format; break; } } // Refresh all input output format GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format); for (const auto &input_desc : input_descs) { Format origin_format = input_desc->GetOriginFormat(); GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format); if (origin_format == FORMAT_ND) { input_desc->SetOriginFormat(first_none_nd_format); input_desc->SetFormat(first_none_nd_format); } } for (const auto &output_desc : output_descs) { Format origin_format = output_desc->GetOriginFormat(); GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format); if (origin_format == FORMAT_ND) { output_desc->SetOriginFormat(first_none_nd_format); output_desc->SetFormat(first_none_nd_format); } } return GRAPH_SUCCESS; } graphStatus OpDesc::OpVerify() { if (verifier_func_ == nullptr) { verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType()); } if (verifier_func_ != nullptr) { Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); graphStatus ret = (graphStatus)verifier_func_(op_proxy); op_proxy.BreakConnect(); return ret; } return GRAPH_SUCCESS; } graphStatus OpDesc::CommonVerify() const { for (const string &iname : GetAllInputNames()) { // Checking shape of all inputs vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); for (int64_t dim : ishape) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( "E19014", {"opname", "value", "reason"}, {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), iname.c_str()); } } // Check all attributes defined const auto &all_attributes = GetAllAttrs(); for (const auto &name : GetAllAttrNames()) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( all_attributes.find(name) == all_attributes.end(), ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, {GetName(), "attribute " + name, "is empty"}); return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { auto it = input_name_idx_.begin(); for (; it != input_name_idx_.end(); ++it) { if (it->second == index) { break; } } GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); return it->first; } int OpDesc::GetInputIndexByName(const string &name) const { auto it_find = input_name_idx_.find(name); GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); return static_cast(it_find->second); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const { auto it = output_name_idx_.begin(); for (; it != output_name_idx_.end(); ++it) { if (it->second == index) { break; } } GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), ""); return it->first; } int OpDesc::GetOutputIndexByName(const string &name) const { auto it_find = output_name_idx_.find(name); GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1); return static_cast(it_find->second); } ProtoAttrMapHelper OpDesc::MutableAttrMap() { if (op_def_.GetProtoMsg() == nullptr) { GELOGE(GRAPH_FAILED, "op def get proto msg failed"); return GeIrProtoHelper(); } return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr()); } ConstProtoAttrMapHelper OpDesc::GetAttrMap() const { return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr()); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->set_id(id); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { return proto_msg->id(); } return 0; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->set_stream_id(stream_id); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { return proto_msg->stream_id(); } return 0; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector &input_name) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_input_name(); for (auto &item : input_name) { proto_msg->add_input_name(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputName() const { vector input_name; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->input_name()) { input_name.push_back(item); } } return input_name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector &src_name) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_src_name(); for (auto &item : src_name) { proto_msg->add_src_name(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcName() const { vector src_name; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->src_name()) { src_name.push_back(item); } } return src_name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector &src_index) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_src_index(); for (auto &item : src_index) { proto_msg->add_src_index(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcIndex() const { vector src_index; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->src_index()) { src_index.push_back(item); } } return src_index; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector &input) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_input_i(); for (auto &item : input) { proto_msg->add_input_i(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputOffset() const { vector input; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->input_i()) { input.push_back(item); } } return input; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector &output) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_output_i(); for (auto &item : output) { proto_msg->add_output_i(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOutputOffset() const { vector output; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->output_i()) { output.push_back(item); } } return output; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector &dst_name) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_dst_name(); for (auto &item : dst_name) { proto_msg->add_dst_name(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstName() const { vector dst_name; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->dst_name()) { dst_name.push_back(item); } } return dst_name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector &depend_names) { auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); if (ret != true) { GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOpInferDepends() const { vector depend_names; (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); return depend_names; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector &dst_index) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_dst_index(); for (auto &item : dst_index) { proto_msg->add_dst_index(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstIndex() const { vector dst_index; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->dst_index()) { dst_index.push_back(item); } } return dst_index; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector &workspace) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_workspace(); for (auto &item : workspace) { proto_msg->add_workspace(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspace() const { vector workspace; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->workspace()) { workspace.push_back(item); } } return workspace; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector &workspace_bytes) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_workspace_bytes(); for (auto &item : workspace_bytes) { proto_msg->add_workspace_bytes(item); } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspaceBytes() const { vector workspace_bytes; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto &item : proto_msg->workspace_bytes()) { workspace_bytes.push_back(item); } } return workspace_bytes; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector &is_input_const) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->clear_is_input_const(); for (auto item : is_input_const) { proto_msg->add_is_input_const(item); } } // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); if (ret != true) { GELOGE(GRAPH_FAILED, "set is_input_const fail."); } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetIsInputConst() const { vector is_input_const; auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { for (auto item : proto_msg->is_input_const()) { is_input_const.push_back(item); } } return is_input_const; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, const int &index) { if (input_name_idx_.find(name) != input_name_idx_.end()) { GELOGI("Restore input name index is existed. name[%s]", name.c_str()); } (void)input_name_idx_.insert(make_pair(name, index)); return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name, const int &index) { if (output_name_idx_.find(name) != output_name_idx_.end()) { GELOGI("Restore output name index is existed. name[%s]", name.c_str()); } (void)output_name_idx_.insert(make_pair(name, index)); return GRAPH_SUCCESS; } graphStatus OpDesc::CallInferFunc(Operator &op) { if (infer_func_ == nullptr) { infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); if (infer_func_ == nullptr) { GELOGW("%s does not have infer func.", GetName().c_str()); return GRAPH_PARAM_INVALID; } } graphStatus graph_status = (graphStatus)infer_func_(op); if (graph_status != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); return GRAPH_FAILED; } return GRAPH_SUCCESS; } graphStatus OpDesc::CallInferFormatFunc(Operator &op) { if (infer_format_func_ == nullptr) { infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType()); if (infer_format_func_ == nullptr) { return DefaultInferFormat(); } } return (graphStatus)infer_format_func_(op); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { if (static_cast(index) >= subgraph_instance_names_.size()) { return ""; } return subgraph_instance_names_.at(index); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetSubgraphInstanceNames() const { return subgraph_instance_names_; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { if (*iter == name) { *iter = ""; return; } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { GELOGI("Add subgraph name is %s", name.c_str()); auto iter = subgraph_names_to_index_.find(name); if (iter != subgraph_names_to_index_.end()) { GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); return GRAPH_FAILED; } auto size = subgraph_names_to_index_.size(); subgraph_names_to_index_[name] = size; subgraph_instance_names_.resize(size + 1); return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphNameIndexes() const { return subgraph_names_to_index_; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, const std::string &name) { GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); if (index >= subgraph_instance_names_.size()) { GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); return GRAPH_PARAM_INVALID; } subgraph_instance_names_[index] = name; return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, SubgraphType type) { subgraph_ir_names_to_type_[name] = type; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphIrNames() const { return subgraph_ir_names_to_type_; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { auto iter = subgraph_ir_names_to_type_.find(name); if (iter == subgraph_ir_names_to_type_.end()) { return kSubgraphTypeEnd; } return iter->second; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. continue; } for (auto name_to_index : subgraph_names_to_index_) { if (name_to_index.second != idx) { // find subgraph name. continue; } subgraph_name = name_to_index.first; return GRAPH_SUCCESS; } } return GRAPH_PARAM_INVALID; } } // namespace ge