From a4868ca4f5ba41c30124ae3a0665c99b919d429c Mon Sep 17 00:00:00 2001 From: taoxiangdong Date: Fri, 9 Oct 2020 17:41:54 +0800 Subject: [PATCH 1/2] git submodule parser and delete metadef --- .gitmodules | 3 + metadef/graph/CMakeLists.txt | 79 - metadef/graph/anchor.cc | 371 --- metadef/graph/attr_value.cc | 38 - metadef/graph/buffer.cc | 112 - metadef/graph/compute_graph.cc | 1314 -------- metadef/graph/debug/ge_log.h | 147 - metadef/graph/debug/ge_op_types.h | 69 - metadef/graph/debug/ge_util.h | 274 -- metadef/graph/debug/graph_debug.cc | 246 -- metadef/graph/debug/graph_debug.h | 48 - metadef/graph/detail/attributes_holder.cc | 241 -- metadef/graph/format_refiner.cc | 508 --- metadef/graph/format_refiner.h | 50 - metadef/graph/ge_attr_define.cc | 1078 ------- metadef/graph/ge_attr_value.cc | 1289 -------- metadef/graph/ge_tensor.cc | 1021 ------ metadef/graph/graph.cc | 384 --- metadef/graph/graph.mk | 294 -- metadef/graph/inference_context.cc | 112 - metadef/graph/model.cc | 190 -- metadef/graph/model_serialize.cc | 763 ----- metadef/graph/module.mk | 3 - metadef/graph/node.cc | 877 ------ metadef/graph/op_desc.cc | 1370 -------- metadef/graph/op_imp.cc | 79 - metadef/graph/operator.cc | 1587 ---------- metadef/graph/operator_factory.cc | 48 - metadef/graph/operator_factory_impl.cc | 149 - metadef/graph/opsproto/opsproto_manager.cc | 187 -- metadef/graph/option/ge_context.cc | 104 - metadef/graph/option/ge_local_context.cc | 60 - metadef/graph/ref_relation.cc | 455 --- metadef/graph/runtime_inference_context.cc | 96 - metadef/graph/shape_refiner.cc | 688 ---- metadef/graph/tensor.cc | 704 ----- metadef/graph/utils/anchor_utils.cc | 102 - metadef/graph/utils/ge_ir_utils.cc | 1178 ------- metadef/graph/utils/ge_ir_utils.h | 206 -- metadef/graph/utils/graph_utils.cc | 2767 ----------------- metadef/graph/utils/mem_utils.h | 32 - metadef/graph/utils/node_utils.cc | 956 ------ metadef/graph/utils/op_desc_utils.cc | 778 ----- metadef/graph/utils/string_utils.h | 68 - metadef/graph/utils/tensor_utils.cc | 401 --- metadef/graph/utils/tuning_utils.cc | 684 ---- metadef/graph/utils/type_utils.cc | 448 --- metadef/inc/external/graph/attr_value.h | 75 - metadef/inc/external/graph/ge_error_codes.h | 38 - metadef/inc/external/graph/graph.h | 81 - .../inc/external/graph/inference_context.h | 76 - metadef/inc/external/graph/operator.h | 289 -- metadef/inc/external/graph/operator_factory.h | 68 - metadef/inc/external/graph/operator_reg.h | 376 --- metadef/inc/external/graph/tensor.h | 131 - metadef/inc/external/graph/types.h | 240 -- metadef/inc/external/register/register.h | 163 - .../external/register/register_error_codes.h | 39 - .../external/register/register_fmk_types.h | 37 - .../inc/external/register/register_types.h | 59 - .../scope/scope_fusion_pass_register.h | 334 -- metadef/inc/graph/anchor.h | 284 -- metadef/inc/graph/attr_value_serializable.h | 191 -- metadef/inc/graph/buffer.h | 82 - metadef/inc/graph/compute_graph.h | 308 -- metadef/inc/graph/debug/ge_attr_define.h | 1122 ------- metadef/inc/graph/def_types.h | 195 -- metadef/inc/graph/detail/any_map.h | 120 - metadef/inc/graph/detail/attributes_holder.h | 165 - .../inc/graph/detail/model_serialize_imp.h | 93 - metadef/inc/graph/ge_attr_value.h | 343 -- metadef/inc/graph/ge_context.h | 46 - metadef/inc/graph/ge_global_options.h | 26 - metadef/inc/graph/ge_local_context.h | 44 - metadef/inc/graph/ge_tensor.h | 193 -- metadef/inc/graph/graph_util.h | 134 - metadef/inc/graph/model.h | 94 - metadef/inc/graph/model_serialize.h | 52 - metadef/inc/graph/node.h | 213 -- metadef/inc/graph/op_desc.h | 328 -- metadef/inc/graph/op_kernel_bin.h | 48 - metadef/inc/graph/operator_factory_impl.h | 56 - metadef/inc/graph/opsproto_manager.h | 46 - metadef/inc/graph/range_vistor.h | 53 - metadef/inc/graph/ref_relation.h | 79 - metadef/inc/graph/runtime_inference_context.h | 46 - metadef/inc/graph/shape_refiner.h | 40 - metadef/inc/graph/tuning_utils.h | 130 - metadef/inc/graph/usr_types.h | 133 - metadef/inc/graph/utils/anchor_utils.h | 45 - metadef/inc/graph/utils/attr_utils.h | 150 - metadef/inc/graph/utils/graph_utils.h | 771 ----- metadef/inc/graph/utils/node_utils.h | 170 - metadef/inc/graph/utils/op_desc_utils.h | 181 -- metadef/inc/graph/utils/tensor_adapter.h | 43 - metadef/inc/graph/utils/tensor_utils.h | 77 - metadef/inc/graph/utils/type_utils.h | 53 - metadef/proto/dump_task.proto | 127 - metadef/proto/fusion_model.proto | 26 - metadef/proto/fwk_adapter.proto | 42 - metadef/proto/ge_api.proto | 104 - metadef/proto/ge_ir.proto | 206 -- metadef/proto/insert_op.proto | 152 - metadef/proto/om.proto | 401 --- metadef/proto/op_mapping_info.proto | 89 - metadef/proto/optimizer_priority.proto | 23 - metadef/proto/task.proto | 170 - parser | 1 + 108 files changed, 4 insertions(+), 32155 deletions(-) create mode 100644 .gitmodules delete mode 100755 metadef/graph/CMakeLists.txt delete mode 100644 metadef/graph/anchor.cc delete mode 100644 metadef/graph/attr_value.cc delete mode 100644 metadef/graph/buffer.cc delete mode 100644 metadef/graph/compute_graph.cc delete mode 100644 metadef/graph/debug/ge_log.h delete mode 100644 metadef/graph/debug/ge_op_types.h delete mode 100644 metadef/graph/debug/ge_util.h delete mode 100644 metadef/graph/debug/graph_debug.cc delete mode 100644 metadef/graph/debug/graph_debug.h delete mode 100644 metadef/graph/detail/attributes_holder.cc delete mode 100644 metadef/graph/format_refiner.cc delete mode 100644 metadef/graph/format_refiner.h delete mode 100644 metadef/graph/ge_attr_define.cc delete mode 100644 metadef/graph/ge_attr_value.cc delete mode 100644 metadef/graph/ge_tensor.cc delete mode 100644 metadef/graph/graph.cc delete mode 100644 metadef/graph/graph.mk delete mode 100644 metadef/graph/inference_context.cc delete mode 100644 metadef/graph/model.cc delete mode 100644 metadef/graph/model_serialize.cc delete mode 100644 metadef/graph/module.mk delete mode 100644 metadef/graph/node.cc delete mode 100644 metadef/graph/op_desc.cc delete mode 100644 metadef/graph/op_imp.cc delete mode 100644 metadef/graph/operator.cc delete mode 100644 metadef/graph/operator_factory.cc delete mode 100644 metadef/graph/operator_factory_impl.cc delete mode 100644 metadef/graph/opsproto/opsproto_manager.cc delete mode 100644 metadef/graph/option/ge_context.cc delete mode 100644 metadef/graph/option/ge_local_context.cc delete mode 100644 metadef/graph/ref_relation.cc delete mode 100644 metadef/graph/runtime_inference_context.cc delete mode 100644 metadef/graph/shape_refiner.cc delete mode 100644 metadef/graph/tensor.cc delete mode 100644 metadef/graph/utils/anchor_utils.cc delete mode 100644 metadef/graph/utils/ge_ir_utils.cc delete mode 100644 metadef/graph/utils/ge_ir_utils.h delete mode 100644 metadef/graph/utils/graph_utils.cc delete mode 100644 metadef/graph/utils/mem_utils.h delete mode 100644 metadef/graph/utils/node_utils.cc delete mode 100644 metadef/graph/utils/op_desc_utils.cc delete mode 100644 metadef/graph/utils/string_utils.h delete mode 100644 metadef/graph/utils/tensor_utils.cc delete mode 100644 metadef/graph/utils/tuning_utils.cc delete mode 100644 metadef/graph/utils/type_utils.cc delete mode 100644 metadef/inc/external/graph/attr_value.h delete mode 100644 metadef/inc/external/graph/ge_error_codes.h delete mode 100644 metadef/inc/external/graph/graph.h delete mode 100644 metadef/inc/external/graph/inference_context.h delete mode 100644 metadef/inc/external/graph/operator.h delete mode 100644 metadef/inc/external/graph/operator_factory.h delete mode 100644 metadef/inc/external/graph/operator_reg.h delete mode 100644 metadef/inc/external/graph/tensor.h delete mode 100644 metadef/inc/external/graph/types.h delete mode 100644 metadef/inc/external/register/register.h delete mode 100644 metadef/inc/external/register/register_error_codes.h delete mode 100644 metadef/inc/external/register/register_fmk_types.h delete mode 100644 metadef/inc/external/register/register_types.h delete mode 100644 metadef/inc/external/register/scope/scope_fusion_pass_register.h delete mode 100644 metadef/inc/graph/anchor.h delete mode 100644 metadef/inc/graph/attr_value_serializable.h delete mode 100644 metadef/inc/graph/buffer.h delete mode 100644 metadef/inc/graph/compute_graph.h delete mode 100644 metadef/inc/graph/debug/ge_attr_define.h delete mode 100644 metadef/inc/graph/def_types.h delete mode 100644 metadef/inc/graph/detail/any_map.h delete mode 100644 metadef/inc/graph/detail/attributes_holder.h delete mode 100644 metadef/inc/graph/detail/model_serialize_imp.h delete mode 100644 metadef/inc/graph/ge_attr_value.h delete mode 100644 metadef/inc/graph/ge_context.h delete mode 100644 metadef/inc/graph/ge_global_options.h delete mode 100644 metadef/inc/graph/ge_local_context.h delete mode 100644 metadef/inc/graph/ge_tensor.h delete mode 100644 metadef/inc/graph/graph_util.h delete mode 100644 metadef/inc/graph/model.h delete mode 100644 metadef/inc/graph/model_serialize.h delete mode 100644 metadef/inc/graph/node.h delete mode 100644 metadef/inc/graph/op_desc.h delete mode 100644 metadef/inc/graph/op_kernel_bin.h delete mode 100644 metadef/inc/graph/operator_factory_impl.h delete mode 100644 metadef/inc/graph/opsproto_manager.h delete mode 100644 metadef/inc/graph/range_vistor.h delete mode 100644 metadef/inc/graph/ref_relation.h delete mode 100644 metadef/inc/graph/runtime_inference_context.h delete mode 100644 metadef/inc/graph/shape_refiner.h delete mode 100644 metadef/inc/graph/tuning_utils.h delete mode 100644 metadef/inc/graph/usr_types.h delete mode 100644 metadef/inc/graph/utils/anchor_utils.h delete mode 100644 metadef/inc/graph/utils/attr_utils.h delete mode 100644 metadef/inc/graph/utils/graph_utils.h delete mode 100644 metadef/inc/graph/utils/node_utils.h delete mode 100644 metadef/inc/graph/utils/op_desc_utils.h delete mode 100644 metadef/inc/graph/utils/tensor_adapter.h delete mode 100644 metadef/inc/graph/utils/tensor_utils.h delete mode 100644 metadef/inc/graph/utils/type_utils.h delete mode 100644 metadef/proto/dump_task.proto delete mode 100644 metadef/proto/fusion_model.proto delete mode 100644 metadef/proto/fwk_adapter.proto delete mode 100644 metadef/proto/ge_api.proto delete mode 100644 metadef/proto/ge_ir.proto delete mode 100644 metadef/proto/insert_op.proto delete mode 100644 metadef/proto/om.proto delete mode 100644 metadef/proto/op_mapping_info.proto delete mode 100644 metadef/proto/optimizer_priority.proto delete mode 100644 metadef/proto/task.proto create mode 160000 parser diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..4a36bfba --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "parser"] + path = parser + url = https://gitee.com/ascend/parser.git diff --git a/metadef/graph/CMakeLists.txt b/metadef/graph/CMakeLists.txt deleted file mode 100755 index 9c649cb2..00000000 --- a/metadef/graph/CMakeLists.txt +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2019-2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -# libgraph.so -# compiling proto files generates some warnings, use no-unused-variable to suppress them -set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") -# add all proto files, generate corresponding .h and .cc files -file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "${GE_SOURCE_DIR}/metadef/proto/om.proto" - "${GE_SOURCE_DIR}/metadef/proto/ge_ir.proto" - "${GE_SOURCE_DIR}/metadef/proto/insert_op.proto" - "${GE_SOURCE_DIR}/metadef/proto/task.proto" - "${GE_SOURCE_DIR}/metadef/proto/fwk_adaper.proto" - "${GE_SOURCE_DIR}/metadef/proto/op_mapping_info.proto" - "${GE_SOURCE_DIR}/metadef/proto/dump_task.proto" - ) - -file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "${onnx_INC}/onnx/onnx.proto" - ) - -ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) - -# need to remove dependencies on pb files later -file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} - "*.cc" - "utils/*.cc" - "opsproto/*.cc" - "detail/*.cc" - "debug/*.cc" - "option/*.cc" - ) - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${GE_SOURCE_DIR}) -#include_directories(${GE_SOURCE_DIR}/src) -include_directories(${GE_SOURCE_DIR}/ge) -include_directories(${GE_SOURCE_DIR}/metadef) -include_directories(${GE_SOURCE_DIR}/metadef/graph) -include_directories(${GE_SOURCE_DIR}/inc) -include_directories(${GE_SOURCE_DIR}/inc/framework) -include_directories(${GE_SOURCE_DIR}/inc/external) -include_directories(${GE_SOURCE_DIR}/metadef/inc) -include_directories(${GE_SOURCE_DIR}/metadef/inc/external/graph) -include_directories(${GE_SOURCE_DIR}/metadef/inc/external) -include_directories(${GE_SOURCE_DIR}/metadef/inc/graph) -include_directories(${GE_SOURCE_DIR}/inc/common) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -include_directories(${GE_SOURCE_DIR}/build) - -######### libgraph.so ############# -add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_ONNX_SRCS}) -target_compile_definitions(graph PRIVATE - DAVINCI_CLOUD - Werror) -target_link_libraries(graph PRIVATE - ${PROTOBUF_LIBRARY} - ${c_sec} - ${slog} - ${error_manager} - rt - dl) diff --git a/metadef/graph/anchor.cc b/metadef/graph/anchor.cc deleted file mode 100644 index f02037e5..00000000 --- a/metadef/graph/anchor.cc +++ /dev/null @@ -1,371 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/anchor.h" -#include -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/node.h" - -namespace ge { -Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), idx_(idx) {} - -bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf(), type) == 0; } - -size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } - -Anchor::Vistor Anchor::GetPeerAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - ret.push_back(anchor.lock()); - } - return Anchor::Vistor(shared_from_this(), ret); -} - -AnchorPtr Anchor::GetFirstPeerAnchor() const { - if (peer_anchors_.empty()) { - return nullptr; - } else { - return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); - } -} - -NodePtr Anchor::GetOwnerNode() const { return owner_node_.lock(); } - -void Anchor::UnlinkAll() noexcept { - if (!peer_anchors_.empty()) { - do { - auto peer_anchor_ptr = peer_anchors_.begin()->lock(); - if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { - GELOGW("unlink peer_anchor_ptr failed."); - } - } while (!peer_anchors_.empty()); - } -} - -graphStatus Anchor::Unlink(const AnchorPtr &peer) { - if (peer == nullptr) { - GELOGE(GRAPH_FAILED, "peer anchor is invalid."); - return GRAPH_FAILED; - } - auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { - auto anchor = an.lock(); - return peer->Equal(anchor); - }); - - GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); - - auto it_peer = - std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr &an) { - auto anchor = an.lock(); - return Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); - - (void)peer_anchors_.erase(it); - (void)peer->peer_anchors_.erase(it_peer); - return GRAPH_SUCCESS; -} - -graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) { - GE_CHK_BOOL_RET_STATUS(old_peer != nullptr, GRAPH_FAILED, "this old peer anchor is nullptr"); - GE_CHK_BOOL_RET_STATUS(first_peer != nullptr, GRAPH_FAILED, "this first peer anchor is nullptr"); - GE_CHK_BOOL_RET_STATUS(second_peer != nullptr, GRAPH_FAILED, "this second peer anchor is nullptr"); - auto this_it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [old_peer](const std::weak_ptr &an) { - auto anchor = an.lock(); - return old_peer->Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(this_it != peer_anchors_.end(), GRAPH_FAILED, "this anchor is not connected to old_peer"); - - auto old_it = std::find_if(old_peer->peer_anchors_.begin(), old_peer->peer_anchors_.end(), - [this](const std::weak_ptr &an) { - auto anchor = an.lock(); - return Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(old_it != old_peer->peer_anchors_.end(), GRAPH_FAILED, - "old_peer is not connected to this anchor"); - *this_it = first_peer; - first_peer->peer_anchors_.push_back(shared_from_this()); - *old_it = second_peer; - second_peer->peer_anchors_.push_back(old_peer); - return GRAPH_SUCCESS; -} - -bool Anchor::IsLinkedWith(const AnchorPtr &peer) { - auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr &an) { - auto anchor = an.lock(); - GE_CHK_BOOL_RET_STATUS(peer != nullptr, false, "this old peer anchor is nullptr"); - return peer->Equal(anchor); - }); - return (it != peer_anchors_.end()); -} - -int Anchor::GetIdx() const { return idx_; } - -void Anchor::SetIdx(int index) { idx_ = index; } - -DataAnchor::DataAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} - -bool DataAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return Anchor::IsTypeOf(type); -} - -InDataAnchor::InDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} - -OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const { - if (peer_anchors_.empty()) { - return nullptr; - } else { - return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); - } -} - -graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) { - // InDataAnchor must be only linkfrom once - if (src == nullptr || !peer_anchors_.empty()) { - GELOGE(GRAPH_FAILED, "src anchor is invalid or the peerAnchors is not empty."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(src); - src->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool InDataAnchor::Equal(AnchorPtr anchor) const { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor); - if (in_data_anchor != nullptr) { - if (GetOwnerNode() == in_data_anchor->GetOwnerNode() && GetIdx() == in_data_anchor->GetIdx()) { - return true; - } - } - return false; -} - -bool InDataAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return DataAnchor::IsTypeOf(type); -} - -OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} - -OutDataAnchor::Vistor OutDataAnchor::GetPeerInDataAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - return OutDataAnchor::Vistor(shared_from_this(), ret); -} - -uint32_t OutDataAnchor::GetPeerInDataNodesSize() const { - uint32_t out_nums = 0; - for (const auto &anchor : peer_anchors_) { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr && in_data_anchor->GetOwnerNode() != nullptr) { - out_nums++; - } - } - return out_nums; -} - -OutDataAnchor::Vistor OutDataAnchor::GetPeerInControlAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - return OutDataAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) { - if (dest == nullptr || !dest->peer_anchors_.empty()) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid or the peerAnchors is not empty."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) { - if (dest == nullptr) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) { - if (dest == nullptr) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool OutDataAnchor::Equal(AnchorPtr anchor) const { - CHECK_FALSE_EXEC(anchor != nullptr, return false); - auto out_data_anchor = Anchor::DynamicAnchorCast(anchor); - if (out_data_anchor != nullptr) { - if (GetOwnerNode() == out_data_anchor->GetOwnerNode() && GetIdx() == out_data_anchor->GetIdx()) { - return true; - } - } - return false; -} - -bool OutDataAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return DataAnchor::IsTypeOf(type); -} - -ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {} - -ControlAnchor::ControlAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} - -bool ControlAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return Anchor::IsTypeOf(type); -} - -InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} - -InControlAnchor::InControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} - -InControlAnchor::Vistor InControlAnchor::GetPeerOutControlAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto out_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (out_control_anchor != nullptr) { - ret.push_back(out_control_anchor); - } - } - return InControlAnchor::Vistor(shared_from_this(), ret); -} - -InControlAnchor::Vistor InControlAnchor::GetPeerOutDataAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto out_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (out_data_anchor != nullptr) { - ret.push_back(out_data_anchor); - } - } - return InControlAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) { - if (src == nullptr) { - GELOGE(GRAPH_FAILED, "src anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(src); - src->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool InControlAnchor::Equal(AnchorPtr anchor) const { - CHECK_FALSE_EXEC(anchor != nullptr, return false); - auto in_control_anchor = Anchor::DynamicAnchorCast(anchor); - if (in_control_anchor != nullptr) { - if (GetOwnerNode() == in_control_anchor->GetOwnerNode()) { - return true; - } - } - return false; -} - -bool InControlAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return ControlAnchor::IsTypeOf(type); -} - -OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} - -OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} - -OutControlAnchor::Vistor OutControlAnchor::GetPeerInControlAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - return OutControlAnchor::Vistor(shared_from_this(), ret); -} - -OutControlAnchor::Vistor OutControlAnchor::GetPeerInDataAnchors() const { - vector ret; - for (const auto &anchor : peer_anchors_) { - auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - return OutControlAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) { - if (dest == nullptr) { - GELOGE(GRAPH_FAILED, "dest anchor is invalid."); - return GRAPH_FAILED; - } - peer_anchors_.push_back(dest); - dest->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool OutControlAnchor::Equal(AnchorPtr anchor) const { - auto out_control_anchor = Anchor::DynamicAnchorCast(anchor); - if (out_control_anchor != nullptr) { - if (GetOwnerNode() == out_control_anchor->GetOwnerNode()) { - return true; - } - } - return false; -} - -bool OutControlAnchor::IsTypeOf(TYPE type) const { - if (strcmp(Anchor::TypeOf(), type) == 0) { - return true; - } - return ControlAnchor::IsTypeOf(type); -} -} // namespace ge diff --git a/metadef/graph/attr_value.cc b/metadef/graph/attr_value.cc deleted file mode 100644 index 066767c2..00000000 --- a/metadef/graph/attr_value.cc +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/attr_value.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" - -namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { impl = ComGraphMakeShared(); } - -#define ATTR_VALUE_SET_GET_IMP(type) \ - graphStatus AttrValue::GetValue(type &val) const { \ - if (impl != nullptr) { \ - GELOGW("GetValue failed."); \ - return impl->geAttrValue_.GetValue(val); \ - } \ - return GRAPH_FAILED; \ - } - -ATTR_VALUE_SET_GET_IMP(AttrValue::STR) -ATTR_VALUE_SET_GET_IMP(AttrValue::INT) -ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT) -} // namespace ge diff --git a/metadef/graph/buffer.cc b/metadef/graph/buffer.cc deleted file mode 100644 index 48cdd397..00000000 --- a/metadef/graph/buffer.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/buffer.h" -#include "proto/ge_ir.pb.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -Buffer::Buffer() { - data_.InitDefault(); - if (data_.GetProtoMsg()) { - buffer_ = data_.GetProtoMsg()->mutable_bt(); - } -} - -Buffer::Buffer(const Buffer &other) { - // Share data - data_ = other.data_; - buffer_ = other.buffer_; -} - -Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default - auto proto_msg = data_.GetProtoMsg(); - if (proto_msg != nullptr) { - try { - proto_msg->set_bt(std::string(buffer_size, default_val)); - buffer_ = proto_msg->mutable_bt(); - } catch (std::bad_alloc &e) { - GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); - buffer_ = nullptr; - } - } -} - -Buffer Buffer::CopyFrom(const std::uint8_t *data, std::size_t buffer_size) { - Buffer buffer; - auto proto_msg = buffer.data_.GetProtoMsg(); - if (proto_msg != nullptr && data != nullptr) { - try { - proto_msg->set_bt(data, buffer_size); - buffer.buffer_ = proto_msg->mutable_bt(); - } catch (std::bad_alloc &e) { - GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); - buffer.buffer_ = nullptr; - } - } - return buffer; -} - -Buffer::Buffer(const std::shared_ptr &proto_owner, proto::AttrDef *buffer) - : data_(proto_owner, buffer) { - if (data_.GetProtoMsg() != nullptr) { - buffer_ = data_.GetProtoMsg()->mutable_bt(); - } -} - -Buffer::Buffer(const std::shared_ptr &proto_owner, std::string *buffer) - : data_(proto_owner, nullptr) { - buffer_ = buffer; -} - -Buffer &Buffer::operator=(const Buffer &other) { - if (&other != this) { - // Share data - data_ = other.data_; - buffer_ = other.buffer_; - } - return *this; -} - -const std::uint8_t *Buffer::GetData() const { - if (buffer_ != nullptr) { - return (const std::uint8_t *)buffer_->data(); - } - return nullptr; -} - -std::uint8_t *Buffer::GetData() { - if (buffer_ != nullptr && !buffer_->empty()) { - // Avoid copy on write - (void)(*buffer_)[0]; - return reinterpret_cast(const_cast(buffer_->data())); - } - return nullptr; -} - -std::size_t Buffer::GetSize() const { - if (buffer_ != nullptr) { - return buffer_->size(); - } - return 0; -} - -void Buffer::ClearBuffer() { - if (buffer_ != nullptr) { - buffer_->clear(); - } -} -} // namespace ge diff --git a/metadef/graph/compute_graph.cc b/metadef/graph/compute_graph.cc deleted file mode 100644 index bae4d362..00000000 --- a/metadef/graph/compute_graph.cc +++ /dev/null @@ -1,1314 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/compute_graph.h" -#include -#include "./format_refiner.h" -#include "./ge_context.h" -#include "debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "ge/ge_api_types.h" -#include "graph/shape_refiner.h" -#include "proto/ge_ir.pb.h" -#include "utils/ge_ir_utils.h" -#include "utils/graph_utils.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/string_utils.h" -#include "utils/tensor_utils.h" - -namespace ge { -namespace { -const size_t OUTPUT_PARAM_SIZE = 2; -const std::string alias_name_attr = "_aliasName"; -bool IsUseBFS() { - string run_mode; - const int base = 10; - if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { - if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { - return true; - } - } else { - GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); - } - return false; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) - : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { - attrs_.InitDefault(); -} - -ComputeGraph::~ComputeGraph() {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { - return GetAllNodes().size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { - std::vector> subgraphs; - return AllGraphNodes(subgraphs); -} - -ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector> &subgraphs) const { - std::vector all_nodes; - std::deque candidates; - - candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); - while (!candidates.empty()) { - NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { - auto subgraph = GetSubgraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph); - candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); - } - } - } - - return Vistor(shared_from_this(), all_nodes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetNodes( - bool is_unknown_shape) const { - if (is_unknown_shape) { - return GetDirectNode(); - } else { - return GetAllNodes(); - } -} - -size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { - return Vistor(shared_from_this(), nodes_); -} - -ComputeGraph::Vistor ComputeGraph::GetInputNodes() const { - return Vistor(shared_from_this(), input_nodes_); -} - -ComputeGraph::Vistor ComputeGraph::GetOutputNodes() const { - std::vector result; - for (auto iter = output_nodes_info_.begin(); iter != output_nodes_info_.end(); ++iter) { - result.push_back(iter->first); - } - return Vistor(shared_from_this(), result); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (node->GetName() == name) { - return node; - } - std::vector out_alias_name; - if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) { - for (const auto &alias_name : out_alias_name) { - if (alias_name == name) { - return node; - } - } - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr -ComputeGraph::FindFirstNodeMatchType(const std::string &name) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (node->GetType() == name) { - return node; - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( - const ComputeGraph &r_graph) const { - // ProtoMsgOwner <::google::protobuf::Message> is temporarily ignored - if ((this->attrs_.protoMsg_ != nullptr) && (r_graph.attrs_.protoMsg_ != nullptr)) { - const auto &proto_attr_map = *(this->attrs_.protoMsg_); - const auto &r_proto_attr_map = *(r_graph.attrs_.protoMsg_); - // 1.Verify graph's ProtoAttrMap size - if (proto_attr_map.size() != r_proto_attr_map.size()) { - GELOGE(GRAPH_FAILED, "Size of compute graph's ProtoAttrMap verify failed, graph name: %s.", - this->GetName().c_str()); - return false; - } - // 2.Verify graph's ProtoAttrMap key, verify values is temporarily not implemented - for (const auto &it : proto_attr_map) { - if (r_proto_attr_map.count(it.first) == 0) { - GELOGE(GRAPH_FAILED, "Key of compute graph's ProtoAttrMap verify failed, graph name: %s key name: %s.", - this->GetName().c_str(), it.first.c_str()); - return false; - } - } - return true; - } - return ((this->attrs_.protoMsg_ == nullptr) && (r_graph.attrs_.protoMsg_ == nullptr)); -} - -/// Since there may be different input nodes -/// chosen by user in the same graph, special judgment is needed -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( - const std::vector &left_nodes, const std::vector &right_nodes) const { - const auto left_nodes_size = left_nodes.size(); - const auto right_nodes_size = right_nodes.size(); - if (left_nodes_size != right_nodes_size) { - GELOGE(GRAPH_FAILED, - "Check failed with graph input_nodes_: " - "left inputNodes size %zu is different with right inputNodes size %zu .", - left_nodes_size, right_nodes_size); - return false; - } - for (size_t j = 0; j < left_nodes_size; j++) { - if (left_nodes.at(j) == nullptr || right_nodes.at(j) == nullptr) { - GELOGE(GRAPH_FAILED, "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j); - return false; - } - const auto &left_input_name = left_nodes.at(j)->GetName(); - const auto &right_input_name = right_nodes.at(j)->GetName(); - if (left_input_name != right_input_name) { - GELOGE(GRAPH_FAILED, - "Check failed with graph input_nodes_: " - "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.", - left_input_name.c_str(), right_input_name.c_str(), j); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( - const ComputeGraph &r_graph) const { - return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && - IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && - VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && - IsEqual(this->name_, r_graph.name_, "graph.name_") && - IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") && - IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") && - IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") && - IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") && - IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") && - IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") && - IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") && - IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_")); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==(const ComputeGraph &r_graph) const { - // Firstly: Graph's members equal - if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) { - return false; - } - - // Secondly: Node equal means the link relationship between node and node itself equal - for (const auto &left_node : nodes_) { - if (left_node == nullptr) { - GELOGE(GRAPH_FAILED, "left_node is nullptr"); - return false; - } - const auto &node_name = left_node->GetName(); - // After TopologicalSorting, node order can change, so find node by name - const auto &right_node = r_graph.FindNode(node_name); - GE_IF_BOOL_EXEC(right_node == nullptr, GELOGE(GRAPH_FAILED, "right_node is NULL!!!"); return false); - if (!(*right_node == *left_node)) { - GELOGE(GRAPH_FAILED, "Compare graph failed, node name: %s.", node_name.c_str()); - return false; - } - } - - // Thirdly: Recursively determine whether the sub graphs are equal - for (size_t i = 0; i < this->sub_graph_.size(); i++) { - if (!(*((this->sub_graph_)[i]) == *((r_graph.sub_graph_)[i]))) { - return false; - } - } - return true; -} - -NodePtr ComputeGraph::AddNodeFront(NodePtr node) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDesc()->SetId(nodes_.size()); - if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { - (void)nodes_.insert(nodes_.begin() + 1, node); - } else { - (void)nodes_.insert(nodes_.begin(), node); - } - return node; -} - -NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(nodes_.size()); - NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); - return AddNodeFront(node_ptr); -} - -NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) { - if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDesc()->SetId(nodes_.size()); - auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); - if (node_iter != nodes_.end()) { - nodes_.insert(node_iter + 1, node); - } else { - GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_."); - return nullptr; - } - - return node; -} - -NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) { - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(nodes_.size()); - NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr); - return AddNodeAfter(node_ptr, pre_node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); - nodes_.push_back(node); - return node; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) { - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(GetDirectNodesSize()); - NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); - return AddNode(node_ptr); -} - -NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. - if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(id); - NodePtr node = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); - node->SetHostNode(is_valid_flag_); - nodes_.push_back(node); - return node; -} - -NodePtr ComputeGraph::AddInputNode(NodePtr node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return nullptr; - } - input_nodes_.push_back(node); - if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); - } - return node; -} - -NodePtr ComputeGraph::AddOutputNode(NodePtr node) { return AddOutputNodeByIndex(node, 0); } - -NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); - return nullptr; - } - - bool already_have = false; - NodePtr result = node; - // [output_nodes_info_ : should not be null] - for (const auto &item : output_nodes_info_) { - if (item.first->GetName() == node->GetName() && item.second == index) { - already_have = true; - result = item.first; - break; - } - } - - if (!already_have) { - output_nodes_info_.emplace_back(std::make_pair(node, index)); - GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index); - } - - if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "add node failed"); - } - return result; -} - -graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { - continue; - } - if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) { - GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, - "Remove edge from const op failed."); - if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) { - GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); - auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); - if (iter != nodes_.end()) { - (void)nodes_.erase(iter); - } - } - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - // delete const op for this node - (void)RemoveConstInput(node); - - // if the node save as input node, delete it - (void)RemoveInputNode(node); - - // if the node save as input node, delete it - (void)RemoveOutputNode(node); - - if (GRAPH_SUCCESS != IsolateNode(node)) { - GELOGE(GRAPH_FAILED, "Isolate node failed, node name: %s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - - auto iter = find(nodes_.begin(), nodes_.end(), node); - if (iter != nodes_.end()) { - (void)nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -// Used in sub_graph scenes -graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - auto iter = find(input_nodes_.begin(), input_nodes_.end(), node); - if (iter != input_nodes_.end()) { - (void)input_nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -// Used in sub_graph scenes -graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - auto iter = output_nodes_info_.begin(); - bool find_node = false; - // [output_nodes_info_ : should not be null] - while (iter != output_nodes_info_.end()) { - if (node->GetName() == iter->first->GetName()) { - iter = output_nodes_info_.erase(iter); - find_node = true; - } else { - ++iter; - } - } - GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); - return GRAPH_SUCCESS; -} - -std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr sub_graph) { - if (sub_graph == nullptr) { - GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); - return nullptr; - } - sub_graph_.push_back(sub_graph); - names_to_subgraph_[sub_graph->GetName()] = sub_graph; - return sub_graph; -} - -graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr &sub_graph) { - if (sub_graph == nullptr) { - GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); - return GRAPH_FAILED; - } - - names_to_subgraph_.erase(sub_graph->GetName()); - auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); - if (iter != sub_graph_.end()) { - (void)sub_graph_.erase(iter); - return GRAPH_SUCCESS; - } else { - GELOGW("find sub_graph failed"); - return GRAPH_SUCCESS; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr &subgraph) { - if (subgraph == nullptr) { - GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - auto parent_graph = subgraph->GetParentGraph(); - if (parent_graph == nullptr) { - GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - auto parent_node = subgraph->GetParentNode(); - if (parent_node == nullptr) { - GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - if (parent_node->GetOwnerComputeGraph() != parent_graph) { - GE_LOGE( - "Try to add a subgraph which parent node's parent graph is not equal to " - "the subgraph's parent graph, subgraph name %s, parent node name %s", - subgraph->GetName().c_str(), parent_graph->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - if (!this->parent_graph_.expired()) { - GELOGW("The subgraphs should only be added to the root graph"); - } - if (name != subgraph->GetName()) { - GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); - } - if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { - GE_LOGE("The subgraph %s existed", name.c_str()); - return GRAPH_PARAM_INVALID; - } - sub_graph_.push_back(subgraph); - names_to_subgraph_[name] = subgraph; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::AddSubgraph(const std::shared_ptr &subgraph) { - if (subgraph == nullptr) { - return GRAPH_PARAM_INVALID; - } - return AddSubgraph(subgraph->GetName(), subgraph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { - auto iter = names_to_subgraph_.find(name); - if (iter == names_to_subgraph_.end()) { - return; - } - for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { - if (*vec_iter == iter->second) { - sub_graph_.erase(vec_iter); - break; - } - } - names_to_subgraph_.erase(iter); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( - const std::shared_ptr &subgraph) { - if (subgraph != nullptr) { - RemoveSubgraph(subgraph->GetName()); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetSubgraph( - const std::string &name) const { - std::shared_ptr parent = parent_graph_.lock(); - if (parent == nullptr) { - auto iter = names_to_subgraph_.find(name); - return iter == names_to_subgraph_.end() ? nullptr : iter->second; - } else { - return parent->GetSubgraph(name); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector> -ComputeGraph::GetAllSubgraphs() const { - return sub_graph_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentGraph() { - return parent_graph_.lock(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( - const shared_ptr &parent) { - parent_graph_ = parent; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr ComputeGraph::GetParentNode() { - return parent_node_.lock(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr &parent) { - parent_node_ = parent; -} - -/// -/// @brief Update input-mapping -/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { - for (auto &input : nodes_) { - if (input->GetType() == DATA) { - uint32_t cur_index = 0; - if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - auto iter = input_mapping.find(cur_index); - if (iter == input_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -/// -/// @brief Update output-mapping -/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::UpdateOutputMapping(const std::map &output_mapping) { - NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT); - if (net_output == nullptr) { - GE_LOGE("UpdateOutputMapping failed: node type %s not exist in graph.", NETOUTPUT); - return GRAPH_FAILED; - } - OpDescPtr op_desc = net_output->GetOpDesc(); - if (op_desc == nullptr) { - GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); - return GRAPH_FAILED; - } - - size_t num = op_desc->GetAllInputsSize(); - for (size_t i = 0; i < num; i++) { - GeTensorDesc tensor = op_desc->GetInputDesc(i); - uint32_t cur_index = 0; - if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - auto iter = output_mapping.find(cur_index); - if (iter == output_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); - return GRAPH_FAILED; - } - if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { - GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { - std::vector node_vec = nodes_; - for (const auto &node : GetDirectNode()) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGW("node or OpDescPtr is nullptr."); - continue; - } - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED); - if (node->GetOpDesc()->GetType() == RECV) { - auto iter = find(node_vec.begin(), node_vec.end(), node); - if (iter == node_vec.end()) { - GELOGW("no node found."); - } else { - (void)node_vec.erase(iter); - } - - auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); - (void)node_vec.insert(dst_iter, node); - } - if (node->GetOpDesc()->GetType() == SEND) { - auto iter = find(node_vec.begin(), node_vec.end(), node); - if (iter == node_vec.end()) { - GELOGW("no node found."); - } else { - (void)node_vec.erase(iter); - } - - auto src_iter = find(node_vec.begin(), node_vec.end(), node->GetInControlNodes().at(0)); - (void)node_vec.insert(src_iter + 1, node); - } - } - nodes_.clear(); - for (size_t i = 0; i < node_vec.size(); ++i) { - NodePtr node = node_vec[i]; - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGW("node or OpDescPtr is nullptr."); - } else { - node->GetOpDesc()->SetId((int64_t)i); - nodes_.push_back(node); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, - std::map &map_in_edge_num, - std::vector &stack) { - GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - - // Only data nodes here - while (!stack.empty()) { - NodePtr node = stack.back(); - stack.pop_back(); - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); - for (const auto &anchor : node->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(anchor); - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); - } - } - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); - } - } - } - GE_IF_BOOL_EXEC( - node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor - : node->GetOutControlAnchor()->GetPeerAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && --iter->second == 0) { - stack.push_back(peer_in_anchor->GetOwnerNode()); - } - }) - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, - std::map &map_in_edge_num, - std::deque &stack) { - GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); - std::vector stack_input; - std::map breadth_node_map; - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - - // Only data nodes here - while (!stack_input.empty() || !stack.empty()) { - NodePtr node = nullptr; - if (!stack.empty()) { - node = stack.back(); - stack.pop_back(); - } else { - node = stack_input.back(); - stack_input.pop_back(); - } - - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); - CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); - - for (const auto &name_node : breadth_node_map) { - (void)stack.push_front(name_node.second); - } - breadth_node_map.clear(); - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map) { - for (const auto &anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && 0 == --iter->second) { - (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); - } - } - - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && 0 == --iter->second) { - (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); - } - } - } - if (node->GetOutControlAnchor() != nullptr) { - for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { - auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end() && 0 == --iter->second) { - (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); - } - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { - auto ret = TopologicalSortingGraph(); - if (ret != SUCCESS) { - GraphUtils::DumpGEGraphToOnnx(*this, "black_box"); - GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); - return ret; - } - - if (sub_graph_.empty()) { - return SUCCESS; - } - - // partition sub graph - for (const auto &sub_graph : sub_graph_) { - ret = sub_graph->TopologicalSortingGraph(); - if (ret != SUCCESS) { - GELOGE(ret, "Sub graph topological sort Failed"); - return ret; - } - } - - std::vector> subgraphs; - auto nodes = AllGraphNodes(subgraphs); - for (size_t i = 0; i < nodes.size(); i++) { - NodePtr node = nodes.at(i); // [node: should not be null] - node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] - } - if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original - GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); - return SUCCESS; - } - sub_graph_.swap(subgraphs); - return SUCCESS; -} - -graphStatus ComputeGraph::TopologicalSortingGraph() { - std::vector node_vec; - std::map map_in_edge_num; - bool use_BFS = IsUseBFS(); - if (use_BFS) { - std::deque stack; - if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } else { - std::vector stack; - if (DFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - - // If they are not equal, there is a closed loop - if (node_vec.size() != nodes_.size()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - itered_nodes_set.insert(node.get()); - } - GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", nodes_.size(), - node_vec.size()); - for (auto &node : nodes_) { - if (itered_nodes_set.count(node.get()) == 0) { - GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); - } - } - return GRAPH_FAILED; - } - - nodes_.clear(); - for (size_t i = 0; i < node_vec.size(); i++) { - NodePtr node = node_vec[i]; // [node: should not be null] - node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] - nodes_.push_back(node); - } - - is_valid_flag_ = true; - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::SortNodes(std::vector &stack, std::map &map_in_edge_num) { - // Record the number of non data nodes but no input nodes - uint32_t spec_node_size = 0; - bool verify_isolated = false; - string run_mode; - const int base = 10; - // Need verify isolated point in PREDICTION mode. - if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { - if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) < TRAIN) { - verify_isolated = true; - } - } - for (const auto &node : GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - map_in_edge_num[node] = static_cast(GetInEdgeSize(node)); - if (map_in_edge_num[node] == 0) { - if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && - (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { - // At present, can only judge the isolated point without input and output. - // It is impossible to judge the situation with multiple output nodes. - if (verify_isolated && GetOutEdgeSize(node) == 0) { - GELOGE(GRAPH_FAILED, "May has isolated nodes in graph, node name: %s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - (void)stack.insert(stack.begin(), node); - spec_node_size++; - continue; - } - // Need to insert the data nodes in reverse order - (void)stack.insert(stack.begin() + spec_node_size, node); - } - } - - /// Make sure the inputs order matches with user-designated - /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_) - /// 2. Compare two indices, if not match, swap the positions of two inputs - /// *: Remind: stack is reverse-order - for (size_t i = 0; i < stack.size(); ++i) { - // If not found in 'inputs_order_', skip it - auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); - GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); - auto inx_i = it_i - inputs_order_.begin(); - for (size_t j = i + 1; j < stack.size(); ++j) { - // If not found in 'inputs_order_', skip it - auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); - GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); - - // Compare index, swap them if it should be - auto inx_j = it_j - inputs_order_.begin(); - GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); - } - } - - return GRAPH_SUCCESS; -} - -size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { - size_t in_edge_size = 0; - if (node == nullptr) { - return in_edge_size; - } - for (const auto &anchor : node->GetAllInDataAnchors()) { - in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); - // Break flow control data loop. - OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); - if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { - NodePtr out_node = out_anchor->GetOwnerNode(); - if (out_node == nullptr) { - GELOGW("out node is nullptr"); - continue; - } - if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) { - GE_IF_BOOL_EXEC(in_edge_size == 0, GELOGE(GRAPH_FAILED, "If [in_edge_size = 0], the result will be reversed"); - return in_edge_size); - in_edge_size -= 1; - } - } - } - if (node->GetInControlAnchor() != nullptr) { - in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); - } - return in_edge_size; -} - -size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { - size_t out_edge_size = 0; - if (node == nullptr) { - return out_edge_size; - } - - // Break flow control data loop. - if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) { - for (const auto &anchor : node->GetAllOutDataAnchors()) { - if (anchor != nullptr) { - out_edge_size = out_edge_size + anchor->GetPeerAnchors().size(); - } - } - } - if (node->GetOutControlAnchor() != nullptr) { - if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { - return 0; - } - out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); - } - return out_edge_size; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { - GELOGI("graph name = %s.", GetName().c_str()); - for (const auto &node : GetAllNodes()) { - GELOGI("node name = %s.", node->GetName().c_str()); - for (const auto &anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - } - auto out_control_anchor = node->GetOutControlAnchor(); - if (out_control_anchor != nullptr) { - for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) { - this->AttrHolder::Swap(graph); - - origGraph_.swap(graph.origGraph_); - - name_.swap(graph.name_); - std::swap(graph_id_, graph.graph_id_); - attrs_.Swap(graph.attrs_); - nodes_.swap(graph.nodes_); - all_nodes_infos_.swap(graph.all_nodes_infos_); - target_nodes_info_.swap(graph.target_nodes_info_); - - input_nodes_.swap(graph.input_nodes_); - inputs_order_.swap(graph.inputs_order_); - std::swap(input_size_, graph.input_size_); - out_nodes_map_.swap(graph.out_nodes_map_); - std::swap(output_size_, graph.output_size_); - output_nodes_info_.swap(graph.output_nodes_info_); - - sub_graph_.swap(graph.sub_graph_); - names_to_subgraph_.swap(graph.names_to_subgraph_); - parent_graph_.swap(graph.parent_graph_); - parent_node_.swap(graph.parent_node_); - - // the members followed should not in the ComputeGraph class - std::swap(is_valid_flag_, graph.is_valid_flag_); - std::swap(is_summary_graph_, graph.is_summary_graph_); - std::swap(need_iteration_, graph.need_iteration_); - params_share_map_.swap(graph.params_share_map_); - op_name_map_.swap(graph.op_name_map_); - std::swap(session_id_, graph.session_id_); - std::swap(data_format_, graph.data_format_); - std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_); - - // Update Node owner. - SetNodesOwner(); - graph.SetNodesOwner(); -} - -void ComputeGraph::SetNodesOwner() { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - node->SetOwnerComputeGraph(shared_from_this()); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto next_nodes = node->GetOutAllNodes(); - // If there is input data side - for (size_t i = 0; i < node->GetAllInDataAnchors().size(); i++) { - auto in_data_anchor = node->GetInDataAnchor(static_cast(i)); - auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (pre_out_data_anchor != nullptr) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_IF_BOOL_EXEC(pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT || - pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP, - continue); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_ctrl_anchor); - auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNode()->GetOutControlAnchor(); - GE_CHECK_NOTNULL(pre_out_ctrl_anchor); - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - } - - // If there is an input control side - auto in_ctrl_anchor = node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor); - for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, - "remove edge failed"); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - } - - for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, return GRAPH_FAILED, - "remove edge failed"); - for (const auto &next_node : next_nodes) { - auto next_in_control_anchor = next_node->GetInControlAnchor(); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "add edge failed"); - } - } - - return RemoveExtraOutEdge(node); -} - -graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - // Remove redundant output edges - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - } - - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - } - } - auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - return GRAPH_FAILED, "remove edge failed"); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraph::Verify() { - bool is_unknown_graph = GetGraphUnknownFlag(); - for (const auto &node_ptr : GetAllNodes()) { - GE_CHECK_NOTNULL(node_ptr); - GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); - GE_IF_BOOL_EXEC(is_unknown_graph, continue); - GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED, - "Verifying %s failed.", node_ptr->GetName().c_str()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferOriginFormat() { - return ge::FormatRefiner::InferOrigineFormat(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferShapeInNeed() { - GE_CHK_BOOL_ONLY_LOG(TopologicalSorting() == GRAPH_SUCCESS, "Verifying failed."); - for (const auto &node_ptr : GetAllNodes()) { - GE_CHECK_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); - bool is_need_infer = false; - (void)ge::AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer); - if (is_need_infer) { - GE_CHK_BOOL_EXEC(node_ptr->Verify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.", - node_ptr->GetName().c_str()); - - graphStatus status = node_ptr->InferShapeAndType(); - GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, - "Op %s does not have the IMPLEMT_INFERFUNC definition," - " and subsequent operators no longer perform shape inference.", - node_ptr->GetName().c_str()); - GE_CHK_BOOL_EXEC(status == GRAPH_SUCCESS, return GRAPH_FAILED, "Inferring %s failed.", - node_ptr->GetName().c_str()); - - for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); - auto output_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(output_tensor, output_tensor.GetShape().GetDims().size()); - (void)out_anchor->GetOwnerNode()->GetOpDesc()->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); - for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - (void)peer_anchor->GetOwnerNode()->GetOpDesc()->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor); - } - } - } - } - return GRAPH_SUCCESS; -} - -ProtoAttrMapHelper ComputeGraph::MutableAttrMap() { return attrs_; } - -ConstProtoAttrMapHelper ComputeGraph::GetAttrMap() const { - return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); -} - -const std::map &ComputeGraph::GetAllNodesInfo() const { return all_nodes_infos_; } - -void ComputeGraph::SetUserDefOutput(const std::string &output_name) { - if (output_name.empty()) { - return; - } - - vector nodes = StringUtils::Split(output_name, ';'); - for (string node : nodes) { - vector item = StringUtils::Split(node, ':'); - if (item.size() != OUTPUT_PARAM_SIZE) { - GELOGW("invalid output param!input:%s", output_name.c_str()); - continue; - } - - int32_t index; - try { - index = stoi(StringUtils::Trim(item[1])); - } catch (const std::out_of_range &) { - GELOGW("outputname cause out of range execption!output_name:%s", output_name.c_str()); - continue; - } catch (const std::invalid_argument &) { - GELOGW("outputname cause invalid argument!output_name:%s", output_name.c_str()); - continue; - } catch (...) { - GELOGW("stoi fail! output_name:%s", output_name.c_str()); - continue; - } - auto iter = out_nodes_map_.find(item[0]); - if (iter == out_nodes_map_.end()) { - out_nodes_map_[item[0]] = std::vector(1, index); - } else { - auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index); - if (idx_iter == iter->second.end()) { - iter->second.push_back(index); - } - } - } -} - -const std::string ComputeGraph::GetOutput() { - static const int resultDefaultSize = 2048; - string result; - result.reserve(resultDefaultSize); - auto iter = out_nodes_map_.begin(); - while (iter != out_nodes_map_.end()) { - auto idxes = iter->second; - for (auto idx : idxes) { - (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";"); - } - ++iter; - } - - return result.substr(0, result.length() - 1); -} -} // namespace ge diff --git a/metadef/graph/debug/ge_log.h b/metadef/graph/debug/ge_log.h deleted file mode 100644 index 14a66709..00000000 --- a/metadef/graph/debug/ge_log.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_ -#define COMMON_GRAPH_DEBUG_GE_LOG_H_ - -#include "graph/ge_error_codes.h" -#include "framework/common/debug/ge_log.h" - -#define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) - -#define GE_LOGI_IF(condition, ...) \ - if ((condition)) { \ - GELOGI(__VA_ARGS__); \ - } - -#define GE_LOGW_IF(condition, ...) \ - if ((condition)) { \ - GELOGW(__VA_ARGS__); \ - } - -#define GE_LOGE_IF(condition, ...) \ - if ((condition)) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - } - -#define GE_CHK_STATUS_RET_NOLOG(expr) \ - do { \ - const ge::graphStatus _status = (expr); \ - if (ge::SUCCESS != _status) { \ - return _status; \ - } \ - } while (0) - -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ - { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - } - -#define GE_IF_BOOL_EXEC(expr, exec_expr) \ - { \ - if (expr) { \ - exec_expr; \ - } \ - } - -#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ - do { \ - const ge::graphStatus _status = (expr); \ - if (_status) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// If expr is true, the log is printed and a custom statement is executed -#define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ - { \ - bool b = (expr); \ - if (b) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - exec_expr; \ - } \ - } - -// Only check error log -#define GE_CHK_BOOL_ONLY_LOG(expr, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GELOGI(__VA_ARGS__); \ - } \ - } while (0) - -// If expr is not true, do not print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - return _status; \ - } \ - } while (0) - -// If expr is not true, the log is printed and a custom statement is executed -#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ - { \ - bool b = (expr); \ - if (!b) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - exec_expr; \ - } \ - } - -// If expr is not true, the log is printed and a custom statement is executed -#define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ - { \ - bool b = (expr); \ - if (!b) { \ - GELOGI(__VA_ARGS__); \ - exec_expr; \ - } \ - } - -// If expr is not GRAPH_SUCCESS, print the log and return the same value -#define GE_CHK_STATUS_RET(expr, ...) \ - do { \ - const ge::graphStatus _status = (expr); \ - if (ge::SUCCESS != _status) { \ - GELOGE(ge::FAILED, __VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ - try { \ - exec_expr0; \ - } catch (...) { \ - GELOGE(ge::FAILED, "Make shared failed"); \ - exec_expr1; \ - } - -#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ diff --git a/metadef/graph/debug/ge_op_types.h b/metadef/graph/debug/ge_op_types.h deleted file mode 100644 index dff87331..00000000 --- a/metadef/graph/debug/ge_op_types.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ -#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ - -namespace ge { -#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name - -GE_REGISTER_OPTYPE(DATA, "Data"); -GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); -GE_REGISTER_OPTYPE(MATMUL, "MatMul"); -GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); -GE_REGISTER_OPTYPE(PERMUTE, "Permute"); -GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); -GE_REGISTER_OPTYPE(_WHILE, "_While"); -GE_REGISTER_OPTYPE(WHILE, "While"); -GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); -GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); -GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); -GE_REGISTER_OPTYPE(SWITCH, "Switch"); -GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); -GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); -GE_REGISTER_OPTYPE(MERGE, "Merge"); -GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); -GE_REGISTER_OPTYPE(ENTER, "Enter"); -GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); -GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); -GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); -GE_REGISTER_OPTYPE(CONSTANT, "Const"); -GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); -GE_REGISTER_OPTYPE(END, "End"); -GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); -GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); -GE_REGISTER_OPTYPE(INITDATA, "InitData"); -GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); -GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); - -GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); -GE_REGISTER_OPTYPE(VARIABLE, "Variable"); -GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); - -GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); - -// Horovod operator -GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); -GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); -GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); -GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); - -GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); - -GE_REGISTER_OPTYPE(RECV, "Recv"); -GE_REGISTER_OPTYPE(SEND, "Send"); -}; // namespace ge -#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ diff --git a/metadef/graph/debug/ge_util.h b/metadef/graph/debug/ge_util.h deleted file mode 100644 index 4c6ae051..00000000 --- a/metadef/graph/debug/ge_util.h +++ /dev/null @@ -1,274 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_ -#define COMMON_GRAPH_DEBUG_GE_UTIL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_log.h" -#include "graph/ge_error_codes.h" - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define GE_DYNAMIC_CAST dynamic_cast -#define GE_DYNAMIC_POINTER_CAST std::dynamic_pointer_cast -#else -#define GE_DYNAMIC_CAST static_cast -#define GE_DYNAMIC_POINTER_CAST std::static_pointer_cast -#endif - -#define GE_RETURN_IF_ERROR(expr) \ - do { \ - const ::ge::optStatus _status = (expr); \ - if (_status) return _status; \ - } while (0) - -#define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ - do { \ - const ::ge::optStatus _status = (expr); \ - if (_status) { \ - GELOGI(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// Verify whether the parameter is true. If yes, return graph failed and record the error log -#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ - do { \ - if (condition) { \ - GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -// Verify whether the parameter is false. If yes, return graph failed and record the error log -#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ - do { \ - bool _condition = (condition); \ - if (!_condition) { \ - GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -// Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ - do { \ - if (condition) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ - do { \ - bool _condition = (condition); \ - if (!_condition) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log -#define GE_CHECK_NOTNULL_EXEC(val, expr) \ - do { \ - if (val == nullptr) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ - expr; \ - } \ - } while (0) - -// Verify whether the parameter is null. If yes, return false and record the error log -#define GE_RT_FALSE_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ - return false; \ - } \ - } while (0) - -// Check whether the parameter is out of range -#define GE_CHECK_SIZE(size) \ - do { \ - if (size == 0) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -/// -/// @ingroup GE_common -/// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float)); -/// -#define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ - uint32_t _var_name; \ - do { \ - uint32_t _expr_size = (_expr); \ - uint32_t _sizeof_size = (_sizeof); \ - if (_expr_size > (0xffffffff) / _sizeof_size) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - _var_name = _sizeof_size * _expr_size; \ - } while (0); - -// Check whether the container is empty -#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ - do { \ - if (vector.empty()) { \ - GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ - return ge::GRAPH_FAILED; \ - } \ - } while (0) - -// Check whether the container is empty and return the specified status code -#define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ - do { \ - if (vector.empty()) { \ - GELOGE(_status, "param[%s] is empty", #vector); \ - return _status; \ - } \ - } while (0) - -/// -/// @ingroup GE_common -/// @brief This macro provides the ability to disable copying constructors and assignment operators. -/// It is usually placed under private -/// -#define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ - TypeName(const TypeName &) = delete; \ - void operator=(const TypeName &) = delete - -/// Check whether the size is 0 or out of range -/// @param锛歴ize锛歋ize to be verified -#define GE_CHECK_SIZE_RANGE(size) \ - do { \ - if (size == 0 || size >= UINT_MAX / 4) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -#define GE_CHECK_SHORT_SIZE_RANGE(size) \ - do { \ - if (size == 0 || size >= UINT_MAX / 2) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ - do { \ - if (size <= 0) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -#define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ - do { \ - if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify that the value on the left is greater than or equal to the value on the right -#define GE_CHECK_GE(lhs, rhs) \ - do { \ - if (lhs < rhs) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Check whether the parameters are equal -#define GE_CHECK_EQ(val1, val2) \ - do { \ - if (val1 != val2) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Verify that the value on the left is less than or equal to the value on the right -#define GE_CHECK_LE(lhs, rhs) \ - do { \ - if (lhs > rhs) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// Check whether the parameters are equal -#define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ - do { \ - if (val1 != val2) { \ - GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ - return ge::GRAPH_PARAM_INVALID; \ - } \ - } while (0) - -// If expr is false, the custom statement is executed -#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - } while (0) - -#define GE_DELETE_NEW_SINGLE(var) \ - do { \ - if (var != nullptr) { \ - delete var; \ - var = nullptr; \ - } \ - } while (0) - -#define GE_DELETE_NEW_ARRAY(var) \ - do { \ - if (var != nullptr) { \ - delete[] var; \ - var = nullptr; \ - } \ - } while (0) - -template -static inline std::shared_ptr ComGraphMakeShared(Args &&... args) { - using T_nc = typename std::remove_const::type; - std::shared_ptr ret(new (std::nothrow) T_nc(std::forward(args)...)); - return ret; -} - -#endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_ diff --git a/metadef/graph/debug/graph_debug.cc b/metadef/graph/debug/graph_debug.cc deleted file mode 100644 index 7ce9db37..00000000 --- a/metadef/graph/debug/graph_debug.cc +++ /dev/null @@ -1,246 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/debug/graph_debug.h" -#include -#include -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" - -#define TAB " " -#define STR_FMT(str) (" \"" + std::string(str) + "\" ") -#define INPUT_ANCHOR_PORT(name) ("__input__" + (name)) -#define OUTPUT_ANCHOR_PORT(name) ("__output__" + (name)) - -namespace ge { -std::unordered_set control_anchor; -std::vector types = { - "DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", - "DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", - "DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; - -std::vector formats = {"FORMAT_NCHW", - "FORMAT_NHWC", - "FORMAT_ND", - "FORMAT_NC1HWC0", - "FORMAT_FRACTAL_Z", - "FORMAT_NC1C0HWPAD", - "FORMAT_NHWC1C0", - "FORMAT_FSR_NCHW", - "FORMAT_FRACTAL_DECONV", - "FORMAT_C1HWNC0", - "FORMAT_FRACTAL_DECONV_TRANSPOSE", - "FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS", - "FORMAT_NC1HWC0_C04", - "FORMAT_FRACTAL_Z_C04", - "FORMAT_CHWN", - "FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS", - "FORMAT_HWCN", - "FORMAT_NC1KHKWHWC0", - "FORMAT_BN_WEIGHT", - "FORMAT_FILTER_HWCK", - "FORMAT_HASHTABLE_LOOKUP_LOOKUPS", - "FORMAT_HASHTABLE_LOOKUP_KEYS", - "FORMAT_HASHTABLE_LOOKUP_VALUE", - "FORMAT_HASHTABLE_LOOKUP_OUTPUT", - "FORMAT_HASHTABLE_LOOKUP_HITS", - "FORMAT_RESERVED"}; - -std::vector data_nodes = {"Const", "Data"}; - -void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &out_) { - if (node == nullptr) { - GELOGI("Some nodes are null."); - return; - } - - bool in_control = false; - auto name = node->GetName(); - out_ << TAB << STR_FMT(name); - auto input_cnt = std::max(static_cast(1), node->GetAllInDataAnchors().size()); - auto output_cnt = std::max(static_cast(1), node->GetAllOutDataAnchors().size()); - if (control_anchor.find(node->GetName()) != control_anchor.end()) { - input_cnt++; - in_control = true; - } - auto max_col = input_cnt * output_cnt; - out_ << "[\n"; - if (find(data_nodes.begin(), data_nodes.end(), node->GetType()) != data_nodes.end()) { - out_ << TAB << TAB << "shape=plaintext, color=goldenrod\n"; - } else { - out_ << TAB << TAB << "shape=plaintext, color=deepskyblue\n"; - } - out_ << TAB << TAB << "label=<\n"; - out_ << TAB << TAB << R"(" << std::endl; - - auto input_anchors = node->GetAllInDataAnchors(); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return ); - if (!input_anchors.empty()) { - out_ << TAB << TAB << ""; - } - for (const auto &anchor : input_anchors) { - string anchor_text = op_desc->GetInputNameByIndex(anchor->GetIdx()); - - out_ << ""; - } - if (in_control) { - string anchor_text = "ctrl"; - out_ << ""; - } - if (!input_anchors.empty()) { - out_ << "\n"; - } - // Node type - out_ << TAB << TAB << "\n"; - // Output - auto output_anchors = node->GetAllOutDataAnchors(); - if (!output_anchors.empty()) { - out_ << TAB << TAB << ""; - } - for (const auto &anchor : output_anchors) { - string anchor_text = op_desc->GetOutputNameByIndex(anchor->GetIdx()); - - out_ << ""; - } - - if (!output_anchors.empty()) { - out_ << "\n"; - } - out_ << TAB << TAB << "
" - << anchor_text << "" - << anchor_text << "
" - << "" << node->GetType() << "
" - << anchor_text << "
\n" << TAB << ">];\n"; -} - -void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag) { - if (node == nullptr) { - GELOGI("Some nodes are null."); - return; - } - auto all_out_anchor = node->GetAllOutDataAnchors(); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return ); - for (const auto &anchor : all_out_anchor) { - auto src_anchor = anchor; - auto src_node_name = node->GetName(); - auto src_anchor_index = op_desc->GetOutputNameByIndex(static_cast(src_anchor->GetIdx())); - auto des_anchors = anchor->GetPeerAnchors(); - for (const auto &peer_in_anchor : des_anchors) { - auto in_data_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); - std::string dst_node_name; - out_ << TAB << STR_FMT(src_node_name); - out_ << ":" << OUTPUT_ANCHOR_PORT(src_anchor_index); - auto op = peer_in_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op, continue); - if (in_data_anchor != nullptr) { - dst_node_name = in_data_anchor->GetOwnerNode()->GetName(); - string des_anchor_index = op->GetInputNameByIndex(static_cast(in_data_anchor->GetIdx())); - out_ << " -> " << STR_FMT(dst_node_name); - out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); - out_ << "["; - } - auto in_control_anchor = Anchor::DynamicAnchorCast(peer_in_anchor); - if (in_control_anchor != nullptr) { - dst_node_name = in_control_anchor->GetOwnerNode()->GetName(); - string des_anchor_index = "ctrl"; - out_ << " -> " << STR_FMT(dst_node_name); - out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); - out_ << "["; - out_ << " style=dashed "; - } - if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { - string label; - auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(src_ops, return ); - auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); - auto dim = src_shape.GetDims(); - std::ostringstream tensor_info; - if (dim.size() > 0) { - for (size_t i = 0; i < dim.size(); i++) { - if (i != dim.size() - 1) { - tensor_info << dim[i] << "x"; - } else { - tensor_info << dim[i]; - } - } - } else { - tensor_info << "?"; - } - auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); - GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return ); - auto format = src_tensor_desc->GetFormat(); - auto datatype = src_tensor_desc->GetDataType(); - tensor_info << " : " << formats[format] << " : " << types[datatype]; - label = tensor_info.str(); - out_ << "label=" << STR_FMT(label); - } - out_ << "]" << std::endl; - } - } -} - -graphStatus GraphDebugPrinter::DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, - uint32_t flag) { - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGI("Compute graph is NULL ."); - return GRAPH_SUCCESS; - } - return DumpGraphDotFile(compute_graph, output_dot_file_name, flag); -} - -graphStatus GraphDebugPrinter::DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, - uint32_t flag) { - if (graph == nullptr) { - GELOGI("graph is null."); - return GRAPH_SUCCESS; - } - std::ostringstream out_; - out_ << "digraph G{\n"; - out_ << TAB << R"(ratio=compress;size="8, 100")" << std::endl; - out_ << TAB << R"(node[fontname="Consolas"])" << std::endl; - out_ << TAB << R"(edge[fontsize = "8" fontname = "Consolas" color="dimgray" ])" << std::endl; - auto all_nodes = graph->GetAllNodes(); - for (const auto &node : all_nodes) { - for (const auto &temp : node->GetAllOutDataAnchors()) { - for (const auto &peer : temp->GetPeerAnchors()) { - auto temp_control_anchor = Anchor::DynamicAnchorCast(peer); - if (temp_control_anchor) { - (void)control_anchor.insert(peer->GetOwnerNode()->GetName()); - } - } - } - } - for (const auto &node : all_nodes) { - DumpNodeToDot(node, out_); - } - for (const auto &node : all_nodes) { - DumpEdgeToDot(node, out_, flag); - } - out_ << "}"; - std::ofstream output_file(output_dot_file_name); - if (output_file.is_open()) { - output_file << out_.str(); - } else { - GELOGW("%s open error.", output_dot_file_name.c_str()); - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/metadef/graph/debug/graph_debug.h b/metadef/graph/debug/graph_debug.h deleted file mode 100644 index 29de632a..00000000 --- a/metadef/graph/debug/graph_debug.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ -#define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ -#include -#include -#include -#include -#include -#include "external/graph/graph.h" -#include "./ge_error_codes.h" -#include "graph/compute_graph.h" -#include "graph/debug/ge_log.h" -#include "graph/node.h" -#include "utils/graph_utils.h" - -namespace ge { -enum DotFileFlag { - // Show nodes, edges, size, type and format - DOT_FLAG_DEFAULT = 0, - DOT_NOT_SHOW_EDGE_LABEL = 1, -}; -class GraphDebugPrinter { - public: - static graphStatus DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, - uint32_t flag = DOT_FLAG_DEFAULT); - static graphStatus DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, - uint32_t flag = DOT_FLAG_DEFAULT); - static void DumpNodeToDot(const NodePtr node, std::ostringstream &out_); - static void DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag = DOT_FLAG_DEFAULT); -}; -} // namespace ge - -#endif // COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ diff --git a/metadef/graph/detail/attributes_holder.cc b/metadef/graph/detail/attributes_holder.cc deleted file mode 100644 index 113f4b6f..00000000 --- a/metadef/graph/detail/attributes_holder.cc +++ /dev/null @@ -1,241 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "detail/attributes_holder.h" -#include -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" -#include "proto/ge_ir.pb.h" - -namespace ge { -using std::map; -using std::unordered_set; -void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { MutableAttrMap().CopyValueFrom(holder.GetAttrMap()); } -graphStatus AttrHolder::SetAttr(const std::string &name, const GeAttrValue &value) { - if (value.IsEmpty()) { - GELOGE(GRAPH_FAILED, "value is empty, key %s", name.c_str()); - return GRAPH_FAILED; - } - auto proto_map = MutableAttrMap().GetProtoMsg(); - auto proto_val = value.value_.GetProtoMsg(); - if (proto_map == nullptr || proto_val == nullptr) { - return GRAPH_FAILED; - } - auto it = proto_map->find(name); - if (it != proto_map->end()) { - if (it->second.value_case() != proto::AttrDef::VALUE_NOT_SET && - it->second.value_case() != proto_val->value_case()) { - return GRAPH_FAILED; - } - } - (*proto_map)[name] = *proto_val; - return GRAPH_SUCCESS; -} - -graphStatus AttrHolder::AddRequiredAttr(const std::string &name) { - if (HasAttr(name)) { - return GRAPH_FAILED; - } - requiredAttrs_.push_back(name); - return GRAPH_SUCCESS; -} - -graphStatus AttrHolder::GetAttr(const std::string &name, GeAttrValue &value) const { - auto proto_map = GetAttrMap().GetProtoMsg(); - auto proto_val = value.value_.GetProtoMsg(); - if (proto_map == nullptr || proto_val == nullptr) { - return GRAPH_FAILED; - } - auto it = proto_map->find(name); - if (it != proto_map->end()) { - *proto_val = it->second; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -bool AttrHolder::HasAttr(const std::string &name) const { - auto proto_map = GetAttrMap().GetProtoMsg(); - if (proto_map != nullptr) { - if (proto_map->find(name) != proto_map->end()) { - return true; - } - } - return std::find(requiredAttrs_.begin(), requiredAttrs_.end(), name) != requiredAttrs_.end(); -} - -graphStatus AttrHolder::DelAttr(const std::string &name) { - auto proto_map = MutableAttrMap().GetProtoMsg(); - if (proto_map == nullptr) { - return GRAPH_FAILED; - } - auto it = proto_map->find(name); - if (it != proto_map->end()) { - (void)proto_map->erase(it); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -const std::map AttrHolder::GetAllAttrs() const { - std::map attr_value_map; - auto proto_map = GetAttrMap().GetProtoMsg(); - if (proto_map != nullptr) { - auto proto_owner = GetAttrMap().GetProtoOwner(); - GE_CHK_BOOL_EXEC(proto_owner != nullptr, return attr_value_map, "proto_owner is nullptr"); - for (const auto &it : *proto_map) { - attr_value_map[it.first] = GeAttrValue(proto_owner, const_cast(&it.second)); - } - } - return attr_value_map; -} - -const std::unordered_set AttrHolder::GetAllAttrNames() const { - std::unordered_set names; - auto proto_map = GetAttrMap().GetProtoMsg(); - if (proto_map != nullptr) { - for (const auto &it : *proto_map) { - (void)names.insert(it.first); - } - } - for (const string &it : requiredAttrs_) { - (void)names.insert(it); - } - return names; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::AttrDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::ShapeDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::NamedAttrs make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = proto_owner->mutable_attr(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = &proto_owner->attr(); - protoOwner_ = proto_owner; -} -} // namespace ge diff --git a/metadef/graph/format_refiner.cc b/metadef/graph/format_refiner.cc deleted file mode 100644 index c716825a..00000000 --- a/metadef/graph/format_refiner.cc +++ /dev/null @@ -1,508 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "format_refiner.h" - -#include -#include -#include -#include -#include - -#include "graph/ref_relation.h" -#include "./compute_graph.h" -#include "./ge_error_codes.h" -#include "./graph/ge_tensor.h" -#include "./operator.h" -#include "./operator_factory.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -using namespace ge; -using namespace std; -namespace ge { -namespace { -const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; -const string kIsGraphInferred = "_is_graph_inferred"; -thread_local RefRelations reflection_builder; -} // namespace - -graphStatus ReflectionProcess(const std::unordered_set &reflection, - std::deque &nodes, ge::Format to_be_set_format) { - for (const auto &cell : reflection) { - auto node = cell.node; - auto in_out_idx = cell.in_out_idx; - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (cell.in_out == ge::NODE_IN) { - auto desc = node->GetOpDesc()->GetInputDesc(static_cast(in_out_idx)); - desc.SetOriginFormat(to_be_set_format); - desc.SetFormat(to_be_set_format); - (void)node->GetOpDesc()->UpdateInputDesc(static_cast(in_out_idx), desc); - } else { - auto desc = node->GetOpDesc()->GetOutputDesc(static_cast(in_out_idx)); - desc.SetOriginFormat(to_be_set_format); - desc.SetFormat(to_be_set_format); - (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(in_out_idx), desc); - } - nodes.push_back(cell.node); - } - - return GRAPH_SUCCESS; -} - -graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { - // 5 meas dim num - if (node_ptr->GetType() != "BiasAdd") { - return GRAPH_SUCCESS; - } - std::unordered_map kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; - for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { - auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); - GE_CHECK_NOTNULL(in_desc); - if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num - continue; - } - auto format = in_desc->GetOriginFormat(); - auto key = TypeUtils::FormatToSerialString(format); - auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; - in_desc->SetOriginFormat(fixed_format); - in_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, - node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { - auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); - GE_CHECK_NOTNULL(out_desc); - if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num - continue; - } - auto format = out_desc->GetOriginFormat(); - auto key = TypeUtils::FormatToSerialString(format); - auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; - out_desc->SetOriginFormat(fixed_format); - out_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, - node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { - ConstGeTensorPtr tensor_value; - if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { - GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(tensor_value); - (void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc()); - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, - std::vector &data_nodes, - std::unordered_map &node_status) { - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "input graph is null"); - return GRAPH_FAILED; - } - anchor_points.clear(); - // Get all anchor point nodes and switch nodes - for (auto &node_ptr : graph->GetAllNodes()) { - if (node_ptr == nullptr) { - return GRAPH_FAILED; - } - auto op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_FAILED; - } - graphStatus status = RefreshConstantOutProcess(graph, op_desc); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); - return GRAPH_FAILED; - } - // consider special node save process - // get all input desc format - bool node_is_all_nd = false; - auto input_size = static_cast(op_desc->GetAllInputsSize()); - for (uint32_t i = 0; i < input_size; i++) { - // Operator pre-set format but not origin format - GE_IF_BOOL_EXEC(op_desc->MutableInputDesc(i) == nullptr, continue); - auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); - // Pre-save data node (only main graph data) and default infer fail - if (node_ptr->GetType() == DATA) { - data_nodes.push_back(node_ptr); - } - if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) { - node_is_all_nd = true; - } - } - // Get all output desc format - auto output_size = static_cast(op_desc->GetOutputsSize()); - for (uint32_t i = 0; i < output_size; i++) { - GE_IF_BOOL_EXEC(op_desc->MutableOutputDesc(i) == nullptr, continue); - auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); - if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { - node_is_all_nd = true; - } - } - // check anchor point valid - if (!node_is_all_nd) { - continue; - } - // special process for biasAdd op - // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg - // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism - // so here do special process - status = BiasAddFormatFixProcess(node_ptr); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); - return GRAPH_FAILED; - } - - GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); - anchor_points.push_back(node_ptr); - } - GELOGI("anchor_points number is %zu", anchor_points.size()); - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node, - std::unordered_map &node_status) { - if (anchor_node == nullptr) { - GELOGE(GRAPH_FAILED, "anchor node is null!"); - return GRAPH_FAILED; - } - std::deque nodes; - nodes.push_back(anchor_node); - while (!nodes.empty()) { - ge::NodePtr node = nodes.front(); - nodes.pop_front(); - graphStatus status = BackInferProcess(nodes, node, node_status); - if (status != GRAPH_SUCCESS && node != nullptr) { - GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str()); - return status; - } - status = ForwardInferProcess(nodes, node, node_status); - if (status != GRAPH_SUCCESS && node != nullptr) { - GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str()); - return status; - } - } - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - - GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str()); - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); - auto in_data_anchor_idx = in_anchor->GetIdx(); - auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - GE_IF_BOOL_EXEC(input_desc == nullptr, continue); - auto to_be_set_format = input_desc->GetOriginFormat(); - if (to_be_set_format == FORMAT_ND) { - GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); - continue; - } - auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx); - continue; - } - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { - GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str()); - continue; - } - // Check format whether have been set - int idx = peer_out_data_anchor->GetIdx(); - // do peer_out_node name and index as key to lookup reflections - ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); - std::unordered_set reflection; - auto status = reflection_builder.LookUpRefRelations(key, reflection); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", - (peer_out_data_node->GetName()).c_str(), idx); - return GRAPH_FAILED; - } - - auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast(idx)); - if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { - auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); - if (dim_num == 0) { - GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); - continue; - } - /// Check whether node to change dims () - /// Because some node will calculate with 5D, C dim maybe multi meaning - auto peer_out_data_node_type = peer_out_data_node->GetType(); - auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); - // 4 means dims num - if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { - GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", - (peer_out_data_node->GetName()).c_str()); - continue; - } - - if (reflection.empty()) { - ge_tensor_desc.SetOriginFormat(to_be_set_format); - ge_tensor_desc.SetFormat(to_be_set_format); - (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast(idx), ge_tensor_desc); - - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); - status = peer_out_data_node->InferOriginFormat(); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); - return GRAPH_FAILED; - } - nodes.push_back(peer_out_data_node); - } else { - auto status = ReflectionProcess(reflection, nodes, to_be_set_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "reflection process failed!"); - return GRAPH_FAILED; - } - } - } - } - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str()); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); - GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); - auto out_data_anchor_idx = out_data_anchor->GetIdx(); - auto to_be_set_format = - node->GetOpDesc()->MutableOutputDesc(static_cast(out_data_anchor_idx))->GetOriginFormat(); - if (to_be_set_format == FORMAT_ND) { - GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); - continue; - } - for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); - - auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); - GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); - GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); - - // Check format whether have been set - int idx = peer_in_data_anchor->GetIdx(); - // do peer_out_node name and index as key to lookup reflections - ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); - std::unordered_set reflection; - auto status = reflection_builder.LookUpRefRelations(key, reflection); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", - (peer_in_data_node->GetName()).c_str(), idx); - return GRAPH_FAILED; - } - auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast(idx)); - if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { - auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); - if (dim_num == 0) { - GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx); - continue; - } - /// Check whether node to change dims () - /// Because some node will calculate with 5D, C dim maybe multi meaning - auto peer_in_data_node_type = peer_in_data_node->GetType(); - auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); - // 4 means dims num - if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { - GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); - continue; - } - - if (reflection.empty()) { - ge_tensor_desc.SetOriginFormat(to_be_set_format); - ge_tensor_desc.SetFormat(to_be_set_format); - (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast(idx), ge_tensor_desc); - - /// Because netoutput node added before infer format ,so netoutput is end condition - /// must set netoutput format , because saved result depend on format - if (peer_in_data_node_type == NETOUTPUT) { - continue; - } - - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); - status = peer_in_data_node->InferOriginFormat(); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); - return GRAPH_FAILED; - } - nodes.push_back(peer_in_data_node); - } else { - auto status = ReflectionProcess(reflection, nodes, to_be_set_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "reflection process failed!"); - return GRAPH_FAILED; - } - } - } - } - } - return GRAPH_SUCCESS; -} - -void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor_points) { - for (const auto &node : anchor_points) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - continue; - } - for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { - if (input_desc != nullptr) { - input_desc->SetOriginFormat(input_desc->GetFormat()); - } - } - for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { - if (output_desc != nullptr) { - output_desc->SetOriginFormat(output_desc->GetFormat()); - } - } - } -} - -graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, - ge::Format data_format, - std::unordered_map &node_status) { - if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { - GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), - TypeUtils::FormatToSerialString(data_format).c_str()); - return GRAPH_SUCCESS; - } - GELOGD("Enter DataNodeFormatProcess"); - std::vector uninfered_data_nodes; - // Check and renew data nodes format - for (const auto &data_node : data_nodes) { - GE_CHECK_NOTNULL(data_node); - auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); - auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); - if (curr_format != FORMAT_ND) { - // Data format has been infered , continue - continue; - } - // Set format for un-infered data node - auto input_descs = op_desc->GetAllInputsDescPtr(); - auto output_descs = op_desc->GetAllOutputsDescPtr(); - - for (const auto &input_desc : input_descs) { - if (input_desc != nullptr) { - input_desc->SetOriginFormat(data_format); - input_desc->SetFormat(data_format); - } - } - for (const auto &output_desc : output_descs) { - if (output_desc != nullptr) { - output_desc->SetOriginFormat(data_format); - output_desc->SetFormat(data_format); - } - } - uninfered_data_nodes.push_back(data_node); - } - // Reinfer format from uninfered data nodes - for (const auto &node : uninfered_data_nodes) { - if (node == nullptr) { - continue; - } - GELOGD("data node [%s] start infer format process", node->GetName().c_str()); - auto status = AnchorProcess(node, node_status); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str()); - return GRAPH_FAILED; - } - } - GELOGD("DataNodeFormatProcess success"); - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { - GELOGI("Enter InferOrigineFormat process!"); - - // True: infered false:no-infered - std::unordered_map node_status; - std::vector anchor_points; - std::vector data_nodes; - // global net format - - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "input graph is null"); - return GRAPH_FAILED; - } - // build reflection relations of boundary - (void)reflection_builder.Clear(); - auto status = reflection_builder.BuildRefRelations(*graph); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); - return GRAPH_FAILED; - } - // User set global net format - status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); - return GRAPH_FAILED; - } - // Refresh origin format of anchor point - RefreshOriginFormatOfAnchor(anchor_points); - // Infer format process - for (const auto &anchor_node : anchor_points) { - if (anchor_node == nullptr) { - continue; - } - status = AnchorProcess(anchor_node, node_status); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str()); - return GRAPH_FAILED; - } - } - /// According to discuss with sys-enginer, data node default format is ND.Its format - /// should be set by infered.But if some data-node can not be got by infer, set context's - /// format for these data nodes. - /// Notice: ignore 5D formats - auto data_format = graph->GetDataFormat(); - status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); - - (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); - - return status; -} - -bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { - bool is_graph_inferred = false; - return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); -} -} // namespace ge diff --git a/metadef/graph/format_refiner.h b/metadef/graph/format_refiner.h deleted file mode 100644 index eca93bae..00000000 --- a/metadef/graph/format_refiner.h +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_FORMAT_REFINER_H_ -#define COMMON_GRAPH_FORMAT_REFINER_H_ - -#include -#include -#include -#include -#include "./compute_graph.h" -#include "./external/graph/types.h" -#include "./ge_error_codes.h" - -namespace ge { -// ShapeRefiner performs shape inference for compute graphs -class FormatRefiner { - public: - static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); - - private: - static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); - static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, - std::vector &data_nodes, - std::unordered_map &node_status); - static graphStatus AnchorProcess(const ge::NodePtr &anchor_node, std::unordered_map &node_status); - static void RefreshOriginFormatOfAnchor(std::vector &anchor_points); - static graphStatus BackInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status); - static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, - std::unordered_map &node_status); - static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, - ge::Format data_format, std::unordered_map &node_status); - static bool IsGraphInferred(const ComputeGraphPtr &graph); -}; -} // namespace ge -#endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/metadef/graph/ge_attr_define.cc b/metadef/graph/ge_attr_define.cc deleted file mode 100644 index 4834c73b..00000000 --- a/metadef/graph/ge_attr_define.cc +++ /dev/null @@ -1,1078 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace ge { -// Public attribute -const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; - -const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; - -const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; - -const std::string ATTR_NAME_NAME = "name"; - -const std::string ATTR_NAME_TYPE = "type"; - -const std::string ATTR_NAME_WEIGHT_NAME = "weight_name"; - -const std::string ATTR_NAME_IS_QUANTIZE_FACTOR = "quantize_factor"; - -const std::string ATTR_NAME_ALPHA = "alpha"; - -const std::string ATTR_NAME_BETA = "beta"; - -const std::string ATTR_NAME_PADMODE = "pad_mode"; - -const std::string ATTR_NAME_PADMODES = "padding"; - -const std::string ATTR_NAME_MODE = "mode"; - -const std::string ATTR_NAME_FILTER = "filter"; - -const std::string ATTR_NAME_BIAS = "bias"; - -const std::string ATTR_NAME_BIAS_TERM = "bias_term"; - -const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; - -const std::string ATTR_NAME_PAD = "pad"; - -const std::string ATTR_NAME_PADS = "pad"; - -const std::string ATTR_NAME_PAD_SIZE = "pad size"; - -const std::string ATTR_NAME_PAD_MODE = "pad mode"; - -const std::string ATTR_NAME_SCALE = "scale"; - -const std::string ATTR_NAME_WINDOWS = "windows"; - -const std::string ATTR_NAME_GLOBAL_POOLING = "global_pooling"; - -const std::string ATTR_NAME_CEIL_MODE = "ceil_mode"; - -const std::string ATTR_NAME_RELUMODE = "relu_mode"; - -const std::string ATTR_NAME_STRIDE_SIZE = "stride size"; - -const std::string ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string ATTR_NAME_ALGO = "algo"; - -const std::string ATTR_NAME_FORMAT = "format"; - -const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; - -const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; - -const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; - -const std::string ATTR_NAME_LRN_K = "lrn_k"; - -const std::string ATTR_NAME_LRN_NORM_REGION = "lrn_normregion"; - -const std::string ATTR_NAME_LRN_LOCAL_SIZE = "lrn_localsize"; - -const std::string ATTR_NAME_LRN_ALPHA = "lrn_alpha"; - -const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; - -const std::string ATTR_NAME_AXIS = "axis"; -const std::string ATTR_NAME_BROADCAST = "broadcast"; - -const std::string ATTR_NAME_OUTPUT = "output"; -const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; -const std::string ATTR_NAME_TIDX = "t_idx"; - -const std::string ATTR_NAME_TPADDINGS = "t_paddings"; -const std::string ATTR_IMG_H = "img_h"; -const std::string ATTR_IMG_W = "img_w"; -const std::string ATTR_NET_H = "net_h"; -const std::string ATTR_NET_W = "net_w"; - -const std::string ATTR_NAME_TMULTIPLES = "t_multiples"; - -const std::string ATTR_NAME_MULTIPLES = "multiples"; - -const std::string ATTR_NAME_T = "T"; -const std::string ATTR_NAME_N = "N"; - -const std::string ATTR_NAME_TSHAPE = "Tshape"; -const std::string ATTR_NAME_NAN_OPT = "nan_opt"; - -const std::string ATTR_NAME_AIPP = "aipp"; -const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; - -const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; -const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; - -const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; - -const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED = "_graph_has_been_added"; -const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; -const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; - -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; -const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; - -const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; -const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; - -const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; -const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; -const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; -const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; -const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; - -const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; -const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; - -const std::string ATTR_NAME_INFERRED_FORMAT = "inferred_format"; -const std::string ATTR_NAME_PRED_PERMUTE_DELETED = "pred_permute_deleted"; -const std::string ATTR_NAME_IGNORE_PRED_FORMAT = "ignore_pred_format"; -const std::string ATTR_NAME_WEIGHTS = "value"; -const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; -const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; -const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; -const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; -const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; -const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; -const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; -const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; -const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; -const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; -const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; - -// Identify node connecting to input and output -const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; -const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; - -// To be deleted -const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; -const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL = "fusion_conv_proposal"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX = "fusion_conv_decodebbox"; -const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_MBOX_LOC_FUSION = "permute_flatten_fusion"; -const std::string SSD_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; -const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; - -// Refinedet -const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; - -const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; -const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; - -// _Arg -const std::string ATTR_NAME_INDEX = "index"; -// _RetVal -const std::string RETVAL_ATTR_NAME_INDEX = "retval_index"; -// Data -const std::string DATA_ATTR_NAME_DATA_TYPE = "data_type"; - -// Send -const std::string SEND_ATTR_EVENT_ID = "event_id"; - -// Recv -const std::string RECV_ATTR_EVENT_ID = "event_id"; - -// convolution -const std::string ATTR_NAME_COEF = "coef"; - -const std::string ATTR_NAME_STRIDE = "stride"; - -const std::string ATTR_NAME_STRIDES = "stride"; - -const std::string ATTR_NAME_DILATION = "dilation"; - -const std::string ATTR_NAME_DILATIONS = "dilation"; - -const std::string CONV_ATTR_NAME_MODE = "mode"; - -const std::string CONV_ATTR_NAME_ALGO = "algo"; - -const std::string CONV_ATTR_NAME_GROUP = "group"; - -const std::string CONV_ATTR_NAME_PAD_MODE = "pad_mode"; - -const std::string CONV_ATTR_NAME_PAD = "pad"; - -const std::string CONV_ATTR_NAME_STRIDE = "stride"; - -const std::string CONV_ATTR_NAME_DILATION = "dilation"; - -const std::string CONV_ATTR_NAME_NUM_OUTPUT = "num_output"; - -const std::string CONV_ATTR_NAME_KERNEL = "kernel"; - -const std::string CONV_ATTR_NAME_FILTER = "filter"; - -const std::string CONV_ATTR_NAME_BIAS = "bias"; - -const std::string CONV_ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string CONV_ATTR_NAME_ADJ = "adj"; - -const std::string CONV_ATTR_NAME_TARGET_SHAPE = "target_shape"; - -const std::string CONV_ATTR_NAME_BEFORE_PAD = "before_pad"; - -const std::string CONV_ATTR_NAME_HAS_BIAS = "has_bias"; - -const std::string NEED_INFER = "isNeedInfer"; - -// Pooling -const std::string POOLING_ATTR_MODE = "mode"; -const std::string POOLING_ATTR_NAN_OPT = "nan_opt"; -const std::string POOLING_ATTR_PAD_MODE = "pad_mode"; -const std::string POOLING_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOLING_ATTR_WINDOW = "window"; -const std::string POOLING_ATTR_PAD = "pad"; -const std::string POOLING_ATTR_STRIDE = "stride"; -const std::string POOLING_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOLING_ATTR_DATA_MODE = "data_mode"; -const std::string POOLING_ATTR_BEFORE_PAD = "before_pad"; -const std::string POOLING_ATTR_NAME_ALGO = "algo"; - -// Eltwise -const std::string ELTWISE_ATTR_MODE = "mode"; -const std::string ELTWISE_ATTR_COEFF = "coeff"; -const std::string ELTWISE_ATTR_WEIGHT = "weight"; -const std::string ELTWISE_ATTR_RELU_FLAG = "relu_flag"; -const std::string ELTWISE_ATTR_ALPHA = "alpha"; -const std::string ELTWISE_ATTR_BETA = "beta"; - -// BatchNorm -const std::string BATCHNORM_ATTR_MODE = "mode"; -const std::string BATCHNORM_ATTR_EPSILON = "epsilon"; -const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS = "use_global_stats"; -const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION = "moving_average_fraction"; -const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; -const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; -const std::string BATCHNORM_ATTR_SCALE = "scale"; -const std::string BATCHNORM_ATTR_BIAS = "bias"; -const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; -const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; -const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; - -// huberloss -const std::string HUBER_LOSS_ATTR_DELTA = "delta"; - -// SSDRealDivTileMul -const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; - -// SSDSumMulRealDivMean -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; - -// ConcatFive2Four -// ConcatFour2Five -const std::string SSD_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_CLASS_NUM = "class_num"; -const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; -const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; -const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; -const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; - -// Scale -const std::string SCALE_ATTR_SCALE = "scale"; -const std::string SCALE_ATTR_BIAS = "bias"; - -// FullConnection -const std::string FULL_CONNECTION_ATTR_FILTER = "filter"; -const std::string FULL_CONNECTION_ATTR_BIAS = "bias"; -const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT = "num_output"; -const std::string FULL_CONNECTION_ATTR_RELU_FLAG = "relu_flag"; -const std::string FULL_ATTR_NAME_ALGO = "algo"; - -// SoftmaxOpParams -const std::string SOFTMAX_ATTR_ALGO = "algo"; -const std::string SOFTMAX_ATTR_MODE = "mode"; - -// SparseSoftmaxCrossEntropy -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE = "cross_entropy_mode"; -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD = "cross_entropy_is_grad"; -// Attr labelSmoothing -const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING = "labelSmoothing"; - -// ApplyMomentum -const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION = "applymomentum_is_graph_fusion"; - -// Activation -const std::string ACTIVATION_ATTR_MODE = "mode"; -const std::string ACTIVATION_ATTR_COEF = "coef"; - -// Concat -const std::string CONCAT_ATTR_NAME_AXIS = "axis"; - -// Const -const std::string CONST_ATTR_NAME_DATA_TRANSTYPE = "data_transtype"; -const std::string CONST_ATTR_NAME_OUTPUT_FORMAT = "output_format"; -const std::string CONST_ATTR_NAME_OUTPUT_TYPE = "output_type"; - -// Roipooling -const std::string ROIPOOLING_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIPOOLING_ATTR_NAME_POOLED_W = "pooled_w"; -const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE = "rio_pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE = "pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO = "sampling_ratio"; - -// DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES = "num_classes"; -const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES = "ocr_num_classes"; -const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD = "nms_threshold"; -const std::string DETECTIONOUTPUT_ATTR_TOP_K = "top_k"; -const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD = "confidence_threshold"; -const std::string DETECTIONOUTPUT_ATTR_IMG_H = "img_h"; -const std::string DETECTIONOUTPUT_ATTR_IMG_W = "img_w"; -const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE = "batch_size"; -// Ssd DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_ETA = "eta"; -const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION = "shared_location"; -const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID = "background_label_id"; -const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE = "code_type"; -const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET = "variance_encoded_in_target"; -const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K = "keep_top_k"; -// Refinedet DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE = "objectness_score"; -// yolo DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_ClASSES = "classes"; -const std::string DETECTIONOUTPUT_ATTR_BIASES = "biases"; -const std::string DETECTIONOUTPUT_ATTR_RELATIVE = "relative"; -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD = "objectness_threshold"; -const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD = "class_threshold"; -const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K = "post_top_k"; -const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY = "iou_threshold_decay"; -const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR = "coor_scale_factor"; -const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION = "yolo_version"; - -// DetectionPostprocess -const std::string POSTPROCESS_ATTR_NAME_CLS_NUM = "cls_num"; -const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH = "conf_thresh"; -const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string POSTPROCESS_ATTR_POST_NMS_TOPN = "post_nms_topn"; -const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT = "bbox_reg_weights"; - -// Spatialtransfrom -const std::string SPTIALTF_ATTR_NAME_OUTPUT_H = "output_h"; -const std::string SPTIALTF_ATTR_NAME_OUTPUT_W = "output_w"; -const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE = "border_value"; -const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM = "affine_transform"; - -// Proposa -const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE = "feat_stride"; -const std::string PROPOSAL_ATTR_NAME_BASE_SIZE = "base_size"; -const std::string PROPOSAL_ATTR_NAME_MIN_SIZE = "min_size"; -const std::string PROPOSAL_ATTR_NAME_RATIO = "ratio"; -const std::string PROPOSAL_ATTR_NAME_SCALE = "scale"; -const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN = "pre_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN = "post_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string PROPOSAL_ATTR_NAME_TOP_SIZE = "top_size"; -const std::string PROPOSAL_ATTR_IMG_H = "img_h"; -const std::string PROPOSAL_ATTR_IMG_W = "img_w"; -// Softmax -const std::string SOFTMAX_ATTR_AXIS = "axis"; - -// Permute -const std::string PERMUTE_ATTR_ORDER = "order"; -const std::string PERMUTE_ATTR_PERM = "perm"; - -// SSD Normalize -const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; -const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED = "channel_shared"; -const std::string SSDNORMALIZE_ATTR_EPS = "eps"; - -// Flatten -const std::string FLATTEN_ATTR_AXIS = "axis"; -const std::string FLATTEN_ATTR_END_AXIS = "end_axis"; - -// SsdPRIORBOX -const std::string SSD_PRIOR_BOX_ATTR_FLIP = "flip"; -const std::string SSD_PRIOR_BOX_ATTR_CLIP = "clip"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_H = "img_h"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_W = "img_w"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_H = "step_h"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_W = "step_w"; -const std::string SSD_PRIOR_BOX_ATTR_OFFSET = "offset"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE = "min_size"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE = "max_size"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM = "min_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM = "max_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO = "aspect_ratio"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; - -// RefinedetDetectionOutput -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; - -// PRelu -const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; - -// Psroi pooling -const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string PSROIPOOLING_ATTR_OUTPUT_DIM = "output_dim"; -const std::string PSROIPOOLING_ATTR_GROUP_SIZE = "group_size"; - -// Power -const std::string POWER_ATTR_NAME_POWER = "power"; -const std::string POWER_ATTR_NAME_SCALE = "scale"; -const std::string POWER_ATTR_NAME_SHIFT = "shift"; - -// log -const std::string LOG_ATTR_NAME_SCALE = "scale"; -const std::string LOG_ATTR_NAME_SHIFT = "shift"; -const std::string LOG_ATTR_NAME_BASE = "base"; -// Pack -const std::string PACK_ATTR_NAME_NUM = "N"; - -// Unpack -const std::string UNPACK_ATTR_NAME_NUM = "num"; -const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; -// Gathernd -const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; - -// Argmax -const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; -const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; -const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; -const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; -const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; -const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; -const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; - -// upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; -const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; - -// Relu -const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; - -// FreeSpaceExtract -const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT = "org_height"; - -// Split -const std::string SPLIT_ATTR_NAME_SLICE_POINT = "slice_point"; -const std::string SPLIT_ATTR_NAME_SIZE_SPLIT = "size_split"; -const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; - -// Tvm -const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; -const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; -const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; -const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; - -// Squeeze -const std::string SQUEEZE_ATTR_AXIS = "axis"; -const std::string SQUEEZE_ATTR_DIMS = "squeeze_dims"; -const std::string SQUEEZE_OP_NAME = "Squeeze"; - -// Stride slice -const std::string STRIDE_SLICE_ATTR_BEGIN_MASK = "begin_mask"; -const std::string STRIDE_SLICE_ATTR_END_MASK = "end_mask"; -const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK = "ellipsis_mask"; -const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK = "new_axis_mask"; -const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK = "shrink_axis_mask"; - -// Slice -const std::string SLICE_ATTR_NAME_BEGINS = "begins"; -const std::string SLICE_ATTR_NAME_SIZES = "sizes"; - -// Roialign -const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; -const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; - -// Generate_rpn_proposal -const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK = "post_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE = "rpn_mini_size"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH = "rpn_proposal_nms_thresh"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH = "rpn_proposal_filter_thresh"; -// Decode_bbox -const std::string DECODE_BBOX_ATTR_DECODECLIP = "decodeClip"; - -// Cast -const std::string CAST_ATTR_DSTT = "DstT"; -const std::string CAST_ATTR_SRCT = "SrcT"; -const std::string CAST_ATTR_DST_TYPE = "dst_type"; -const std::string CAST_ATTR_TRUNCATE = "truncate"; - -// Fastrcnnn predications -const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK = "fsr_topk"; -const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD = "fsr_score_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD = "fsr_nms_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES = "fsr_num_classes"; - -// REORG -const std::string REORG_ATTR_STRIDE = "stride"; -const std::string REORG_ATTR_REVERSE = "reverse"; - -// MERGE -const std::string MERGE_DEAD_INDEX = "merge_dead_index"; -const std::string MERGE_PRENODE_FLAG = "merge_prenode_flag"; -const std::string TO_BE_OUTPUT = "to_be_output"; - -// ENTER -const std::string ENTER_ATTR_FRAME_NAME = "frame_name"; -const std::string ENTER_ATTR_CONSTANT_FLAG = "is_constant"; - -// Concatv2 -const std::string CONCAT_V2_ATTR_TIDX = "Tidx"; -const std::string CONCAT_V2_ATTR_N = "N"; -// SUM -const std::string SUM_ATTR_TIDX = "Tidx"; -const std::string SUM_ATTR_AXIS = "axis"; -const std::string SUM_ATTR_KEEP_DIMS = "keep_dims"; - -// ResizeBilinear -const std::string RESIZE_BILINEAR_ATTR_MODE = "mode"; -const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS = "align_corners"; -const std::string RESIZE_BILINEAR_ATTR_HEIGHT = "height"; -const std::string RESIZE_BILINEAR_ATTR_WIDTH = "width"; -const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR = "zoom_factor"; -const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR = "shrink_factor"; -const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN = "pad_begin"; -const std::string RESIZE_BILINEAR_ATTR_PAD_END = "pad_end"; -const std::string RESIZE_BILINEAR_ATTR_ALPHA = "alpha"; -const std::string RESIZE_BILINEAR_ATTR_BETA = "beta"; - -// RetinaNet -const std::string RETINANET_FILTER_BACKGROUND_TRUE = "retina_conv_filter_background"; -const std::string RETINANET_ANCHOR_FUSION = "retina_anchor_fusion"; - -// MatMul -const std::string MATMUL_TRANSPOSE_X = "transposeX"; -const std::string MATMUL_TRANSPOSE_W = "transposeW"; -const std::string MATMUL_HAS_BIAS = "has_bias"; -const std::string MATMUL_ATTR_IS_TRAINING = "matmul_is_training"; - -// Flatten -const std::string FLATTEN_START_AXIS = "start_axis"; -const std::string FLATTEN_END_AXIS = "end_axis"; - -// Reshape -const std::string RESHAPE_ATTR_AXIS = "axis"; -const std::string RESHAPE_ATTR_NUM_AXES = "num_axes"; -const std::string RESHAPE_ATTR_FORMAT = "format"; -const std::string RESHAPE_ATTR_SHAPE = "shape"; -const std::string RESHAPE_ATTR_ALPHA = "alpha"; -const std::string RESHAPE_ATTR_BETA = "beta"; - -// Frameoworkop -const std::string T_IN_DATATYPE = "t_in_datatype"; -const std::string T_OUT_DATATYPE = "t_out_datatype"; -const std::string ATTR_NAME_OUT_N = "out_n"; -const std::string ATTR_NAME_OUT_C = "out_c"; -const std::string ATTR_NAME_OUT_H = "out_h"; -const std::string ATTR_NAME_OUT_W = "out_w"; -const std::string ATTR_PAD_DEPTH_CONV = "pad_depth_conv"; -const std::string ATTR_PAD_CONV = "pad_conv"; - -const std::string ATTR_NAME_BEFORE_PAD = "before_pad"; -const std::string ANN_MEAN_KEEPDIMS = "AnnMeanKeepDims"; -const std::string PAD_ATTR_PADDINGDS = "paddings"; -const std::string PAD_ATTR_CONSTANT_VALUE = "padvalue"; - -// ConvGradFilter -const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape"; -// ConvGradInput -const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; - -// Rnn -const std::string RNN_MODE_STATIC = "rnn_static"; -const std::string MUTI_RNN = "multi_rnn"; -const std::string CNN_RNN = "cnn_rnn"; -const std::string RNN_MODE_ = "rnn_"; - -const std::string CELL_MODE = "mode"; -const std::string LSTM_CELL = "lstm_cell"; -const std::string GRU_CELL = "gru_cell"; -const std::string RNN_HT = "ht"; -const std::string RNN_XT_HT = "xt_ht"; -const std::string RNN_BATCH_SIZE = "batch_size"; -const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; -const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; -const std::string LSTM_ACTIVATE = "lstm_activate"; -const std::string LSTM_OUT_MAP = "lstm_out_map"; -const std::string LSTM_OUT_MODE = "lstm_out_mode"; -const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; -const std::string LSTM_TIME_MAJOR = "lstm_time_major"; -const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; - -// Upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; - -// PadV2 -const std::string PADV2_ATTR_NAME_MODE = "mode"; -const std::string PADV2_ATTR_NAME_PADS = "paddings"; -const std::string PADV2_ATTR_NAME_T = "T"; -const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; - -// MirrorPad -const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; -const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; -const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; -// Filler -const std::string FILLER_TYPE = "filler_type"; -const std::string FILLER_VALUE = "filler_value"; - -// Shufflechannel -const std::string SHUFFLE_CHANNEL_GROUP = "group"; - -// TopKV2 -const std::string TOPKV2_ATTR_K = "k"; - -// Calibaration -const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; -const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; -const std::string PAD_TOP_INDEX = "PAD_TOP_INDEX"; -const std::string PAD_BOTTOM_INDEX = "PAD_BOTTOM_INDEX"; -const std::string PAD_RIGHT_INDEX = "PAD_RIGHT_INDEX"; -const std::string PAD_LEFT_INDEX = "PAD_LEFT_INDEX"; -const std::string QUANTIZE_ALGO_ATTR = "quantize_algo"; -const std::string SCALE_TYPE_ATTR = "scale_type"; - -const std::string QUANTIZE_SCALE_MODE = "quantize_scale_mode"; -const std::string QUANTIZE_SCALE_VALUE = "quantize_scale_value"; -const std::string QUANTIZE_SCALE_OFFSET = "quantize_scale_offset"; -const std::string QUANTIZE_OFFSET_DATA_VALUE = "quantize_offset_data_value"; -const std::string QUANTIZE_OFFSET_DATA_OFFSET = "quantize_offset_data_offset"; -const std::string QUANTIZE_OFFSET_WEIGHT_VALUE = "quantize_offset_weight_value"; -const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET = "quantize_offset_weight_offset"; -const std::string QUANTIZE_OFFSET_PAD_VALUE = "quantize_offset_pad_value"; -const std::string QUANTIZE_OFFSET_PAD_OFFSET = "quantize_offset_pad_offset"; - -const std::string DEQUANTIZE_SCALE_MODE = "dequantize_scale_mode"; -const std::string DEQUANTIZE_SCALE_VALUE = "dequantize_scale_value"; -const std::string DEQUANTIZE_SCALE_OFFSET = "dequantize_scale_offset"; -const std::string DEQUANTIZE_OFFSET_DATA_TYPE = "dequantize_offset_data_value"; -const std::string DEQUANTIZE_OFFSET_DATA_OFFSET = "dequantize_offset_data_offset"; -const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE = "dequantize_offset_weight_value"; -const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET = "dequantize_offset_weight_offset"; -const std::string DEQUANTIZE_OFFSET_PAD_VALUE = "dequantize_offset_pad_value"; -const std::string DEQUANTIZE_OFFSET_PAD_OFFSET = "dequantize_offset_pad_offset"; - -const std::string REQUANTIZE_SCALE_MODE = "requantize_scale_mode"; -const std::string REQUANTIZE_SCALE_VALUE = "requantize_scale_value"; -const std::string REQUANTIZE_SCALE_OFFSET = "requantize_scale_offset"; -const std::string REQUANTIZE_OFFSET_DATA_VALUE = "requantize_offset_data_value"; -const std::string REQUANTIZE_OFFSET_DATA_OFFSET = "requantize_offset_data_offset"; -const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE = "requantize_offset_weight_value"; -const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET = "requantize_offset_weight_offset"; -const std::string REQUANTIZE_OFFSET_PAD_VALUE = "requantize_offset_pad_value"; -const std::string REQUANTIZE_OFFSET_PAD_OFFSET = "requantize_offset_pad_offset"; - -const std::string ATTR_NAME_IS_CONST = "attr_name_is_const"; - -const std::string ATTR_NAME_GROUP = "group"; -const std::string ATTR_NAME_DILATION_SIZE = "dilation_size"; -const std::string ATTR_NAME_EPSILON = "epsilon"; -const std::string ATTR_NAME_POOLING_MODE = "mode"; -const std::string ATTR_NAME_CLASS_NUM = "class_num"; -// model -const std::string ATTR_MODEL_TARGET_TYPE = "target_type"; - -const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; - -const std::string ATTR_MODEL_EVENT_NUM = "event_num"; - -const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; - -const std::string ATTR_MODEL_LABEL_NUM = "label_num"; - -const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; - -const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; - -const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; - -const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; - -const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; - -const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; - -const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR = "task_gen_variable_addr"; - -const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; - -const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; - -const std::string ATTR_MODEL_CORE_TYPE = "core_type"; - -const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; - -const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; - -// Public attribute -const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; - -const std::string ATTR_NAME_BYTE_SIZE = "op_byte_size"; - -const std::string ATTR_NAME_FUSION_INFERENCE_ID = "fusion_inference_id"; - -const std::string ATTR_NAME_FUSION_OPDEF = "fusion_opdef"; - -const std::string ATTR_NAME_IO_OP = "io_op"; - -const std::string ATTR_NAME_FUSION_SCOPE = "fusion_scope"; - -const std::string ATTR_NAME_OPATTR = "opattr"; - -const std::string ATTR_NAME_RELUFLAG = "relu_flag"; - -const std::string ATTR_NAME_SEQLEN_INDEX = "seqlen_index"; - -const std::string ATTR_NAME_X_INDEX = "x_index"; - -const std::string ATTR_NAME_CONT_INDEX = "cont_index"; - -const std::string ATTR_NAME_XSTATIC_INDEX = "xstatic_index"; - -const std::string TARGET_TYPE_MINI = "MINI"; - -const std::string TARGET_TYPE_TINY = "TINY"; - -const std::string TARGET_TYPE_LITE = "LITE"; - -// l2_normalize -const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; -const std::string L2_NORMALIZE_ATTR_EPS = "eps"; - -const std::string POOL_PARAMA_ATTR_WINDOW = "window"; -const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; -const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; -const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; - -// HCOM -const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; -const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; - -const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; -const std::string HCOM_ATTR_GROUP = "group"; -const std::string HCOM_ATTR_SR_TAG = "sr_tag"; -const std::string HCOM_ATTR_SRC_RANK = "src_rank"; -const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; -const std::string HCOM_ATTR_FUSION = "fusion"; -const std::string HCOM_ATTR_SHAPE = "shape"; -const std::string HCOM_ATTR_DATA_TYPE = "dtype"; - -// SpaceToDepth/DepthToSpace -const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; - -// SparseSoftmaxCrossEntropyWithLogits -const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; - -// MaxPoolGradWithArgmax -const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; - -// AvgPoolGrad -const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; - -// Pad -const std::string ATTR_PAD_FORMAT = "attr_pad_format"; - -// Varible -const std::string VAR_ATTR_FORMAT = "_var_format"; -const std::string VAR_ATTR_NAME = "var_name"; -const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; -const std::string VAR_ATTR_4D_FORMAT = "4D"; -const std::string VAR_ATTR_5D_FORMAT = "5D"; -const std::string VAR_ATTR_DATA_TYPE = "data_format"; -const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; -const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; -const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; -const std::string VAR_ATTR_SHAPE = "shape"; -const std::string HALF_VAR_NAME_END = "_fp16"; -const std::string VAR_ATTR_INITED = "var_is_inited"; - -const std::string VAR_ATTR_CONTAINER = "container"; -const std::string VAR_ATTR_SHARED_NAME = "shared_name"; -const std::string VAR_ATTR_DTYPE = "dtype"; - -const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; -const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; -const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; -const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; -const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; -const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; - -// Assign -const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; -const std::string ASSIGN_VAR_NAME = "_assign_var_name"; - -// space2bacth batch2space -const std::string BATCH_SPACE_ATTR_BLOCK = "block"; -const std::string BATCH_SPACE_ATTR_PADDING = "padding"; - -// depth_to_space space_to_depth -const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; - -// FakeQuantWithMinMaxVars -const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; -const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; - -// mobilenet_ssd_conv_fusion -const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; -const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; -const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; - -// lsh project -const std::string LSH_PROJ_TYPE = "lsh_project_type"; - -// log time stamp -const std::string LOG_TIME_STAMP_LOGID = "logid"; -const std::string LOG_TIME_STAMP_NOTIFY = "notify"; - -// ShapeN -const std::string SHAPEN_ATTR_N = "N"; -const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; -const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; - -// GatherV2 attr def -const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; -const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; - -// Reshape attr def -const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; -const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; - -// axis attr def -const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; - -const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; - -const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; -const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; - -// For constant folding -const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; - -const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; - -const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; - -const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; - -const std::string ATTR_NAME_REFERENCE = "reference"; - -const std::string ATTR_NAME_NOTASK = "_no_task"; - -const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; - -const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; - -const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; - -const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; - -const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; - -// Used for mark the active label list stream of activated node -const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; - -// Used for l2cache, true: the memory of all inputs is used for the last time. -const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; - -// Multi batch -const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; -const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; -const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; -const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; - -// Control flow -const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; -const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; -const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; -const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; -const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; -const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; -const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; -const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; - -const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; -const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; -const std::string ATTR_NAME_SWITCH_DATA_TYPE = "_switch_data_type"; -const std::string ATTR_NAME_ORIG_NODE_NAME = "_original_node_name"; -const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; - -const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; - -// Function Op -const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; - -// Used for mark the active node is for loop, type:bool -const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; - -const std::string ATTR_NAME_MEMORY_TYPE_INPUT = "memory_type_input"; - -const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT = "memory_type_output"; - -const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; - -const std::string ATTR_NAME_MEMORY_TYPE_RANGE = "_memory_type_range"; - -const std::string MODEL_ATTR_SESSION_ID = "session_id"; - -// lx fusion -const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; -const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; -const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; -const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; -const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; -const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; -const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; -const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; -const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; -const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; -const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; -const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; -const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; -const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; -const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; -const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; -const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; -const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; -const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; -const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; -const std::string ATTR_NAME_ENGINE_NAME_FOR_LX = "_lxfusion_engine_name"; -const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX = "_lxfusion_op_kernel_lib_name"; -const std::string ATTR_NAME_NEED_LX_FUSION = "_lx_fusion"; -const std::string ATTR_NAME_OPTIMIZE_GROUP = "_optimize_group"; -const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; -const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; -const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; - -// Op debug attrs -const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; -const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; - -// Atomic addr clean attrs -const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; -const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; -const std::string ATOMIC_ATTR_IS_FUSION_NODE = "is_fusion_node"; -const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO = "sub_node_workspace_info"; -const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET = "sub_node_workspace_offset"; -const std::string ATOMIC_ATTR_IS_ATOMIC_NODE = "is_atomic_node"; - -// Source/dst format for Op FormatTransfer -const std::string FORMAT_TRANSFER_SRC_FORMAT = "src_format"; -const std::string FORMAT_TRANSFER_DST_FORMAT = "dst_format"; - -// For compile op by ge call -const std::string ATTR_NEED_COMPILE = "_node_need_compile"; - -const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; - -const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; - -const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; - -const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER = "user_designate_shape_order"; - -// For inserted op -const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; - -// For compress weight -const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; - -// For data dump -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; -const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; -const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX = "_datadump_sub_spliter_index"; -const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME = "_datadump_group_op_name"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME = "_datadump_origin_name"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_output_index"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; - -// functional ops attr -const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; -const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; -const std::string ATTR_NAME_WHILE_COND = "cond"; -const std::string ATTR_NAME_WHILE_BODY = "body"; - -// used for label switch -const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; -const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; -const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; - -const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; -const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; - -// used for LX tiling -const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; -const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; -const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; -const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; -const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; -const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; - -// for unregistered op -const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; -const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; - -// used for Horovod -const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; -const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; -// used for allreduce tailing optimization -const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; -const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; - -// dynamic shape attr -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; - -// atc user def dtype&format -const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; -const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; - -// for fusion op plugin -const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; - -// graph partition for aicpu -const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; -const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; - -// input and output memory type -const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; -const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; -const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; - -// input_output_offset -const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; -const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; -} // namespace ge diff --git a/metadef/graph/ge_attr_value.cc b/metadef/graph/ge_attr_value.cc deleted file mode 100644 index a8490470..00000000 --- a/metadef/graph/ge_attr_value.cc +++ /dev/null @@ -1,1289 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "external/graph/graph.h" -#include "utils/attr_utils.h" -#include "framework/common/debug/ge_log.h" -#include "graph/model_serialize.h" -#include "proto/ge_ir.pb.h" -#include "detail/model_serialize_imp.h" -#include "debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" - -using std::map; -using std::string; -using std::vector; - -namespace ge { -NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } - -NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) - : named_attrs_(owner, proto_msg) {} // lint !e1744 - -void NamedAttrs::SetName(const std::string &name) { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_name(name); - } -} - -string NamedAttrs::GetName() const { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->name(); - } - return string(); -} - -GeAttrValue NamedAttrs::GetItem(const string &key) const { - GeAttrValue value; - (void)GetAttr(key, value); - return value; -} - -ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); - } - return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); -} - -ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { - auto proto_msg = named_attrs_.GetProtoMsg(); - if (proto_msg != nullptr) { - return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); - } - return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); -} - -class GeAttrValueImp { - public: - static map attr_val_one_type_map_; - static map attr_val_list_type_map_; - - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::INT val); - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::FLOAT val); - static bool SetValue(proto::AttrDef &attr_def, GeAttrValue::BOOL val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::STR &val); - static bool SetValue(proto::AttrDef &attr_def, const ConstGeTensorPtr &val); - static bool SetValue(proto::AttrDef &attr_def, const GeTensor &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::TENSOR_DESC &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::BYTES &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::NAMED_ATTRS &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::GRAPH &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_FLOAT &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BOOL &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_STR &val); - static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); - static bool SetValue(proto::AttrDef &proto_attr_val, const vector &value); - static bool SetValue(proto::AttrDef &attr_def, const vector &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_TENSOR_DESC &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_BYTES &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_NAMED_ATTRS &val); - static bool SetValue(proto::AttrDef &attr_def, const GeAttrValue::LIST_GRAPH &val); - - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::INT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::FLOAT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BOOL &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::STR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::TENSOR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeTensor &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::TENSOR_DESC &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::BYTES &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::NAMED_ATTRS &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, GeAttrValue::GRAPH &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_INT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_FLOAT &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_BOOL &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_STR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_TENSOR &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, vector &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_TENSOR_DESC &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_BYTES &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_NAMED_ATTRS &val); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - GeAttrValue::LIST_GRAPH &val); - // Value will be moved - static bool SetZeroCopyBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &&buffer); - static bool GetZeroCopyBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, Buffer &buffer); - // Value will be moved - static bool SetZeroCopyListBytes(proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &list_buffer); - static bool GetZeroCopyListBytes(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &list_buffer); - - static bool SetValue(proto::AttrDef &attr_def, const vector> &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector> &value); - static bool SetValue(proto::AttrDef &attr_def, const vector &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, - vector &value); - - static bool SetValue(proto::AttrDef &attr_def, const ge::DataType &value); - static bool GetValue(const proto::AttrDef &attr_def, const ProtoMsgOwner &proto_msg_owner, ge::DataType &value); -}; - -map GeAttrValueImp::attr_val_one_type_map_ = { - {proto::AttrDef::kI, GeAttrValue::VT_INT}, - {proto::AttrDef::kF, GeAttrValue::VT_FLOAT}, - {proto::AttrDef::kB, GeAttrValue::VT_BOOL}, - {proto::AttrDef::kS, GeAttrValue::VT_STRING}, - {proto::AttrDef::kT, GeAttrValue::VT_TENSOR}, - {proto::AttrDef::kTd, GeAttrValue::VT_TENSOR_DESC}, - {proto::AttrDef::kG, GeAttrValue::VT_GRAPH}, - {proto::AttrDef::kBt, GeAttrValue::VT_BYTES}, - {proto::AttrDef::kFunc, GeAttrValue::VT_NAMED_ATTRS}, - {proto::AttrDef::kListListInt, GeAttrValue::VT_LIST_LIST_INT}, - {proto::AttrDef::kDt, GeAttrValue::VT_DATA_TYPE}, -}; -map GeAttrValueImp::attr_val_list_type_map_ = { - {proto::AttrDef_ListValue_ListValueType_VT_LIST_INT, GeAttrValue::VT_LIST_INT}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT, GeAttrValue::VT_LIST_FLOAT}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_BOOL, GeAttrValue::VT_LIST_BOOL}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_STRING, GeAttrValue::VT_LIST_STRING}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, GeAttrValue::VT_LIST_TENSOR}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, GeAttrValue::VT_LIST_TENSOR_DESC}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, GeAttrValue::VT_LIST_GRAPH}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, GeAttrValue::VT_LIST_BYTES}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, GeAttrValue::VT_LIST_NAMED_ATTRS}, - {proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, GeAttrValue::VT_LIST_DATA_TYPE}, -}; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue::GeAttrValue() { value_.InitDefault(); } - -GeAttrValue::GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val) : value_(proto_owner, val) {} - -GeAttrValue::ValueType GeAttrValue::GetValueType() const { - auto proto_msg = value_.GetProtoMsg(); - if (proto_msg != nullptr) { - auto val_case = proto_msg->value_case(); - if (val_case != proto::AttrDef::kList) { - auto it = GeAttrValueImp::attr_val_one_type_map_.find(val_case); - if (it != GeAttrValueImp::attr_val_one_type_map_.end()) { - return it->second; - } - } else { - auto it = GeAttrValueImp::attr_val_list_type_map_.find(proto_msg->list().val_type()); - if (it != GeAttrValueImp::attr_val_list_type_map_.end()) { - return it->second; - } - } - } - return GeAttrValue::VT_NONE; -} - -bool GeAttrValue::IsEmpty() const { return GetValueType() == VT_NONE; } - -GeAttrValue GeAttrValue::Copy() const { - GeAttrValue valueRet; - auto proto_msg = value_.GetProtoMsg(); - auto proto_msg_ret = valueRet.value_.GetProtoMsg(); - if (proto_msg != nullptr && proto_msg_ret != nullptr) { - *proto_msg_ret = *proto_msg; - } - return valueRet; -} - -#define ATTR_VALUE_SET_GET_IMP(type) \ - graphStatus GeAttrValue::SetValue(const type &val) { \ - auto proto_msg = value_.GetProtoMsg(); \ - if (proto_msg) { \ - if (GeAttrValueImp::SetValue(*proto_msg, val)) { \ - return GRAPH_SUCCESS; \ - } \ - } \ - return GRAPH_FAILED; \ - } \ - \ - graphStatus GeAttrValue::GetValue(type &val) const { \ - auto proto_msg = value_.GetProtoMsg(); \ - if (proto_msg) { \ - if (GeAttrValueImp::GetValue(*proto_msg, value_.GetProtoOwner(), val)) { \ - return GRAPH_SUCCESS; \ - } \ - } \ - return GRAPH_FAILED; \ - } - -ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR_DESC) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::TENSOR) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::GRAPH) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) -ATTR_VALUE_SET_GET_IMP(vector) -/*lint -e665*/ -ATTR_VALUE_SET_GET_IMP(vector>) -/*lint +e665*/ -ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 -ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 - -#undef ATTR_VALUE_SET_GET_IMP - -graphStatus GeAttrValue::MutableTensor(GeTensorPtr &tensor) { return GetValue(tensor); } - -graphStatus GeAttrValue::MutableListTensor(vector &list_tensor) { return GetValue(list_tensor); } - -class AttrUtilsHelper { - public: - inline static bool GetValueCheckType(const proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { - if (attr_def.value_case() != proto_case) { - GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); - return false; - } - return true; - } - - inline static bool GetValueCheckListType( - const proto::AttrDef &attr_def, proto::AttrDef_ListValue_ListValueType proto_list_case, - const std::function item_check_fun) { - if (attr_def.value_case() != proto::AttrDef::kList) { - GELOGW("Check ListType Failed, value_case %u", attr_def.value_case()); - return false; - } - auto &list = attr_def.list(); - if (list.val_type() == proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE) { - return item_check_fun(attr_def); - } - if (list.val_type() != proto_list_case) { - GELOGW("Check ListType Failed, val_type %u, expected %u", list.val_type(), proto_list_case); - return false; - } - return true; - } - - inline static bool SetValueCheckType(proto::AttrDef &attr_def, proto::AttrDef::ValueCase proto_case) { - if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto_case) { - GELOGW("Check Type Failed, proto case type %u, expected %u", attr_def.value_case(), proto_case); - return false; - } - return true; - } - - inline static bool SetValueCheckAndSetListType(proto::AttrDef &attr_def, - proto::AttrDef_ListValue_ListValueType proto_list_case) { - if (attr_def.value_case() != proto::AttrDef::VALUE_NOT_SET && attr_def.value_case() != proto::AttrDef::kList) { - GELOGW("AttrUtils::Check Type Failed, value_case %u", attr_def.value_case()); - return false; - } - auto list = attr_def.mutable_list(); - if (list == nullptr) { - GELOGE(GRAPH_FAILED, "list is nullptr"); - return false; - } - if (list->val_type() != proto::AttrDef_ListValue_ListValueType_VT_LIST_NONE && - list->val_type() != proto_list_case) { - GELOGW("AttrUtils::Check ListType Type Failed, val_type %d, expected %d", static_cast(list->val_type()), - static_cast(proto_list_case)); - return false; - } - list->set_val_type(proto_list_case); - return true; - } - - static bool GetAttrMapItem(const AttrHolder *obj, const string &name, const proto::AttrDef *&attr_def) { - if (obj == nullptr) { - GELOGE(FAILED, "%s obj is nullptr", name.c_str()); - return false; - } - auto attr_map = obj->GetAttrMap().GetProtoMsg(); - if (attr_map == nullptr) { - GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); - return false; - } - auto it = attr_map->find(name); - if (it == attr_map->end()) { - return false; - } - attr_def = &it->second; - return true; - } - - inline static bool MutableAttrMapItem(AttrHolder *obj, const string &name, proto::AttrDef *&attr_def) { - if (obj == nullptr) { - GELOGE(FAILED, " %s obj is nullptr", name.c_str()); - return false; - } - auto attr_map = obj->MutableAttrMap().GetProtoMsg(); - if (attr_map == nullptr) { - GELOGE(FAILED, "%s attr map is nullptr", name.c_str()); - return false; - } - // Get or add - attr_def = &((*attr_map)[name]); - return true; - } -}; - -#define ATTR_VALUE_IMP_SET_ONE(ValType, proto_case, protoItem) \ - bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ - return false; \ - } \ - proto_attr_val.set_##protoItem(value); \ - return true; \ - } - -#define ATTR_VALUE_IMP_SET_LIST(ValType, proto_list_case, protoItem) \ - bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, ValType value) { \ - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, \ - proto::AttrDef_ListValue_ListValueType_##proto_list_case)) { \ - return false; \ - } \ - auto list = proto_attr_val.mutable_list(); \ - list->clear_##protoItem(); \ - for (const auto &item : value) { \ - list->add_##protoItem(item); \ - } \ - return true; \ - } - -ATTR_VALUE_IMP_SET_ONE(int64_t, kI, i) -ATTR_VALUE_IMP_SET_ONE(float, kF, f) -ATTR_VALUE_IMP_SET_ONE(const string &, kS, s) -ATTR_VALUE_IMP_SET_ONE(bool, kB, b) - -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_INT, i) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_FLOAT, f) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_STRING, s) -ATTR_VALUE_IMP_SET_LIST(const vector &, VT_LIST_BOOL, b) - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensorDesc &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { - return false; - } - auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_attr_val.mutable_td() = *proto_msg; - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_td(); - for (const auto &item : value) { - auto proto_msg = item.tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - proto_attr_val.clear_list(); - return false; - } - *list->add_td() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ConstGeTensorPtr &value) { - if (value) { - return SetValue(proto_attr_val, *value); - } else { - return SetValue(proto_attr_val, GeTensor()); - } -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeTensor &val) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { - return false; - } - auto proto_msg = val.tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - return false; - } - *proto_attr_val.mutable_t() = *proto_msg; - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - vector constList(value.size()); - std::copy(value.begin(), value.end(), constList.begin()); - return SetValue(proto_attr_val, constList); -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_t(); - for (const auto &item : value) { - if (item == nullptr) { - GELOGE(GRAPH_FAILED, "AttrUtils::SetListTensor item is nullptr"); - proto_attr_val.clear_list(); - return false; - } - auto proto_msg = item->tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - proto_attr_val.clear_list(); - return false; - } - *list->add_t() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_t(); - for (const auto &item : value) { - auto proto_msg = item.tensor_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - proto_attr_val.clear_list(); - return false; - } - *list->add_t() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::BYTES &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - size_t val_size = value.GetSize(); - proto_attr_val.set_bt(value.GetData(), val_size); - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_bt(); - for (const auto &item : value) { - list->add_bt(item.GetData(), item.GetSize()); - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { - return false; - } - auto proto_msg = value.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - GELOGE(FAILED, "Proto msg is nullptr"); - return false; - } - *proto_attr_val.mutable_func() = *proto_msg; - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_na(); - for (const auto &item : value) { - auto proto_msg = item.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - proto_attr_val.clear_list(); - return false; - } - *list->add_na() = *proto_msg; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::ComputeGraphPtr &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { - return false; - } - ModelSerializeImp imp; - if (!imp.SerializeGraph(value, proto_attr_val.mutable_g())) { - GELOGE(GRAPH_FAILED, "AttrUtils::SetGraph SerializeGraph Failed"); - proto_attr_val.clear_g(); - return false; - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_g(); - - ModelSerializeImp imp; - for (const auto &item : value) { - if (!imp.SerializeGraph(item, list->add_g())) { - GELOGE(GRAPH_FAILED, "AttrUtils::SetListGraph SerializeGraph"); - proto_attr_val.clear_list(); - return false; - } - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector> &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { - return false; - } - proto_attr_val.clear_list_list_int(); - auto list_list_int = proto_attr_val.mutable_list_list_int(); - GE_CHECK_NOTNULL_EXEC(list_list_int, return false); - for (auto &list_int : value) { - auto list_item = list_list_int->add_list_list_i(); - GE_CHECK_NOTNULL_EXEC(list_item, return false); - for (auto &int_item : list_int) { - list_item->add_list_i(int_item); - } - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_dt(); - for (const auto &item : value) { - list->add_dt(static_cast(item)); - } - return true; -} - -bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const ge::DataType &value) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { - return false; - } - proto_attr_val.set_dt(static_cast(value)); - - return true; -} - -#define ATTR_VALUE_IMP_GET_ONE(ValType, proto_case, protoItem) \ - bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ValType value) { \ - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::proto_case)) { \ - return false; \ - } \ - value = proto_attr_val.protoItem(); \ - return true; \ - } - -#define ListValueItemCheck(protoItem) \ - [](const proto::AttrDef &proto_attr_val) { return proto_attr_val.list().protoItem##_size() > 0; } - -#define ATTR_VALUE_IMP_GET_LIST(ValType, proto_list_case, protoItem) \ - bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, vector &value) { \ - value.clear(); \ - if (!AttrUtilsHelper::GetValueCheckListType( \ - proto_attr_val, proto::AttrDef_ListValue_ListValueType_##proto_list_case, ListValueItemCheck(protoItem))) { \ - return false; \ - } \ - auto &list = proto_attr_val.list(); \ - for (const auto &item : list.protoItem()) { \ - value.push_back(item); \ - } \ - return true; \ - } - -ATTR_VALUE_IMP_GET_ONE(int64_t &, kI, i) -ATTR_VALUE_IMP_GET_ONE(float &, kF, f) -ATTR_VALUE_IMP_GET_ONE(string &, kS, s) -ATTR_VALUE_IMP_GET_ONE(bool &, kB, b) - -ATTR_VALUE_IMP_GET_LIST(int64_t, VT_LIST_INT, i) -ATTR_VALUE_IMP_GET_LIST(float, VT_LIST_FLOAT, f) -ATTR_VALUE_IMP_GET_LIST(string, VT_LIST_STRING, s) -ATTR_VALUE_IMP_GET_LIST(bool, VT_LIST_BOOL, b) - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeTensorDesc &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kTd)) { - return false; - } - auto proto_msg = value.tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = proto_attr_val.td(); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - if (!AttrUtilsHelper::GetValueCheckListType( - proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC, ListValueItemCheck(td))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.td()) { - value.emplace_back(GeTensorDesc()); - auto proto_msg = value.back().tensor_descriptor_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = item; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - GeTensorPtr &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kT)) { - return false; - } - value = std::shared_ptr(new (std::nothrow) - GeTensor(proto_owner, const_cast(proto_attr_val).mutable_t())); - GE_CHK_BOOL_RET_STATUS(value != nullptr, false, "value is nullptr"); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR, - ListValueItemCheck(t))) { - return false; - } - auto list = const_cast(proto_attr_val).mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - for (auto &item : *(list->mutable_t())) { - std::shared_ptr temp_value = std::shared_ptr(new (std::nothrow) GeTensor(proto_owner, &item)); - GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); - value.push_back(temp_value); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, GeAttrValue::BYTES &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - auto &proto_val = proto_attr_val.bt(); - GE_LOGI_IF(proto_val.size() == 0, "size res is 0."); - value = Buffer::CopyFrom(reinterpret_cast(proto_val.data()), proto_val.size()); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, - ListValueItemCheck(bt))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.bt()) { - value.push_back(Buffer::CopyFrom((const uint8_t *)item.data(), item.size())); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - GeAttrValue::NAMED_ATTRS &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { - return false; - } - auto proto_msg = value.named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = proto_attr_val.func(); - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType( - proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.na()) { - value.emplace_back(GeAttrValue::NAMED_ATTRS()); - if (value.empty()) { - return false; - } - auto proto_msg = value.back().named_attrs_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - *proto_msg = item; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ComputeGraphPtr &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kG)) { - return false; - } - ComputeGraphPtr graph = nullptr; - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(proto_attr_val.g()); - if (graph_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; // lint !e665 - } else { - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(graph, *graph_def)) { - GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); - return false; - } // lint !e514 - value = graph; - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH, - ListValueItemCheck(g))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.g()) { - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(item); - if (graph_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; // lint !e665 - } else { - ComputeGraphPtr graph = nullptr; - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(graph, *graph_def)) { - GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); - return false; - } // lint !e514 - value.push_back(graph); - } - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector> &value) { - value.clear(); - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kListListInt)) { - return false; - } - - auto &list_listint = proto_attr_val.list_list_int().list_list_i(); - for (auto &list_int : list_listint) { - vector list_item(list_int.list_i().size()); - if (!list_int.list_i().empty()) { - (void)std::copy(list_int.list_i().begin(), list_int.list_i().end(), list_item.begin()); - } - value.push_back(list_item); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_DATA_TYPE, - ListValueItemCheck(dt))) { - return false; - } - auto &list = proto_attr_val.list(); - for (const auto &item : list.dt()) { - value.emplace_back(static_cast(item)); - } - return true; -} - -bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, ge::DataType &value) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kDt)) { - return false; - } - value = static_cast(proto_attr_val.dt()); - return true; -} - -GE_FUNC_HOST_VISIBILITY bool GeAttrValueImp::SetZeroCopyBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - Buffer &&buffer) { - if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - auto proto_msg = buffer.data_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - proto_attr_val.set_bt(std::move(*proto_msg->mutable_bt())); - return true; -} - -bool GeAttrValueImp::GetZeroCopyBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - Buffer &buffer) { - if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kBt)) { - return false; - } - buffer = Buffer(proto_owner, &const_cast(proto_attr_val)); - return true; -} - -bool GeAttrValueImp::SetZeroCopyListBytes(proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &list_buffer) { - if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, - proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES)) { - return false; - } - auto list = proto_attr_val.mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - list->clear_bt(); - for (auto &item : list_buffer) { - auto proto_msg = item.data_.GetProtoMsg(); - if (proto_msg == nullptr) { - return false; - } - list->add_bt(std::move(*proto_msg->mutable_bt())); - } - return true; -} - -bool GeAttrValueImp::GetZeroCopyListBytes(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &proto_owner, - vector &list_buffer) { - list_buffer.clear(); - if (!AttrUtilsHelper::GetValueCheckListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_BYTES, - ListValueItemCheck(bt))) { - return false; - } - auto list = const_cast(proto_attr_val).mutable_list(); - GE_CHECK_NOTNULL_EXEC(list, return false); - for (auto &item : *(list->mutable_bt())) { - list_buffer.emplace_back(Buffer(proto_owner, &item)); - } - return true; -} - -bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const string &name) { - if (!obj) { - return false; - } - return obj->HasAttr(name); -} - -#define ATTR_UTILS_SET_IMP(FuncName, Type) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Set##FuncName( \ - AttrHolderAdapter &&obj, const string &name, const Type &value) { \ - proto::AttrDef *proto_attr_val = nullptr; \ - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ - return false; \ - } \ - if (!GeAttrValueImp::SetValue(*proto_attr_val, value)) { \ - GELOGW("Set" #FuncName " failed key %s", name.c_str()); \ - return false; \ - } \ - return true; \ - } - -#define ATTR_UTILS_GET_IMP(FuncName, Type) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::Get##FuncName(ConstAttrHolderAdapter &&obj, \ - const string &name, Type &value) { \ - const proto::AttrDef *proto_attr_val = nullptr; \ - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { \ - return false; \ - } \ - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value)) { \ - GELOGW("Get" #FuncName " failed key %s", name.c_str()); \ - return false; \ - } \ - return true; \ - } - -#define ATTR_UTILS_SET_GET_IMP(FuncName, Type) \ - ATTR_UTILS_SET_IMP(FuncName, Type) \ - ATTR_UTILS_GET_IMP(FuncName, Type) - -ATTR_UTILS_SET_GET_IMP(Int, int64_t) -ATTR_UTILS_SET_GET_IMP(Float, float) -ATTR_UTILS_SET_GET_IMP(Bool, bool) -ATTR_UTILS_SET_GET_IMP(Str, string) -ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) -ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) -ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) -ATTR_UTILS_SET_IMP(Tensor, GeTensor) -ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) -ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) -ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) -/*lint -e665*/ -ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) -/*lint +e665*/ - -ATTR_UTILS_SET_GET_IMP(ListInt, vector) -ATTR_UTILS_SET_IMP(ListInt, vector) -ATTR_UTILS_SET_IMP(ListInt, vector) -ATTR_UTILS_SET_GET_IMP(ListFloat, vector) -ATTR_UTILS_SET_GET_IMP(ListBool, vector) -ATTR_UTILS_SET_GET_IMP(ListStr, vector) -ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) -ATTR_UTILS_SET_GET_IMP(ListBytes, vector) -ATTR_UTILS_SET_GET_IMP(ListGraph, vector) -ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 -ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 - -bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, - std::initializer_list &&value) { - return SetListTensor(std::move(obj), name, vector(value)); -} - -bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - GeTensorPtr tensor; - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { - return false; - } - value = tensor; - return true; -} - -bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value) { - value.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - vector tensor; - if (!GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), tensor)) { - return false; - } - value.insert(value.begin(), tensor.begin(), tensor.end()); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, - const string &name, GeTensorPtr &value) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); -} - -bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value) { - value.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetValue(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), value); -} - -bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetValue(*proto_attr_val, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, - int32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (int64_val > INT32_MAX) { - GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to int32_t", int64_val); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const string &name, - uint32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (int64_val > UINT32_MAX) { - GELOGE(GRAPH_FAILED, "%ld int64_t value cannot cast to uint32_t", int64_val); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, - const string &name, vector &value) { - value.clear(); - vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0; i < int64_list.size(); ++i) { - if (int64_list[i] > INT32_MAX) { - GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); - return false; - } - } - value.insert(value.begin(), int64_list.begin(), int64_list.end()); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, - const string &name, vector &value) { - value.clear(); - vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0; i < int64_list.size(); ++i) { - if (int64_list[i] > UINT32_MAX) { - GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); - return false; - } - } - value.insert(value.begin(), int64_list.begin(), int64_list.end()); - return true; -} - -bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value) { - if (obj) { - vector bytes_vals; - for (auto &item : value) { - ModelSerialize serialize; - auto buffer = serialize.SerializeOpDesc(item); - if (buffer.GetSize() == 0) { - return false; - } - bytes_vals.push_back(buffer); - } - return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetListOpDesc(AttrHolderAdapter &&obj, - const string &name, - const vector &value) { - if (obj) { - vector bytes_vals; - for (auto &item : value) { - ModelSerialize serialize; - auto buffer = serialize.SerializeOpDesc(item); - if (buffer.GetSize() == 0) { - return false; - } - bytes_vals.push_back(buffer); - } - return SetZeroCopyListBytes(std::move(obj), name, bytes_vals); - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(ConstAttrHolderAdapter &&obj, - const string &name, - vector &value) { - value.clear(); - - vector bytes_vals; - if (!GetZeroCopyListBytes(std::move(obj), name, bytes_vals)) { - return false; - } - for (const auto &item : bytes_vals) { - ModelSerialize serialize; - auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 - value.push_back(op_desc); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, - const string &name, Buffer &&buffer) { - // Value will be moved - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), std::move(buffer)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, - const string &name, Buffer &buffer) { - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetZeroCopyBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), buffer); -} - -bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &list_buffer) { - // Value will be moved - proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::MutableAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::SetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); -} - -bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &list_buffer) { - list_buffer.clear(); - const proto::AttrDef *proto_attr_val = nullptr; - if (!AttrUtilsHelper::GetAttrMapItem(obj.get(), name, proto_attr_val) || proto_attr_val == nullptr) { - return false; - } - return GeAttrValueImp::GetZeroCopyListBytes(*proto_attr_val, obj->GetAttrMap().GetProtoOwner(), list_buffer); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - if (org_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "org_op_desc is null"); - return nullptr; - } - std::shared_ptr op_def; - op_def = ComGraphMakeShared(); - if (op_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; // lint !e665 - } - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); - op_desc->extAttrs_ = org_op_desc->extAttrs_; - - // This function may be called by some passes of fusion engine, in this condition, do not need these attribute - if (!op_desc->input_name_idx_.empty()) { - op_desc->input_name_idx_.clear(); - } - if (!op_desc->output_name_idx_.empty()) { - op_desc->output_name_idx_.clear(); - } - if (!op_desc->optional_input_names_.empty()) { - op_desc->optional_input_names_.clear(); - } - - return op_desc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - if (org_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "org_op_desc is null"); - return nullptr; - } - std::shared_ptr op_def = ComGraphMakeShared(); - if (op_def == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; - } - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); - - op_desc->extAttrs_ = org_op_desc->extAttrs_; - - op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); - op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), - org_op_desc->optional_input_names_.end()); - op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); - - op_desc->infer_func_ = org_op_desc->infer_func_; - op_desc->infer_format_func_ = org_op_desc->infer_format_func_; - op_desc->verifier_func_ = org_op_desc->verifier_func_; - - return op_desc; -} -std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { - auto holder = obj.get(); - if (holder == nullptr) { - return ""; - } - auto attrs_map = holder->GetAttrMap(); - if (attrs_map.GetProtoMsg() == nullptr) { - return ""; - } - - std::map ordered_attrs; - for (auto &attr : *(attrs_map.GetProtoMsg())) { - ordered_attrs[attr.first] = attr.second.SerializeAsString(); - } - - std::stringstream ss; - for (auto &attr : ordered_attrs) { - ss << attr.first << ":" << attr.second << ";"; - } - return ss.str(); -} -} // namespace ge diff --git a/metadef/graph/ge_tensor.cc b/metadef/graph/ge_tensor.cc deleted file mode 100644 index 196b8569..00000000 --- a/metadef/graph/ge_tensor.cc +++ /dev/null @@ -1,1021 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/ge_tensor.h" -#include -#include -#include -#include -#include "debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" -#include "graph/model_serialize.h" -#include "proto/ge_ir.pb.h" -#include "utils/attr_utils.h" -#include "utils/ge_ir_utils.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace ge { -static const char *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; - -static const std::map kDataTypeMap = { - {DT_UNDEFINED, proto::DT_UNDEFINED}, - {DT_FLOAT, proto::DT_FLOAT}, - {DT_FLOAT16, proto::DT_FLOAT16}, - {DT_INT8, proto::DT_INT8}, - {DT_UINT8, proto::DT_UINT8}, - {DT_INT16, proto::DT_INT16}, - {DT_UINT16, proto::DT_UINT16}, - {DT_INT32, proto::DT_INT32}, - {DT_INT64, proto::DT_INT64}, - {DT_UINT32, proto::DT_UINT32}, - {DT_UINT64, proto::DT_UINT64}, - {DT_BOOL, proto::DT_BOOL}, - {DT_DOUBLE, proto::DT_DOUBLE}, - {DT_DUAL, proto::DT_DUAL}, - {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, - {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, - {DT_COMPLEX64, proto::DT_COMPLEX64}, - {DT_COMPLEX128, proto::DT_COMPLEX128}, - {DT_QINT8, proto::DT_QINT8}, - {DT_QINT16, proto::DT_QINT16}, - {DT_QINT32, proto::DT_QINT32}, - {DT_QUINT8, proto::DT_QUINT8}, - {DT_QUINT16, proto::DT_QUINT16}, - {DT_RESOURCE, proto::DT_RESOURCE}, - {DT_STRING_REF, proto::DT_STRING_REF}, - {DT_STRING, proto::DT_STRING}, -}; - -static const std::map kDataTypeSelfDefinedMap = { - {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, - {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, -}; - -GeShape::GeShape() { shape_def_.InitDefault(); } - -// Default -GeShape::GeShape(std::vector s) : GeShape() { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto i : s) { - proto_msg->add_dim(i); - } - } -} - -size_t GeShape::GetDimNum() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (proto_msg->dim_size() >= 0) { - // check whether contain -2, if true, return -1 - for (auto i : proto_msg->dim()) { - if (i == UNKNOWN_DIM_NUM) { - return 0; - } - } - return proto_msg->dim_size(); - } else { - return 0; - } - } - return 0; -} - -int64_t GeShape::GetDim(size_t idx) const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (proto_msg->dim_size() > static_cast(idx)) { - return proto_msg->dim(static_cast(idx)); - } - } - return 0; -} - -graphStatus GeShape::SetDim(size_t idx, int64_t value) { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - auto dims = proto_msg->mutable_dim(); - GE_CHECK_NOTNULL(dims); - if (dims->empty()) { - GELOGE(GRAPH_FAILED, "shape is empty"); - return GRAPH_FAILED; - } - if (static_cast(idx) >= dims->size()) { - GELOGE(GRAPH_FAILED, "idx is out of range"); - return GRAPH_FAILED; - } - proto_msg->set_dim(static_cast(idx), value); - } - return GRAPH_SUCCESS; -} - -std::vector GeShape::GetDims() const { - vector dims; - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto i : proto_msg->dim()) { - dims.push_back(i); - } - } - return dims; -} - -std::string GeShape::ToString() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg == nullptr) { - return ""; - } - - std::stringstream ss; - bool first = true; - for (auto i : proto_msg->dim()) { - if (first) { - first = false; - } else { - ss << ","; - } - ss << i; - } - return ss.str(); -} - -int64_t GeShape::GetShapeSize() const { - int64_t res = 1; - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (proto_msg->dim().empty()) { - return 0; - } - for (auto i : proto_msg->dim()) { - // if unknown shape, return -1 - if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { - return UNKNOWN_DIM; - } - res *= i; - } - } - return res; -} - -/// -/// @brief Check is unknown shape -/// @return bool -/// /// -bool GeShape::IsUnknownShape() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto i : proto_msg->dim()) { - if (i < 0) { - return true; - } - } - } - return false; -} - -/// -/// @brief Check is a scalar -/// @return bool -/// -bool GeShape::IsScalar() const { - auto proto_msg = shape_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->dim().empty(); - } - return false; -} - -const string TENSOR_UTILS_SIZE = "size"; -const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; -const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; -const string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor"; -const string TENSOR_UTILS_DEVICE_TYPE = "device_type"; -const string TENSOR_UTILS_INPUT_TENSOR = "input_tensor"; -const string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt"; -const string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index"; -const string TENSOR_UTILS_DATA_OFFSET = "data_offset"; -const string TENSOR_UTILS_CMPS_SIZE = "cmps_size"; -const string TENSOR_UTILS_CMPS_TAB = "cmps_tab"; -const string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset"; -const string TENSOR_UTILS_CMPSINFO = "cmps_info"; -const string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info"; -const string TENSOR_UTILS_RC = "rc"; -const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; -const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; -const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; -const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; -const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; - -GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} - -GeShape::GeShape(const GeShape &other) : GeShape() { shape_def_.CopyValueFrom(other.shape_def_); } - -GeShape::GeShape(GeShape &&other) : GeShape() { shape_def_.MoveValueFrom(std::move(other.shape_def_)); } - -GeShape &GeShape::operator=(const GeShape &other) { - if (&other != this) { - shape_def_.CopyValueFrom(other.shape_def_); - } - return *this; -} - -GeShape &GeShape::operator=(GeShape &&other) { - if (&other != this) { - shape_def_.CopyValueFrom(std::move(other.shape_def_)); - } - return *this; -} - -GeTensorDesc::GeTensorDesc() { - tensor_descriptor_.InitDefault(); - SetDataType(DT_FLOAT); - Init(); -} - -// Default -GeTensorDesc::GeTensorDesc(GeShape shape, Format format, DataType dt) : GeTensorDesc() { - SetFormat(format); - SetDataType(dt); - ShapeReference() = std::move(shape); -} - -// Default -GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : GeTensorDesc() { - tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); -} - -// Default -GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : GeTensorDesc() { - tensor_descriptor_.MoveValueFrom(std::move(desc.tensor_descriptor_)); -} - -GeTensorDesc::GeTensorDesc(const ProtoMsgOwner &proto_owner, proto::TensorDescriptor *proto_msg) - : tensor_descriptor_(proto_owner, proto_msg) { - if (proto_msg != nullptr && !proto_msg->has_out_attr()) { - proto_msg->set_has_out_attr(true); - - int64_t size = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_SIZE, size); - proto_msg->set_size(size); - - int64_t weight_size = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_WEIGHT_SIZE, weight_size); - proto_msg->set_weight_size(weight_size); - - bool reuse_input = false; - (void)AttrUtils::GetBool(this, TENSOR_UTILS_REUSE_INPUT, reuse_input); - proto_msg->set_reuse_input(reuse_input); - - bool output_tensor = false; - (void)AttrUtils::GetBool(this, TENSOR_UTILS_OUTPUT_TENSOR, output_tensor); - proto_msg->set_output_tensor(output_tensor); - - string device_type = "NPU"; - (void)AttrUtils::GetStr(this, TENSOR_UTILS_DEVICE_TYPE, device_type); - proto_msg->set_device_type(device_type); - - bool input_tensor = false; - (void)AttrUtils::GetBool(this, TENSOR_UTILS_INPUT_TENSOR, input_tensor); - proto_msg->set_input_tensor(input_tensor); - - int64_t real_dim_cnt = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_REAL_DIM_CNT, real_dim_cnt); - proto_msg->set_real_dim_cnt(real_dim_cnt); - - int64_t reuse_input_index = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_REUSE_INPUT_INDEX, reuse_input_index); - proto_msg->set_reuse_input_index(reuse_input_index); - - int64_t data_offset = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_DATA_OFFSET, data_offset); - proto_msg->set_data_offset(data_offset); - - int64_t cmps_size = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_SIZE, cmps_size); - proto_msg->set_cmps_size(cmps_size); - - string cmps_tab; - (void)AttrUtils::GetStr(this, TENSOR_UTILS_CMPS_TAB, cmps_tab); - proto_msg->set_cmps_tab(cmps_tab); - - int64_t cmps_tab_offset = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_CMPS_TAB_OFFSET, cmps_tab_offset); - proto_msg->set_cmps_tab_offset(cmps_tab_offset); - } -} - -bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const { - const auto &tensor_descriptor = this->tensor_descriptor_.GetProtoMsg(); - const auto &r_tensor_descriptor = r_ge_tensor_desc.tensor_descriptor_.GetProtoMsg(); - if ((tensor_descriptor != nullptr) && (r_tensor_descriptor != nullptr)) { - // Message TensorDescriptor in ge_ir.proto - return ( - IsEqual(tensor_descriptor->name(), r_tensor_descriptor->name(), "TensorDescriptor.name()") && - IsEqual(tensor_descriptor->dtype(), r_tensor_descriptor->dtype(), "TensorDescriptor.dtype()") && - // Message ShapeDef in ge_ir.proto - IsEqual(ToString(tensor_descriptor->shape().dim()), ToString(r_tensor_descriptor->shape().dim()), - "TensorDescriptor.shape().dim()") && - IsEqual(tensor_descriptor->layout(), r_tensor_descriptor->layout(), "TensorDescriptor.layout()") && - IsEqual(tensor_descriptor->has_out_attr(), r_tensor_descriptor->has_out_attr(), - "TensorDescriptor.has_out_attr()") && - IsEqual(tensor_descriptor->size(), r_tensor_descriptor->size(), "TensorDescriptor.size()") && - IsEqual(tensor_descriptor->weight_size(), r_tensor_descriptor->weight_size(), "TensorDescriptor.weight_size()") && - IsEqual(tensor_descriptor->reuse_input(), r_tensor_descriptor->reuse_input(), "TensorDescriptor.reuse_input()") && - IsEqual(tensor_descriptor->output_tensor(), r_tensor_descriptor->output_tensor(), - "TensorDescriptor.output_tensor()") && - IsEqual(tensor_descriptor->device_type(), r_tensor_descriptor->device_type(), "TensorDescriptor.device_type()") && - IsEqual(tensor_descriptor->input_tensor(), r_tensor_descriptor->input_tensor(), - "TensorDescriptor.input_tensor()") && - IsEqual(tensor_descriptor->real_dim_cnt(), r_tensor_descriptor->real_dim_cnt(), - "TensorDescriptor.real_dim_cnt()") && - IsEqual(tensor_descriptor->reuse_input_index(), r_tensor_descriptor->reuse_input_index(), - "TensorDescriptor.reuse_input_index()") && - IsEqual(tensor_descriptor->data_offset(), r_tensor_descriptor->data_offset(), "TensorDescriptor.data_offset()") && - IsEqual(tensor_descriptor->cmps_size(), r_tensor_descriptor->cmps_size(), "TensorDescriptor.cmps_size()") && - IsEqual(tensor_descriptor->cmps_tab(), r_tensor_descriptor->cmps_tab(), "TensorDescriptor.cmps_tab()") && - IsEqual(tensor_descriptor->cmps_tab_offset(), r_tensor_descriptor->cmps_tab_offset(), - "TensorDescriptor.cmps_tab_offset()")); - } else { - return ((tensor_descriptor == nullptr) && (r_tensor_descriptor == nullptr)); - } -} - -bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const { - return GeTensorDescAttrsAreEqual(r_ge_tensor_desc); -} - -GeShape &GeTensorDesc::ShapeReference() const { - if (tensor_descriptor_.GetProtoMsg() != nullptr) { - GeShape refShape(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_shape()); - __shape_.RefTo(refShape); - } else { - GeShape refShape(tensor_descriptor_.GetProtoOwner(), nullptr); - __shape_.RefTo(refShape); - } - return __shape_; -} - -void GeTensorDesc::Init() { - SetFormat(FORMAT_ND); - SetOriginFormat(FORMAT_ND); - TensorUtils::SetDeviceType(*this, DeviceType::NPU); - if (tensor_descriptor_.GetProtoMsg() == nullptr) { - GELOGE(GRAPH_FAILED, "ProtoType nullptr."); - return; - } - tensor_descriptor_.GetProtoMsg()->set_has_out_attr(true); -} - -ProtoAttrMapHelper GeTensorDesc::MutableAttrMap() { - if (tensor_descriptor_.GetProtoMsg() != nullptr) { - return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), tensor_descriptor_.GetProtoMsg()->mutable_attr()); - } - return ProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); -} - -ConstProtoAttrMapHelper GeTensorDesc::GetAttrMap() const { - if (tensor_descriptor_.GetProtoMsg() != nullptr) { - return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), - tensor_descriptor_.GetProtoMsg()->mutable_attr()); - } - return ConstProtoAttrMapHelper(tensor_descriptor_.GetProtoOwner(), nullptr); -} - -void GeTensorDesc::Update(GeShape shape, Format format, DataType dt) { - ShapeReference() = std::move(shape); - SetFormat(format); - SetDataType(dt); -} -GeShape GeTensorDesc::GetShape() const { return ShapeReference(); } - -GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } - -void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } - -// set shape with -2, it stand for unknown shape -void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } - -// for unknown shape -graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) { - std::vector> shape_range; - for (const auto &ele : range) { - shape_range.emplace_back(std::vector({ele.first, ele.second})); - } - auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); - return ret ? GRAPH_SUCCESS : GRAPH_FAILED; -} -graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const { - std::vector> shape_range; - (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); - - for (const auto &ele : shape_range) { - // here must be only two elemenet because pair - if (ele.size() != 2) { - GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); - return GRAPH_FAILED; - } - std::pair pair({ele[0], ele[1]}); - range.push_back(pair); - } - - return GRAPH_SUCCESS; -} - -GeShape GeTensorDesc::GetOriginShape() const { - vector origin_shape; - if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { - return GeShape(); - } - return GeShape(origin_shape); -} - -void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) { - std::vector origin_shape_tmp = origin_shape.GetDims(); - (void)AttrUtils::SetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape_tmp); -} - -Format GeTensorDesc::GetFormat() const { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - return TypeUtils::SerialStringToFormat(tensor_descriptor_msg->layout()); - } - return FORMAT_RESERVED; -} - -void GeTensorDesc::SetFormat(Format format) { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_layout(TypeUtils::FormatToSerialString(format)); - } -} - -void GeTensorDesc::SetName(const std::string &name) { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_name(name); - return; - } - GELOGW("[SetName]tensor_descriptor_msg is null."); -} - -const std::string GeTensorDesc::GetName() const { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - return tensor_descriptor_msg->name(); - } - GELOGW("[GetName]tensor_descriptor_msg is null."); - return ""; -} - -Format GeTensorDesc::GetOriginFormat() const { - std::string origin_format_str; - if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str)) { - // Can not get the certificate and it's not set, return directly - return FORMAT_RESERVED; - } - if (origin_format_str == "RESERVED") { - return FORMAT_RESERVED; - } - return TypeUtils::SerialStringToFormat(origin_format_str); -} - -void GeTensorDesc::SetOriginFormat(Format origin_format) { - std::string origin_format_str = "RESERVED"; - if (origin_format != FORMAT_RESERVED) { - origin_format_str = TypeUtils::FormatToSerialString(origin_format); - } - (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_FORMAT, origin_format_str); -} - -DataType GeTensorDesc::GetDataType() const { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg == nullptr) { - return DT_UNDEFINED; - } - auto &attr_map = *(tensor_descriptor_msg->mutable_attr()); - // Data type - auto it_data_type = attr_map.find(kKeyDataTypeSelfDefined); - if (it_data_type != attr_map.end()) { - int64_t data_type_proto = it_data_type->second.i(); - for (auto it : kDataTypeSelfDefinedMap) { - if (it.second == data_type_proto) { - return it.first; - } - } - } else { - auto data_type_proto = tensor_descriptor_msg->dtype(); - for (auto it : kDataTypeMap) { - if (it.second == data_type_proto) { - return it.first; - } - } - } - return DT_UNDEFINED; -} - -void GeTensorDesc::SetDataType(DataType dataType) { - auto tensor_descriptor_msg = tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg == nullptr) { - return; - } - auto &attr_maps = *(tensor_descriptor_msg->mutable_attr()); - (void)attr_maps.erase(kKeyDataTypeSelfDefined); - - // Data type - auto it = kDataTypeMap.find(dataType); - if (it != kDataTypeMap.end()) { - tensor_descriptor_msg->set_dtype(it->second); - return; - } - auto it2 = kDataTypeSelfDefinedMap.find(dataType); - if (it2 != kDataTypeSelfDefinedMap.end()) { - attr_maps[kKeyDataTypeSelfDefined].set_i(it2->second); - } -} - -void GeTensorDesc::SetOriginDataType(DataType origin_data_type) { - std::string origin_data_type_str = "RESERVED"; - if (origin_data_type != DT_UNDEFINED) { - origin_data_type_str = TypeUtils::DataTypeToSerialString(origin_data_type); - } - (void)AttrUtils::SetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str); -} - -DataType GeTensorDesc::GetOriginDataType() const { - std::string origin_data_type_str; - if (!AttrUtils::GetStr(this, TENSOR_UTILS_ORIGIN_DATA_TYPE, origin_data_type_str)) { - return DT_UNDEFINED; - } - if (origin_data_type_str == "RESERVED") { - return DT_UNDEFINED; - } - return TypeUtils::SerialStringToDataType(origin_data_type_str); -} - -std::vector GeTensorDesc::GetRefPortIndex() const { - vector ref_port_index; - (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); - return ref_port_index; -} - -void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { - (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); -} - -graphStatus GeTensorDesc::IsValid() const { - auto dtype = this->GetDataType(); - auto format = this->GetFormat(); - if (dtype == DT_UNDEFINED && format == FORMAT_RESERVED) { - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -GeTensorDesc GeTensorDesc::Clone() const { return *this; } - -GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) { - if (&desc != this) { - tensor_descriptor_.CopyValueFrom(desc.tensor_descriptor_); - } - return *this; -} - -GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) { - if (&desc != this) { - tensor_descriptor_.CopyValueFrom(std::move(desc.tensor_descriptor_)); - } - return *this; -} - -GeTensor::GeTensor::GeTensor() { - tensor_def_.InitDefault(); - // Default init desc - DescReference() = GeTensorDesc(); -} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc) : GeTensor() { DescReference() = tensor_desc; } - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const vector &data) : GeTensor() { - DescReference() = tensor_desc; - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_data(data.data(), data.size()); - } -} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *data, size_t size) : GeTensor() { - DescReference() = tensor_desc; - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr && data != nullptr) { - proto_msg->set_data(data, size); - } -} - -GeTensor::GeTensor(GeTensorDesc &&tensor_desc, vector &&data) : GeTensor() { - DescReference() = std::move(tensor_desc); - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_data(data.data(), data.size()); - } -} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensor() { - DescReference() = tensor_desc; - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - if (data.size() == 0) { - GELOGI("GetSize res is 0."); - } - if (data.data() == nullptr) { - GELOGI("data addr is null."); - } - proto_msg->set_data(data.GetData(), data.GetSize()); - } -} - -GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) - : tensor_def_(proto_owner, proto_msg) {} - -GeTensorDesc GeTensor::GetTensorDesc() const { return DescReference(); } - -GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); } - -GeTensorDesc &GeTensor::DescReference() const { - if (tensor_def_.GetProtoMsg() != nullptr) { - GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), tensor_def_.GetProtoMsg()->mutable_desc()); - __desc_.RefTo(tensor_desc); - } else { - GeTensorDesc tensor_desc(tensor_def_.GetProtoOwner(), nullptr); - __desc_.RefTo(tensor_desc); - } - return __desc_; -} - -void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; } - -const Buffer GeTensor::GetData() const { - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); - } - return Buffer(); -} - -Buffer GeTensor::MutableData() { - auto proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return Buffer(tensor_def_.GetProtoOwner(), proto_msg->mutable_data()); - } - return Buffer(); -} - -graphStatus GeTensor::SetData(vector &&data) { - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data.data(), data.size()); - return GRAPH_SUCCESS; -} - -graphStatus GeTensor::SetData(const vector &data) { - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data.data(), data.size()); - return GRAPH_SUCCESS; -} - -graphStatus GeTensor::SetData(const uint8_t *data, size_t size) { - GE_CHECK_NOTNULL(data); - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data, size); - return GRAPH_SUCCESS; -} - -graphStatus GeTensor::SetData(const Buffer &data) { - auto proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - if (data.size() == 0) { - GELOGI("GetSize res is 0."); - } - if (data.data() == nullptr) { - GELOGI("data addr is null."); - } - proto_msg->set_data(data.data(), data.size()); - return GRAPH_SUCCESS; -} - -GeTensor GeTensor::Clone() const { - GeTensor tensor; - tensor.tensor_def_.CopyValueFrom(tensor_def_); - return tensor; -} - -GeTensor::GeTensor(const GeTensor &other) { tensor_def_ = other.tensor_def_; } - -GeTensor &GeTensor::operator=(const GeTensor &other) { - if (&other != this) { - tensor_def_ = other.tensor_def_; - } - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, - int64_t &size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - size = static_cast(tensor_descriptor_msg->size()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_size(size); - } -} - -uint32_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - return static_cast(tensor_descriptor_msg->weight_size()); - } - return 0; -} - -uint32_t TensorUtils::GetWeightSize(const GeTensor &tensor) { return GetWeightSize(tensor.GetTensorDesc()); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) { - if (tensor_ptr == nullptr) { - return 0; - } - return GetWeightSize(*tensor_ptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, - uint8_t *base) { - if (tensor_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "tensor_ptr is null."); - return nullptr; - } - return GetWeightAddr(*tensor_ptr, base); -} - -uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, uint8_t *base) { - if (base == nullptr) { - GELOGE(GRAPH_FAILED, "base is null."); - return nullptr; - } - int64_t weight_data_offset = 0; - if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) return nullptr; - - if (weight_data_offset == 0) { - // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS. - return const_cast(tensor.GetData().data()); - } - - return base + weight_data_offset; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc, - uint32_t size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_weight_size(size); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc, - bool &flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - flag = tensor_descriptor_msg->reuse_input(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput(GeTensorDesc &tensor_desc, bool flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_reuse_input(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc, - bool &flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - flag = tensor_descriptor_msg->output_tensor(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor(GeTensorDesc &tensor_desc, bool flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_output_tensor(flag); - } -} - -static map device_to_str_map{ - {0, "NPU"}, - {1, "CPU"}, -}; -static map str_to_device_map{ - {"NPU", 0}, - {"CPU", 1}, -}; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, - DeviceType &type) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - string type_str = tensor_descriptor_msg->device_type(); - type = DeviceType(str_to_device_map[type_str]); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc, - DeviceType type) { - auto type_str = device_to_str_map[type]; - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_device_type(type_str); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc, - bool &flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - flag = tensor_descriptor_msg->input_tensor(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor(GeTensorDesc &tensor_desc, bool flag) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_input_tensor(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc, - uint32_t &cnt) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - cnt = static_cast(tensor_descriptor_msg->real_dim_cnt()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc, - uint32_t cnt) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_real_dim_cnt(cnt); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - GE_CHECK_NOTNULL(tensor_descriptor_msg); - - idx = static_cast(tensor_descriptor_msg->reuse_input_index()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc, - uint32_t idx) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_reuse_input_index(idx); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc, - int64_t &offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - offset = tensor_descriptor_msg->data_offset(); - return GRAPH_SUCCESS; - } else { - GELOGW("tensor_descriptor_msg is nullptr."); - return GRAPH_FAILED; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc, - int64_t offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_data_offset(offset); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsSize(const GeTensorDesc &tensor_desc, - uint32_t &cmp_size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - cmp_size = static_cast(tensor_descriptor_msg->cmps_size()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsSize(GeTensorDesc &tensor_desc, - uint32_t cmp_size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_cmps_size(cmp_size); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetCmpsTab(const GeTensorDesc &tensor_desc, - vector &vec) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - string str = tensor_descriptor_msg->cmps_tab(); - vec.assign(str.begin(), str.end()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTab(GeTensorDesc &tensor_desc, - const uint8_t *data, size_t size) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - GE_CHK_BOOL_EXEC(data != nullptr, return, "data is null."); - string str((const char *)data, size); - tensor_descriptor_msg->set_cmps_tab(str); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetCmpsTabOffset(const GeTensorDesc &tensor_desc, int64_t &tab_offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tab_offset = tensor_descriptor_msg->cmps_tab_offset(); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsTabOffset(GeTensorDesc &tensor_desc, - int64_t tab_offset) { - auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor_msg != nullptr) { - tensor_descriptor_msg->set_cmps_tab_offset(tab_offset); - } -} - -graphStatus TensorUtils::GetCmpsInfo(const GeTensorDesc &tensor_desc, CompressInfo &info) { - GeAttrValue attr_value; - if (tensor_desc.GetAttr(TENSOR_UTILS_CMPSINFO, attr_value) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return attr_value.GetValue(info); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetCmpsInfo(GeTensorDesc &tensor_desc, - const CompressInfo &info) { - (void)tensor_desc.SetAttr(TENSOR_UTILS_CMPSINFO, GeAttrValue::CreateFrom(info)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::HasAlloffsetQuantizeInfo( - const GeTensorDesc &tensor_desc) { - return tensor_desc.HasAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetAlloffsetQuantizeInfo(const GeTensorDesc &tensor_desc, AllOffsetQuantizeInfo &info) { - GeAttrValue attr_value; - if (tensor_desc.GetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, attr_value) != GRAPH_SUCCESS) { - GELOGW("get attr alloffset_quantize_info fail."); - } - return attr_value.GetValue(info); -} - -void TensorUtils::SetAlloffsetQuantizeInfo(GeTensorDesc &tensor_desc, const AllOffsetQuantizeInfo &info) { - (void)tensor_desc.SetAttr(TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO, GeAttrValue::CreateFrom(info)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc, - uint32_t &rc) { - return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, uint32_t rc) { - (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, rc); -} -} // namespace ge diff --git a/metadef/graph/graph.cc b/metadef/graph/graph.cc deleted file mode 100644 index fc30e9d6..00000000 --- a/metadef/graph/graph.cc +++ /dev/null @@ -1,384 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/graph.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" -#include "graph/model.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" - -using std::map; -using std::pair; -using std::string; -using std::vector; - -namespace ge { -class GraphImpl { - public: - friend class GraphUtils; - GraphImpl(const GraphImpl &) = delete; - GraphImpl &operator=(const GraphImpl &) = delete; - - explicit GraphImpl(const std::string &name) : name_(name) {} - - ~GraphImpl() { - if (IsValid()) { - if (compute_graph_ != nullptr) { - GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo()); - } - } - for (const auto &it : op_list_) { - Operator op = it.second; - op.BreakConnect(); - } - } - - graphStatus SetInputs(const std::vector &inputs) { - compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs); - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed."); - GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL."); - compute_graph_->SetInputSize(static_cast(inputs.size())); - return GRAPH_SUCCESS; - } - - graphStatus SetOutputs(const std::vector &outputs) { - if (compute_graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); - return GRAPH_FAILED; - } - if (outputs.empty()) { - GELOGW("set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct special output node - std::vector>> output_indexs; - for (size_t i = 0; i < outputs.size(); ++i) { - output_indexs.emplace_back(outputs[i], std::vector{}); - } - - graphStatus ret = SetOutputs(output_indexs); - return ret; - } - - graphStatus SetOutputs(const std::vector>> &output_indexs) { - if (compute_graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); - return GRAPH_FAILED; - } - if (output_indexs.empty()) { - GELOGW("set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct special output node - std::vector> output_nodes; - for (const auto &item : output_indexs) { - const Operator &output = item.first; - const vector &indexs = item.second; - ge::NodePtr node = compute_graph_->FindNode(output.GetName()); - if (node == nullptr) { - GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str()); - continue; - } - - ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); - size_t out_size = tmp_op_ptr->GetOutputsSize(); - if (indexs.empty()) { - for (size_t i = 0; i < out_size; ++i) { - output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, i); - } - } else { - for (size_t i = 0; i < indexs.size(); ++i) { - if (indexs[i] >= out_size) { - GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str()); - } else { - output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, indexs[i]); - } - } - } - } - - // Del last ";" - if (!output_name_.empty()) { - output_name_ = output_name_.substr(0, output_name_.length() - 1); - } - compute_graph_->SetUserDefOutput(output_name_); - compute_graph_->SetOutputSize(static_cast(output_indexs.size())); - compute_graph_->SetGraphOutNodesInfo(output_nodes); - return GRAPH_SUCCESS; - } - - graphStatus SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); - GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0."); - - // Construct specified output - std::vector> output_nodes; - for (auto item : outputs) { - ge::NodePtr node = compute_graph_->FindNode(item.first.GetName()); - if (node == nullptr) { - GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!", - item.first.GetName().c_str()); - return GRAPH_FAILED; - } - ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); - size_t out_size = tmp_op_ptr->GetOutputsSize(); - - if (item.second.empty()) { - for (size_t i = 0; i < out_size; ++i) { - output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";"; - output_nodes.push_back(std::make_pair(node, i)); - } - } else { - int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second); - if (index < 0) { - GELOGE(GRAPH_FAILED, - " Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!", - item.first.GetName().c_str(), item.second.c_str()); - return GRAPH_FAILED; - } - output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";"; - output_nodes.push_back(std::make_pair(node, index)); - } - } - // Del last ";" - if (!output_name_.empty()) { - output_name_ = output_name_.substr(0, output_name_.length() - 1); - } - compute_graph_->SetOutputSize(static_cast(outputs.size())); - compute_graph_->SetGraphOutNodesInfo(output_nodes); - GELOGI("********************SetOutputs Success***********************"); - GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str())); - - return GRAPH_SUCCESS; - } - - graphStatus SetTargets(const std::vector &targets) { - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); - GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0."); - - std::vector target_nodes; - for (auto item : targets) { - ge::NodePtr node = compute_graph_->FindNode(item.GetName()); - if (node == nullptr) { - GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!", - item.GetName().c_str()); - continue; - } - target_nodes.push_back(node); - } - compute_graph_->SetGraphTargetNodesInfo(target_nodes); - return GRAPH_SUCCESS; - } - bool IsValid() const { return (compute_graph_ != nullptr); } - - graphStatus AddOp(const ge::Operator &op) { - std::pair::iterator, bool> ret; - ret = op_list_.emplace(std::pair(op.GetName(), op)); - GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.", - op.GetName().c_str()); - return GRAPH_SUCCESS; - } - - graphStatus GetAllOpName(std::vector &op_name) const { - for (const auto &it : op_list_) { - op_name.push_back(it.second.GetName()); - } - return GRAPH_SUCCESS; - } - - graphStatus FindOpByName(const string &name, ge::Operator &op) const { - auto it = op_list_.find(name); - GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); - op = it->second; - return GRAPH_SUCCESS; - } - - graphStatus FindOpByType(const string &type, std::vector &ops) const { - for (auto &op : op_list_) { - auto op_type = op.second.GetOpType(); - if (op_type == type) { - ops.push_back(op.second); - continue; - } - if (op_type == ge::FRAMEWORKOP) { - op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); - if (op_type == type) { - ops.push_back(op.second); - } - } - } - return GRAPH_SUCCESS; - } - - void SetNeedIteration(bool need_iteration) { - if (compute_graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); - return; - } - compute_graph_->SetNeedIteration(need_iteration); - } - - const std::string &GetName() const { return name_; } - - private: - std::string name_; - std::string output_name_; - std::map op_list_; - ComputeGraphPtr compute_graph_{nullptr}; -}; - -Graph::Graph(const std::string &name) { - impl_ = ComGraphMakeShared(name); - if (impl_ == nullptr) { - GELOGW("GraphImpl make shared failed, impl_ is nullptr"); - } -} - -graphStatus Graph::AddOp(const ge::Operator &op) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr."); - return impl_->AddOp(op); -} - -graphStatus Graph::GetAllOpName(std::vector &op_name) const { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, - "GetAllOpName failed: graph can not be used, impl is nullptr."); - return impl_->GetAllOpName(op_name); -} - -graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { - Operator op_find_op_def("NULL"); - op = op_find_op_def; - GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, - "FindOpByName failed: graph can not be used, impl is nullptr."); - return impl_->FindOpByName(name, op); -} - -graphStatus Graph::FindOpByType(const string &type, std::vector &ops) const { - GE_CHECK_NOTNULL(impl_); - return impl_->FindOpByType(type, ops); -} - -Graph &Graph::SetInputs(const vector &inputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") - GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); - (void)impl_->SetInputs(inputs); - return *this; -} - -Graph &Graph::SetOutputs(const vector &outputs) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetOutputs(outputs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector>> &output_indexs) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetOutputs(output_indexs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.") - (void)impl_->SetOutputs(outputs); - return *this; -} - -Graph &Graph::SetTargets(const vector &targets) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetTargets(targets); - return *this; -} - -bool Graph::IsValid() const { - if (impl_ == nullptr) { - return false; - } - return impl_->IsValid(); -} - -void Graph::SetNeedIteration(bool need_iteration) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null."); - return; - } - impl_->SetNeedIteration(need_iteration); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { - GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); - return graph.impl_->compute_graph_; -} - -graphStatus Graph::SaveToFile(const string &file_name) const { - Model model = Model(); - model.SetGraph(*this); - return model.SaveToFile(file_name); -} - -graphStatus Graph::LoadFromFile(const string &file_name) { - Model model = Model(); - graphStatus ret = model.LoadFromFile(file_name); - if (ret != GRAPH_SUCCESS) { - return ret; - } - *this = model.GetGraph(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph -GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { - GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); - - auto name = compute_graph->GetName(); - auto graph = Graph(name); - - GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph); - graph.impl_->compute_graph_ = compute_graph; - - return graph; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { - GE_CHECK_NOTNULL(graph.impl_); - GE_CHECK_NOTNULL(graph.impl_->compute_graph_); - - graph.impl_->op_list_.clear(); - for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { - graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); - } - return SUCCESS; -} -} // namespace ge diff --git a/metadef/graph/graph.mk b/metadef/graph/graph.mk deleted file mode 100644 index 9e9ffa3a..00000000 --- a/metadef/graph/graph.mk +++ /dev/null @@ -1,294 +0,0 @@ -LOCAL_PATH := $(call my-dir) -include $(LOCAL_PATH)/stub/Makefile -COMMON_LOCAL_SRC_FILES := \ - ./proto/om.proto \ - ./proto/ge_ir.proto \ - ./proto/ge_onnx.proto \ - ./proto/insert_op.proto \ - ./proto/task.proto \ - ./proto/fwk_adapter.proto \ - ./proto/op_mapping_info.proto \ - ./proto/dump_task.proto \ - ./anchor.cc \ - ./ge_attr_value.cc \ - ./attr_value.cc \ - ./buffer.cc \ - ./compute_graph.cc \ - ./graph.cc \ - ./inference_context.cc \ - ./shape_refiner.cc \ - ./format_refiner.cc \ - ./ref_relation.cc \ - ./model.cc \ - ./model_serialize.cc \ - ./node.cc \ - ./op_desc.cc \ - ./operator.cc \ - ./operator_factory.cc \ - ./operator_factory_impl.cc \ - ./ge_attr_define.cc \ - ./ge_tensor.cc \ - ./detail/attributes_holder.cc \ - ./utils/anchor_utils.cc \ - ./utils/tuning_utils.cc \ - ./utils/graph_utils.cc \ - ./utils/ge_ir_utils.cc \ - ./utils/node_utils.cc \ - ./utils/op_desc_utils.cc \ - ./utils/type_utils.cc \ - ./utils/tensor_utils.cc \ - ./tensor.cc \ - ./debug/graph_debug.cc \ - ./opsproto/opsproto_manager.cc \ - ../ops/op_imp.cpp \ - option/ge_context.cc \ - option/ge_local_context.cc \ - ./runtime_inference_context.cc \ - -COMMON_LOCAL_C_INCLUDES := \ - proto/om.proto \ - proto/ge_ir.proto \ - proto_inner/ge_onnx.proto \ - proto/insert_op.proto \ - proto/task.proto \ - proto/fwk_adapter.proto \ - proto/op_mapping_info.proto \ - proto/dump_task.proto \ - inc \ - inc/external \ - inc/external/graph \ - inc/graph \ - inc/common \ - common \ - common/graph \ - third_party/protobuf/include \ - libc_sec/include \ - ops/built-in/op_proto/inc \ - - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := stub/libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for host -include $(CLEAR_VARS) -LOCAL_MODULE := fwk_stub/libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/attr_value.cc \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/inference_context.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := stub/libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -#compiler for device -include $(CLEAR_VARS) -LOCAL_MODULE := fwk_stub/libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := \ - ../../out/graph/lib64/stub/attr_value.cc \ - ../../out/graph/lib64/stub/graph.cc \ - ../../out/graph/lib64/stub/operator.cc \ - ../../out/graph/lib64/stub/operator_factory.cc \ - ../../out/graph/lib64/stub/tensor.cc \ - ../../out/graph/lib64/stub/inference_context.cc \ - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -ifeq ($(device_os),android) -LOCAL_LDFLAGS := -ldl -endif - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_SHARED_LIBRARY) - -# compile for ut/st -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_LLT_SHARED_LIBRARY) - - -#compiler for host static lib -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 -LOCAL_CPPFLAGS += -fexceptions - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_STATIC_LIBRARIES := \ - libprotobuf \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_HOST_STATIC_LIBRARY) - -#compiler for device static lib -include $(CLEAR_VARS) -LOCAL_MODULE := libgraph - -LOCAL_CFLAGS += -O2 - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) - -LOCAL_STATIC_LIBRARIES := \ - libprotobuf \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libslog \ - liberror_manager \ - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_MULTILIB := 64 -LOCAL_PROPRIETARY_MODULE := true - -include $(BUILD_STATIC_LIBRARY) diff --git a/metadef/graph/inference_context.cc b/metadef/graph/inference_context.cc deleted file mode 100644 index ed8193dc..00000000 --- a/metadef/graph/inference_context.cc +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/inference_context.h" -#include "debug/ge_util.h" - -namespace ge { -class ShapeAndTypeImpl { - public: - ShapeAndTypeImpl() = default; - ~ShapeAndTypeImpl() = default; - - ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} - - Shape shape_; - DataType data_type_ = DT_UNDEFINED; -}; - -class InferenceContextImpl { - public: - InferenceContextImpl() = default; - ~InferenceContextImpl() = default; - - // For deliver to op in pair, help to support dynamic shape - std::vector marks_; - std::vector> input_handle_shapes_and_types_; - std::vector> output_handle_shapes_and_types_; -}; - -ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared(); } - -ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { - shape_and_type_impl_ = ComGraphMakeShared(shape, data_type); -} - -void ShapeAndType::SetShape(const Shape &shape) { - if (shape_and_type_impl_ != nullptr) { - shape_and_type_impl_->shape_ = shape; - } -} - -void ShapeAndType::SetType(DataType data_type) { - if (shape_and_type_impl_ != nullptr) { - shape_and_type_impl_->data_type_ = data_type; - } -} - -Shape ShapeAndType::GetShape() const { - if (shape_and_type_impl_ != nullptr) { - return shape_and_type_impl_->shape_; - } - return Shape(); -} - -DataType ShapeAndType::GetDataType() const { - if (shape_and_type_impl_ != nullptr) { - return shape_and_type_impl_->data_type_; - } - return DT_UNDEFINED; -} - -InferenceContext::InferenceContext(std::unique_ptr &impl) { - inference_context_impl_ = std::move(impl); -} - -std::unique_ptr InferenceContext::Create() { - std::unique_ptr impl = - std::unique_ptr(new (std::nothrow) InferenceContextImpl()); - if (impl == nullptr) { - return nullptr; - } - - return std::unique_ptr(new (std::nothrow) InferenceContext(impl)); -} - -void InferenceContext::SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types) { - inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types); -} - -const std::vector> &InferenceContext::GetInputHandleShapesAndTypes() const { - return inference_context_impl_->input_handle_shapes_and_types_; -} - -const std::vector> &InferenceContext::GetOutputHandleShapesAndTypes() const { - return inference_context_impl_->output_handle_shapes_and_types_; -} - -void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types) { - inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types; -} - -void InferenceContext::SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types) { - inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types); -} - -void InferenceContext::SetMarks(const std::vector &marks) { inference_context_impl_->marks_ = marks; } - -const std::vector &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; } -} // namespace ge diff --git a/metadef/graph/model.cc b/metadef/graph/model.cc deleted file mode 100644 index a3628204..00000000 --- a/metadef/graph/model.cc +++ /dev/null @@ -1,190 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/model.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/model_serialize.h" -#include "proto/ge_ir.pb.h" -#include "utils/attr_utils.h" -#include "utils/ge_ir_utils.h" - -using google::protobuf::io::FileInputStream; -using google::protobuf::io::FileOutputStream; -using google::protobuf::io::ZeroCopyInputStream; - -namespace { -const int DEFAULT_VERSION = 1; -const int ACCESS_PERMISSION_BITS = 0400; -} // namespace - -namespace ge { -void Model::Init() { - (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); - (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); - version_ = 0; -} - -Model::Model() { - attrs_.InitDefault(); - Init(); -} - -Model::Model(const string &name, const string &custom_version) - : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) { - attrs_.InitDefault(); - Init(); -} - -string Model::GetName() const { return name_; } - -void Model::SetName(const string &name) { name_ = name; } - -uint32_t Model::GetVersion() const { return version_; } - -string Model::GetPlatformVersion() const { return platform_version_; } - -void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } - -Graph Model::GetGraph() const { return graph_; } - -graphStatus Model::Save(Buffer &buffer, bool is_dump) const { - ModelSerialize serialize; - buffer = serialize.SerializeModel(*this, is_dump); - return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; } - -graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) { - ModelSerialize serialize; - model = serialize.UnserializeModel(data, len); - return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::SaveToFile(const string &file_name) const { - Buffer buffer; - if ((*this).Save(buffer) != GRAPH_SUCCESS) { - GE_LOGE("save to file fail."); - return GRAPH_FAILED; - } - // Write file - ge::proto::ModelDef ge_proto; - if (buffer.GetData() != nullptr) { - std::string str((const char *)buffer.GetData(), buffer.GetSize()); - if (!ge_proto.ParseFromString(str)) { - return GRAPH_FAILED; - } - char real_path[PATH_MAX] = {0x00}; - if (strlen(file_name.c_str()) >= PATH_MAX) { - return GRAPH_FAILED; - } - if (realpath(file_name.c_str(), real_path) == nullptr) { - GELOGI("file %s does not exit, it will be created.", file_name.c_str()); - } - int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); - if (fd < 0) { - GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); - return GRAPH_FAILED; - } - bool ret = ge_proto.SerializeToFileDescriptor(fd); - if (!ret) { - GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed"); - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - return GRAPH_FAILED; - } - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - if (!ret) { - GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -graphStatus Model::Load(ge::proto::ModelDef &model_def) { - ModelSerialize serialize; - *this = serialize.UnserializeModel(model_def); - return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -bool Model::IsValid() const { return graph_.IsValid(); } - -graphStatus Model::LoadFromFile(const string &file_name) { - char real_path[PATH_MAX] = {0x00}; - if (strlen(file_name.c_str()) >= PATH_MAX) { - return GRAPH_FAILED; - } - if (realpath(file_name.c_str(), real_path) == nullptr) { - GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str()); - return GRAPH_FAILED; - } - int fd = open(real_path, O_RDONLY); - if (fd < 0) { - GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); - return GRAPH_FAILED; - } - - ge::proto::ModelDef model_def; - bool ret = model_def.ParseFromFileDescriptor(fd); - if (!ret) { - GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed"); - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - return GRAPH_FAILED; - } - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "close file descriptor fail."); - return GRAPH_FAILED; - } - if (!ret) { - GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); - return GRAPH_FAILED; - } - return Load(model_def); -} - -ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; } - -ConstProtoAttrMapHelper Model::GetAttrMap() const { - return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); -} -} // namespace ge diff --git a/metadef/graph/model_serialize.cc b/metadef/graph/model_serialize.cc deleted file mode 100644 index 16855fc5..00000000 --- a/metadef/graph/model_serialize.cc +++ /dev/null @@ -1,763 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/model_serialize.h" -#include - -#include -#include - -#include "debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/detail/model_serialize_imp.h" -#include "proto/ge_ir.pb.h" -#include "utils/graph_utils.h" -#include "debug/ge_op_types.h" - -using std::map; -using std::string; - -namespace ge { -bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) { - auto sep = node_index.rfind(":"); - if (sep == string::npos) { - GELOGW("separator is not found in node_index."); - return false; - } - node_name = node_index.substr(0, sep); - auto index_str = node_index.substr(sep + 1); - index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor, - proto::TensorDef *tensor_proto) { - GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null."); - GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null."); - - if (tensor->tensor_def_.GetProtoMsg() != nullptr) { - *tensor_proto = *tensor->tensor_def_.GetProtoMsg(); - return true; - } - return false; -} - -bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); - - op_def_proto->clear_input(); - // Inputs - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - if (in_data_anchor != nullptr) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { - op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + - std::to_string(peer_out_anchor->GetIdx())); - } else { - op_def_proto->add_input(""); - } - } - } - // Control edge - auto control_anchor = node->GetInControlAnchor(); - if (control_anchor != nullptr) { - auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors(); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { - op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1"); - } - } - } - return true; -} - -bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); - if (op_desc->op_def_.GetProtoMsg() != nullptr) { - *op_def_proto = *op_desc->op_def_.GetProtoMsg(); - // Delete unnecessary attr - if (is_dump) { - auto attr = op_def_proto->mutable_attr(); - attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); - attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); - attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); - GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), - attr->erase(ATTR_NAME_WEIGHTS)); - } - op_def_proto->clear_input_desc(); - op_def_proto->clear_output_desc(); - // Input descs - if (op_desc->GetAllInputsSize() > 0) { - auto size = static_cast(op_desc->GetAllInputsSize()); - for (uint32_t i = 0; i < size; i++) { - auto tensor_desc = op_desc->GetInputDescPtrDfault(i); - if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { - *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); - } - } - } - // Output descs - if (op_desc->GetOutputsSize() > 0) { - auto size = static_cast(op_desc->GetOutputsSize()); - for (uint32_t i = 0; i < size; i++) { - auto tensor_desc = op_desc->GetOutputDescPtr(i); - if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { - *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); - } - } - } - - op_def_proto->set_id(op_desc->GetId()); - for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { - op_def_proto->add_subgraph_name(name); - } - OpDescToAttrDef(op_desc, op_def_proto); - } - return true; -} - -void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { - proto::AttrDef key_in; - proto::AttrDef value_in; - auto op_desc_attr = op_def_proto->mutable_attr(); - if (!op_desc->input_name_idx_.empty()) { - for (auto &item : op_desc->input_name_idx_) { - key_in.mutable_list()->add_s(item.first); - value_in.mutable_list()->add_i(item.second); - } - op_desc_attr->insert({"_input_name_key", key_in}); - op_desc_attr->insert({"_input_name_value", value_in}); - } - proto::AttrDef key_out; - proto::AttrDef value_out; - if (!op_desc->output_name_idx_.empty()) { - for (auto &item : op_desc->output_name_idx_) { - key_out.mutable_list()->add_s(item.first); - value_out.mutable_list()->add_i(item.second); - } - op_desc_attr->insert({"_output_name_key", key_out}); - op_desc_attr->insert({"_output_name_value", value_out}); - } - proto::AttrDef opt_input; - if (!op_desc->optional_input_names_.empty()) { - for (auto &item : op_desc->optional_input_names_) { - opt_input.mutable_list()->add_s(item); - } - op_desc_attr->insert({"_opt_input", opt_input}); - } -} - -bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { - if (node == nullptr || op_def_proto == nullptr) { - GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); - return false; - } - if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { - GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); - return false; - } - if (SerializeEdge(node, op_def_proto)) { - return true; - } else { - return false; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, - proto::GraphDef *graph_proto, - bool is_dump) { - if (graph == nullptr || graph_proto == nullptr) { - GELOGE(GRAPH_FAILED, "Input para Invalid"); - return false; - } - graph_proto->set_name(graph->GetName()); - // Inputs - for (const auto &input : graph->GetInputNodes()) { - if (input != nullptr) { - graph_proto->add_input(input->GetName() + ":0"); - } - } - // Outputs - for (const auto &output : graph->GetGraphOutNodesInfo()) { - if (output.first != nullptr) { - graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); - GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second); - } - } - if (graph->attrs_.GetProtoMsg() != nullptr) { - *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); - } - for (const auto &node : graph->GetDirectNode()) { - if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { - if (node->GetOpDesc() != nullptr) { - GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); - } - return false; - } - } - return true; -} - -bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { - if (model_proto == nullptr) { - GELOGE(GRAPH_FAILED, "model_proto para Invalid"); - return false; - } - model_proto->set_name(model.GetName()); - model_proto->set_custom_version(model.GetPlatformVersion()); - model_proto->set_version(model.GetVersion()); - if (model.attrs_.GetProtoMsg()) { - *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg(); - } - auto &graph = model.graph_; - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); - return false; - } - if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { - GELOGE(GRAPH_FAILED, "SerializeGraph fail"); - return false; - } - - for (auto subgraph : compute_graph->GetAllSubgraphs()) { - if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { - GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); - return false; - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( - GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { - tensor = std::shared_ptr(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); - if (tensor == nullptr) { - GELOGE(GRAPH_FAILED, "tensor is nullptr"); - return false; - } else { - return true; - } -} - -void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, - std::vector &value_in, std::vector &value_out, - std::vector &opt_input) { - if (!key_in.empty()) { - if (key_in.size() != value_in.size()) { - GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), - value_in.size()); - } else { - for (uint32_t i = 0; i < key_in.size(); ++i) { - op_desc->input_name_idx_.insert(std::pair(key_in.at(i), value_in.at(i))); - } - } - } - if (!key_out.empty()) { - if (key_out.size() != value_out.size()) { - GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), - value_out.size()); - } else { - for (uint32_t i = 0; i < key_out.size(); ++i) { - op_desc->output_name_idx_.insert(std::pair(key_out.at(i), value_out.at(i))); - } - } - } - if (!opt_input.empty()) { - for (const auto &i : opt_input) { - op_desc->optional_input_names_.insert(i); - } - } -} - -bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { - std::vector opt_input; - std::vector key_in; - std::vector value_in; - if (op_def_proto.attr().count("_opt_input") > 0) { - auto &name_list = op_def_proto.attr().at("_opt_input").list(); - for (const auto &item_s : name_list.s()) { - opt_input.push_back(item_s); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_opt_input"); - } - if (op_def_proto.attr().count("_input_name_key") > 0) { - auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); - for (const auto &item_s : output_name_key_list.s()) { - key_in.push_back(item_s); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_input_name_key"); - } - if (op_def_proto.attr().count("_input_name_value") > 0) { - auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); - for (const auto &item_i : input_name_value_list.i()) { - value_in.push_back(static_cast(item_i)); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_input_name_value"); - } - std::vector key_out; - std::vector value_out; - if (op_def_proto.attr().count("_output_name_key") > 0) { - auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); - for (const auto &item_s : output_name_key_list.s()) { - key_out.push_back(item_s); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_output_name_key"); - } - if (op_def_proto.attr().count("_output_name_value") > 0) { - auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); - for (const auto &item_i : output_name_value_list.i()) { - value_out.push_back(static_cast(item_i)); - } - auto op_desc_attr = op_def_proto.mutable_attr(); - op_desc_attr->erase("_output_name_value"); - } - - op_desc = std::shared_ptr(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto)); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr."); - - // Input tensor - for (auto &input_desc : *op_def_proto.mutable_input_desc()) { - std::shared_ptr temp_value = - std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); - GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); - op_desc->inputs_desc_.push_back(temp_value); - } - // Output tensor - for (auto &output_desc : *op_def_proto.mutable_output_desc()) { - std::shared_ptr temp_value = - std::shared_ptr(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); - GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); - op_desc->outputs_desc_.push_back(temp_value); - } - - op_desc->SetId(op_def_proto.id()); - uint32_t graph_index = 0; - for (const std::string &name : op_def_proto.subgraph_name()) { - op_desc->AddSubgraphName(name); - op_desc->SetSubgraphInstanceName(graph_index++, name); - } - - // insert name index by key and value - AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); - - return true; -} - -bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) { - GE_RT_FALSE_CHECK_NOTNULL(graph); - OpDescPtr op_desc = nullptr; - if (!UnserializeOpDesc(op_desc, op_def_proto)) { - GELOGW("UnserializeOpDesc error."); - } - - NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); - - // Inputs - int dst_index = 0; - for (const auto &input : op_def_proto.input()) { - string node_name; - int32_t index = 0; - if (ParseNodeIndex(input, node_name, index)) { - node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()}); - } - if (index >= 0) { - dst_index++; - } - } - node_map_[op_def_proto.name()] = node; - return true; -} - -bool ModelSerializeImp::HandleNodeNameRef() { - // Edges - for (auto &item : node_input_node_names_) { - auto src_node_it = node_map_.find(item.src_node_name); - if (src_node_it == node_map_.end()) { - GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str()); - return false; - } - GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue); - if (item.src_out_index >= 0) { - auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index); - auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); - if (src_anchor == nullptr || dst_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, - item.dst_node_name.c_str(), item.dst_in_index); - return false; - } - GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 - } else { - // Control edge - auto src_anchor = src_node_it->second->GetOutControlAnchor(); - auto dst_anchor = item.dst_node->GetInControlAnchor(); - if (src_anchor != nullptr && dst_anchor != nullptr) { - GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 - } - } - } - // Graph input - for (auto &item : graph_input_node_names_) { - auto node_it = node_map_.find(item.node_name); - if (node_it == node_map_.end()) { - GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); - return false; - } - GE_IF_BOOL_EXEC(item.graph == nullptr, continue); - auto ret = item.graph->AddInputNode(node_it->second); - if (ret == nullptr) { - return false; - } - } - // Graph output - for (auto &item : graph_output_node_names_) { - auto node_it = node_map_.find(item.node_name); - if (node_it == node_map_.end()) { - GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); - return false; - } - - GE_IF_BOOL_EXEC(item.graph == nullptr, continue); - auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); - GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index); - if (ret == nullptr) { - GELOGE(GRAPH_FAILED, "AddOutputNode failed."); - return false; - } - } - node_input_node_names_.clear(); - graph_input_node_names_.clear(); - graph_output_node_names_.clear(); - node_map_.clear(); - return true; -} - -bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map &subgraphs) { - std::queue all_graphs; - all_graphs.emplace(compute_graph); - while (!all_graphs.empty()) { - ComputeGraphPtr graph = all_graphs.front(); - all_graphs.pop(); - - for (const NodePtr &node : graph->GetDirectNode()) { - const OpDescPtr op_desc = node->GetOpDesc(); - for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { - auto it = subgraphs.find(name); - if (it == subgraphs.end()) { - GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), - subgraphs.size()); - return false; - } - - ComputeGraphPtr &subgraph = it->second; - subgraph->SetParentGraph(graph); - subgraph->SetParentNode(node); - compute_graph->AddSubgraph(subgraph->GetName(), subgraph); - all_graphs.emplace(subgraph); - } - } - } - - return true; -} - -bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { - model.name_ = model_proto.name(); - model.version_ = model_proto.version(); - model.platform_version_ = model_proto.custom_version(); - model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr()); - - auto &graphs_proto = *model_proto.mutable_graph(); - if (!graphs_proto.empty()) { - auto &graph_proto = graphs_proto[0]; - ComputeGraphPtr compute_graph_ptr; - if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { - model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); - } - - // 0 is main graph, following is subgraph. - map subgraphs; - for (int idx = 1; idx < graphs_proto.size(); ++idx) { - ComputeGraphPtr subgraph; - ModelSerializeImp impl; - if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { - GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); - return false; - } - - if (!impl.HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); - return false; - } - - subgraphs[subgraph->GetName()] = subgraph; - } - - if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { - GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); - return false; - } - } - - if (!HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); - return false; - } - return true; -} - -bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) { - graph = ComGraphMakeShared(graph_proto.name()); - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); - return false; - } - - // Inputs - for (auto input : graph_proto.input()) { - string node_name; - int32_t index; - if (ParseNodeIndex(input, node_name, index)) { - graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); - } - } - // Outputs - for (auto output : graph_proto.output()) { - string node_name; - int32_t index; - if (ParseNodeIndex(output, node_name, index)) { - graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); - } - } - graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr()); - for (auto &op_def_proto : *graph_proto.mutable_op()) { - if (!UnserializeNode(graph, op_def_proto)) { - GELOGE(GRAPH_FAILED, "UnserializeNode fail"); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph, - proto::GraphDef &graph_proto) { - if (!UnserializeGraphWithoutEdge(graph, graph_proto)) { - GELOGW("UnserializeGraphWithoutEdge fail"); - } - if (!HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail"); - return false; - } - return true; -} - -bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) { - GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null."); - GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null."); - - google::protobuf::io::CodedInputStream coded_stream(data, len); - // 2048M -1 - coded_stream.SetTotalBytesLimit(INT32_MAX, -1); - if (!proto->ParseFromCodedStream(&coded_stream)) { - GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len); - return false; - } - return true; -} - -Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { - proto::ModelDef model_def; - ModelSerializeImp imp; - if (!imp.SerializeModel(model, &model_def, is_dump)) { - return Buffer(); - } -#if !defined(__ANDROID__) && !defined(ANDROID) - Buffer buffer(model_def.ByteSizeLong()); -#else - Buffer buffer(model_def.ByteSize()); -#endif - GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed"); - GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); - auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); - if (ret != true) { - GELOGW("serialize to array fail."); - } - return buffer; -} - -size_t ModelSerialize::GetSerializeModelSize(const Model &model) { - proto::ModelDef model_def; - ModelSerializeImp imp; - if (!imp.SerializeModel(model, &model_def)) { - return 0; - } -#if !defined(__ANDROID__) && !defined(ANDROID) - return model_def.ByteSizeLong(); -#else - return model_def.ByteSize(); -#endif -} - -Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) { - if (data == nullptr) { - GELOGE(GRAPH_FAILED, "data is nullptr"); - return Model(); - } - - std::shared_ptr model_proto_ptr; - model_proto_ptr = ComGraphMakeShared(); - if (model_proto_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); - return Model(); - } - - auto &model_proto = *model_proto_ptr; - if (!ReadProtoFromBinaryFile(data, len, &model_proto)) { - GELOGE(GRAPH_FAILED, "ParseFromArray fail"); - return Model(); - } - - Model model; - ModelSerializeImp imp; - imp.SetProtobufOwner(model_proto_ptr); - if (!imp.UnserializeModel(model, model_proto)) { - GELOGE(GRAPH_FAILED, "Unserialize Model fail"); - return Model(); - } - return model; -} - -Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) { - std::shared_ptr model_def_ptr = ComGraphMakeShared(model_def); - GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed"); - - ModelSerializeImp imp; - imp.SetProtobufOwner(model_def_ptr); - Model model; - if (!imp.UnserializeModel(model, *model_def_ptr)) { - GELOGE(GRAPH_FAILED, "Unserialize Model fail"); - return Model(); - } - return model; -} - -Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) { - proto::GraphDef graph_def; - ModelSerializeImp imp; - if (!imp.SerializeGraph(graph, &graph_def)) { - return Buffer(); - } -#if !defined(__ANDROID__) && !defined(ANDROID) - Buffer buffer(graph_def.ByteSizeLong()); -#else - Buffer buffer(graph_def.ByteSize()); -#endif - GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); - GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); - auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); - if (ret != true) { - GE_LOGE("serialize to array fail."); - } - - return buffer; -} - -ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) { - if (data == nullptr) { - GELOGE(GRAPH_FAILED, "data is nullptr"); - return nullptr; - } - - std::shared_ptr graph_proto_ptr; - graph_proto_ptr = ComGraphMakeShared(); - if (graph_proto_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); - return nullptr; - } - proto::GraphDef &graph_proto = *graph_proto_ptr; - if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) { - GELOGE(GRAPH_FAILED, "ParseFromArray fail"); - return nullptr; - } - - ComputeGraphPtr graph; - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_proto_ptr); - if (!imp.UnserializeGraph(graph, graph_proto)) { - return nullptr; - } - return graph; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) { - proto::OpDef op_def; - ModelSerializeImp imp; - if (!imp.SerializeOpDesc(op_desc, &op_def)) { - return Buffer(); - } -#if !defined(__ANDROID__) && !defined(ANDROID) - Buffer buffer(op_def.ByteSizeLong()); -#else - Buffer buffer(op_def.ByteSize()); -#endif - GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); - GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); - auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast(buffer.GetSize())); - if (ret != true) { - GE_LOGE("serialize to array fail."); - } - - return buffer; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data, - size_t len) { - if (data == nullptr) { - GELOGE(GRAPH_FAILED, "data is nullptr"); - return nullptr; - } - - std::shared_ptr op_def_ptr; - op_def_ptr = ComGraphMakeShared(); - if (op_def_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; - } - proto::OpDef &op_def = *op_def_ptr; - if (!ReadProtoFromBinaryFile(data, len, &op_def)) { - GELOGE(GRAPH_FAILED, "ParseFromArray fail"); - return nullptr; - } - - OpDescPtr op_desc; - ModelSerializeImp imp; - imp.SetProtobufOwner(op_def_ptr); - if (!imp.UnserializeOpDesc(op_desc, op_def)) { - GELOGW("UnserializeOpDesc error."); - } - return op_desc; -} -} // namespace ge diff --git a/metadef/graph/module.mk b/metadef/graph/module.mk deleted file mode 100644 index 1e00b7fc..00000000 --- a/metadef/graph/module.mk +++ /dev/null @@ -1,3 +0,0 @@ -LOCAL_PATH := $(call my-dir) - -include $(LOCAL_PATH)/graph.mk diff --git a/metadef/graph/node.cc b/metadef/graph/node.cc deleted file mode 100644 index 10d6b3ed..00000000 --- a/metadef/graph/node.cc +++ /dev/null @@ -1,877 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/node.h" -#include -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "external/graph/operator_factory.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_tensor.h" -#include "graph/operator_factory_impl.h" -#include "graph/shape_refiner.h" -#include "utils/ge_ir_utils.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "common/util/error_manager/error_manager.h" - -using std::string; -using std::vector; - -namespace ge { -Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) - : op_(op), - owner_graph_(owner_graph), - in_data_anchors_(), - out_data_anchors_(), - in_control_anchor_(nullptr), - out_control_anchor_(nullptr), - attrs_(), - has_init_(false) { - anchor_status_updated_ = false; -} - -Node::~Node() { - for (const auto &in_data_anchor : in_data_anchors_) { - if (in_data_anchor != nullptr) { - in_data_anchor->UnlinkAll(); - } - } - for (const auto &out_data_anchor : out_data_anchors_) { - if (out_data_anchor != nullptr) { - out_data_anchor->UnlinkAll(); - } - } - if (in_control_anchor_ != nullptr) { - in_control_anchor_->UnlinkAll(); - } - if (out_control_anchor_ != nullptr) { - out_control_anchor_->UnlinkAll(); - } -} - -graphStatus Node::Init() { - if (has_init_) { - return GRAPH_SUCCESS; - } - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - size_t size = op_->GetAllInputsSize(); - for (size_t i = 0; i < size; i++) { - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - } - size = op_->GetOutputsSize(); - for (size_t i = 0; i < size; i++) { - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), i); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Current out_data_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - out_data_anchors_.push_back(anchor); - } - in_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); - out_control_anchor_ = ComGraphMakeShared(shared_from_this(), -1); - if (in_control_anchor_ == nullptr || out_control_anchor_ == nullptr) { - GELOGE(GRAPH_FAILED, "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - has_init_ = true; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); - return op_->GetName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); - return op_->GetType(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAttrsAreEqual(const Node &r_node) const { - const auto &attr_map = this->attrs_; - const auto &r_attr_map = r_node.attrs_; - // 1.Verify node's map size - if (attr_map.size() != r_attr_map.size()) { - GELOGE(GRAPH_FAILED, "Size of node's attr map verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - // 2.Verify node's map key, verify values is temporarily not implemented - for (const auto &it : attr_map) { - if (r_attr_map.count(it.first) == 0) { - GELOGE(GRAPH_FAILED, "Key of node's attr map verify failed, node name: %s key name: %s.", this->GetName().c_str(), - it.first.c_str()); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const { - return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) || - ((this->op_ == nullptr) && (r_node.op_ == nullptr))) && - IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") && - IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") && - IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") && - IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_")); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor, - const AnchorPtr &right_anchor, - size_t i) const { - GE_IF_BOOL_EXEC(left_anchor == nullptr, GELOGE(GRAPH_FAILED, "left_anchor is null."); return false); - GE_IF_BOOL_EXEC(right_anchor == nullptr, GELOGE(GRAPH_FAILED, "right_anchor is null."); return false); - - const auto anchor_peer_size = left_anchor->GetPeerAnchors().size(); - const auto right_anchor_peer_size = right_anchor->GetPeerAnchors().size(); - // Firstly, verify anchor's peer anchors size equal or not - if (anchor_peer_size != right_anchor_peer_size) { - GELOGE(GRAPH_FAILED, - "Size of anchor's peer anchors verify failed, node name: %s " - "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", - this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); - return false; - } - // Secondly, verify anchor's peer anchor owner node equal or not - for (size_t j = 0; j < anchor_peer_size; j++) { - const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); - const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); - if (peer_node == nullptr || r_peer_node == nullptr) { - GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", - this->GetName().c_str(), i, j); - return false; - } - // Determine the connection relationship by linking the node's name - if (peer_node->GetName() != r_peer_node->GetName()) { - GELOGE(GRAPH_FAILED, - "anchor's peer node name verify failed, node name: %s index[%zu]" - "peer node name %s is different from %s at index [%zu].", - this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const { - // 1.Verify all in data and control anchors size - const auto in_data_anchor_size = this->GetAllInDataAnchors().size(); - const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size(); - if (in_data_anchor_size != r_in_data_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's in data anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - const auto l_in_anchors = this->GetAllInAnchors(); - const auto r_in_anchors = r_node.GetAllInAnchors(); - // Data anchors size equal, all anchors size not equal, means control anchor size not equal - const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size; - const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size; - if (in_control_anchor_size != r_in_control_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's in control anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - // 2.Verify all in data and control anchors connect info - for (size_t i = 0; i < this->GetAllInAnchors().size(); i++) { - // Verify data anchors - if (i < in_data_anchor_size) { - const auto &in_anchor = l_in_anchors.at(i); - const auto &r_in_anchor = r_in_anchors.at(i); - if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) { - GELOGE(GRAPH_FAILED, "Node's in data control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } else { - // Verify control anchors - const auto &in_control_anchor = l_in_anchors.at(i); - const auto &r_in_control_anchor = r_in_anchors.at(i); - if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) { - GELOGE(GRAPH_FAILED, "Node's in control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const { - // 1.Verify all out data and control anchors size - const auto l_out_data_anchors = this->GetAllOutDataAnchors(); - const auto r_out_data_anchors = r_node.GetAllOutDataAnchors(); - const auto out_data_anchor_size = l_out_data_anchors.size(); - const auto r_out_data_anchor_size = r_out_data_anchors.size(); - if (out_data_anchor_size != r_out_data_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's out data anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - const auto l_out_anchors = this->GetAllOutAnchors(); - const auto r_out_anchors = r_node.GetAllOutAnchors(); - // Data anchors size equal, all anchors size not equal, means control anchor size not equal - const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size; - const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size; - if (out_control_anchor_size != r_out_control_anchor_size) { - GELOGE(GRAPH_FAILED, "Size of node's out control anchors verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - - // 2.Verify all out data and control anchors connect info - for (size_t i = 0; i < this->GetAllOutAnchors().size(); i++) { - // Verify data anchors - if (i < out_data_anchor_size) { - const auto &out_anchor = l_out_data_anchors.at(i); - const auto &r_out_anchor = r_out_data_anchors.at(i); - if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) { - GELOGE(GRAPH_FAILED, "Node's out data control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } else { - // Verify control anchors - const auto &out_control_anchor = l_out_anchors.at(i); - const auto &r_out_control_anchor = r_out_anchors.at(i); - if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) { - GELOGE(GRAPH_FAILED, "Node's out control anchor verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const { - return (NodeMembersAreEqual(r_node) && NodeAttrsAreEqual(r_node) && NodeInConnectsAreEqual(r_node) && - NodeOutConnectsAreEqual(r_node)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) { - // This function is deprecated, please use other two overloaded functions - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (op_->AddInputDesc(op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "add input desc failed."); - return GRAPH_FAILED; - } - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index, - NodePtr input_node) { - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - GE_CHECK_NOTNULL(op_); - auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "add input desc failed."); - return GRAPH_FAILED; - } - - if (index < GetAllInDataAnchors().size()) { - (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); - } else { - std::shared_ptr anchor = - ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) { - // This function is used for ParseWeights. - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - std::shared_ptr anchor = ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, make anchor failed", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const string &name, NodePtr input_node) { - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1) { - GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - GE_CHECK_NOTNULL(op_); - auto input_op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(input_op_desc); - auto index = op_->GetInputIndexByName(name); - if (index != -1) { - if (index >= static_cast(in_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "op %s get input name %s 's index %d is illegal.", op_->GetName().c_str(), name.c_str(), - index); - return GRAPH_FAILED; - } - (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); - } else { - std::shared_ptr anchor = - ComGraphMakeShared(shared_from_this(), in_data_anchors_.size()); - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "in_data_anchors_size is:%zu, malloc shared_ptr failed.", in_data_anchors_.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); - } - if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "add input desc failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const { - return owner_graph_.lock(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { - if (graph == nullptr) { - return GRAPH_PARAM_INVALID; - } - owner_graph_ = graph; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInDataAnchors() const { - return Vistor(shared_from_this(), in_data_anchors_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutDataAnchors() const { - return Vistor(shared_from_this(), out_data_anchors_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const { - return in_data_anchors_.size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const { - return out_data_anchors_.size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInAnchors() const { - std::vector vec; - // Push back in_data_anchors_ - for (const auto &in_anchor_iter : Vistor(shared_from_this(), in_data_anchors_)) { - auto in_anchor = Anchor::DynamicAnchorCast(in_anchor_iter); - if (in_anchor != nullptr) { - vec.push_back(in_anchor); - } - } - // Push back in_control_anchor_ - if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || - (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { - auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); - if (in_anchor != nullptr) { - vec.push_back(in_anchor); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutAnchors() const { - std::vector vec; - // Push back out_data_anchors_ - for (const auto &out_anchor_iter : Vistor(shared_from_this(), out_data_anchors_)) { - auto out_anchor = Anchor::DynamicAnchorCast(out_anchor_iter); - if (out_anchor != nullptr) { - vec.push_back(out_anchor); - } - } - // Push back out_control_anchor_ - if (out_control_anchor_->GetPeerInControlAnchors().size() > 0 || - out_control_anchor_->GetPeerInDataAnchors().size() > 0) { - auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); - if (out_anchor != nullptr) { - vec.push_back(out_anchor); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { - if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); - return nullptr; - } else { - return in_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { - // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ - if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { - GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); - return nullptr; - } else { - // Return control anchor - if (idx == -1) { - auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); - return in_anchor; - } - // Return data anchor - return in_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { - // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ - if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, - { - GetName().c_str(), - std::to_string(idx), - "out_anchor", - GetType().c_str(), - }); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); - return nullptr; - } else { - // Return control anchor - if (idx == -1) { - auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); - return out_anchor; - } - // Return data anchor - return out_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { - if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19019", {"opname", "index", "anchorname", "optype"}, - {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); - GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, - GetType().c_str()); - return nullptr; - } else { - return out_data_anchors_[idx]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const { - return in_control_anchor_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const { - return out_control_anchor_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInNodes() const { - std::vector vec; - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - auto node = out_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - if (in_control_anchor_ != nullptr) { - if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { - return Node::Vistor(shared_from_this(), vec); - } - - auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); - for (const auto &out_anchor : peer_out_anchors) { - GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); - auto node = out_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - - auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, - "in_control_anchor_ peer out control anchors is nullptr"); - auto node = out_control_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( - std::unordered_set &nodes_seen) const { - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - auto node = out_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { - continue; - } - if (nodes_seen.count(node.get()) == 0) { - return false; - } - } - - if (in_control_anchor_ != nullptr) { - if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { - return true; - } - auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, "out_control_anchor is nullptr"); - auto node = out_control_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { - continue; - } - if (nodes_seen.count(node.get()) == 0) { - return false; - } - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInDataNodes() const { - std::vector vec; - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); - auto anchor_ptr = in_anchor->GetPeerOutAnchor(); - if (anchor_ptr == nullptr) { - continue; - } - auto node = anchor_ptr->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInControlNodes() const { - std::vector vec; - if (in_control_anchor_ != nullptr) { - for (const auto &in_anchor : in_control_anchor_->GetPeerOutControlAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerOutControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutNodes() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC((peer_in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); - auto node = peer_in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - if (out_control_anchor_ != nullptr) { - auto peer_in_control_anchors = out_control_anchor_->GetPeerInControlAnchors(); - for (const auto &in_control_anchor : peer_in_control_anchors) { - GE_CHK_BOOL_EXEC(in_control_anchor != nullptr, continue, - "out_control_anchor_ peer in control anchors is nullptr"); - auto node = in_control_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInAllNodes() const { - std::vector vec; - for (const auto &in_node : GetInDataNodes()) { - vec.push_back(in_node); - } - for (const auto &in_control_node : GetInControlNodes()) { - vec.push_back(in_control_node); - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutDataNodes() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const { - uint32_t out_nums = 0; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - out_nums += out_anchor->GetPeerInDataNodesSize(); - } - return out_nums; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutControlNodes() const { - std::vector vec; - - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); - for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - - if (out_control_anchor_ != nullptr) { - for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - - return Node::Vistor(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutAllNodes() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), { continue; }, "out_data_anchors_ is nullptr"); - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), { continue; }, "GetPeerInDataAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - - if (out_control_anchor_ != nullptr) { - for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); - auto node = in_anchor->GetOwnerNode(); - GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); - vec.push_back(node); - } - } - return Node::Vistor(shared_from_this(), vec); -} - -graphStatus Node::InferShapeAndType() const { - Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); - graphStatus ret = ShapeRefiner::InferShapeAndType(shared_from_this(), op); - return ret; -} - -graphStatus Node::InferOriginFormat() const { - Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); - // Get infer func and execute - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - return op_->CallInferFormatFunc(op); -} -graphStatus Node::Verify() const { - const string data_type = "Data"; - const string aipp_data_type = "AippData"; - const string const_type = "Const"; - const string variable_type = "Variable"; - bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - - if (!is_unknown_graph) { - for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { - GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue); - bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || - op_->GetType() == const_type || op_->GetType() == variable_type || - op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; - if (!valid_anchor) { - ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, - {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); - GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); - return GRAPH_FAILED; - } - } - } - - string frameworkop_type = "FrameworkOp"; - bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; - if (need_update_name) { - auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); - if (node_op.IsEmpty()) { - GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); - } else { - GELOGD("get op from OperatorFactory success. opType: %s", op_->GetType().c_str()); - auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - if (temp_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "temp op desc is null"); - return GRAPH_FAILED; - } - if (!op_->UpdateInputName(temp_op_desc->GetAllInputName())) { - GELOGW("Verify UpdateInputName failed"); - } - if (!op_->UpdateOutputName(temp_op_desc->GetAllOutputName())) { - GELOGW("Verify UpdateOutputName failed"); - } - } - node_op.BreakConnect(); - } - GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); - if (op_->CommonVerify() == GRAPH_SUCCESS) { - Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); - auto verify_func = op_->GetVerifyFunc(); - if (verify_func == nullptr) { - verify_func = OperatorFactoryImpl::GetVerifyFunc(GetType()); - } - if (verify_func != nullptr) { - return (graphStatus)verify_func(op_proxy); - } - return GRAPH_SUCCESS; - } else { - GELOGE(GRAPH_FAILED, "%s Verify failed.", op_->GetType().c_str()); - return GRAPH_FAILED; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { return op_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) { - GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return GRAPH_PARAM_INVALID, "Param OpDesc is nullptr"); - GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, - "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), - op_desc->GetInputsSize()); - - GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, - "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), - op_desc->GetOutputsSize()); - op_ = op_desc; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> -Node::GetInDataNodesAndAnchors() const { - std::vector> vec; - for (const auto &p : in_data_anchors_) { - if (p == nullptr) { - GELOGW("indata anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - auto anchor_ptr = p->GetPeerOutAnchor(); - if (anchor_ptr == nullptr) { - continue; - } - auto node = anchor_ptr->GetOwnerNode(); - if (node == nullptr) { - GELOGW("src node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - vec.push_back(std::make_pair(node, anchor_ptr)); - } - return Node::Vistor>(shared_from_this(), vec); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> -Node::GetOutDataNodesAndAnchors() const { - std::vector> vec; - for (const auto &p : out_data_anchors_) { - if (p == nullptr) { - GELOGW("out data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - for (const auto &in_anchor : p->GetPeerInDataAnchors()) { - if (in_anchor == nullptr) { - GELOGW("dst in data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - auto node = in_anchor->GetOwnerNode(); - if (node == nullptr) { - GELOGW("dst node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); - continue; - } - vec.push_back(std::make_pair(node, in_anchor)); - } - } - return Node::Vistor>(shared_from_this(), vec); -} -} // namespace ge diff --git a/metadef/graph/op_desc.cc b/metadef/graph/op_desc.cc deleted file mode 100644 index fdd1acb7..00000000 --- a/metadef/graph/op_desc.cc +++ /dev/null @@ -1,1370 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/op_desc.h" -#include "debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "external/graph/operator.h" -#include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/operator_factory_impl.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "proto/ge_ir.pb.h" - -using std::make_pair; -using std::shared_ptr; -using std::string; -using std::vector; - -/*lint -save -e521 -e681 -e732 -e737*/ -namespace ge { -const std::string ATTR_NAME_ID = "id"; - -const std::string ATTR_NAME_STREAM_ID = "stream_id"; - -const std::string ATTR_NAME_INPUT_NAME = "input_name"; - -const std::string ATTR_NAME_SRC_NAME = "src_name"; - -const std::string ATTR_NAME_SRC_INDEX = "src_index"; - -const std::string ATTR_NAME_INPUT = "input"; - -const std::string ATTR_NAME_OUTPUT = "output"; - -const std::string ATTR_NAME_INPUT_DESC = "input_desc"; - -const std::string ATTR_NAME_OUTPUT_DESC = "output_desc"; - -const std::string ATTR_NAME_DST_NAME = "dst_name"; - -const std::string ATTR_NAME_DST_INDEX = "dst_index"; - -const std::string ATTR_NAME_WORKSPACE = "workspace"; - -const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; - -const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; - -const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { - op_def_.InitDefault(); - if (op_def_.GetProtoMsg() != nullptr) { - op_def_.GetProtoMsg()->set_has_out_attr(true); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) { - op_def_.InitDefault(); - if (op_def_.GetProtoMsg() != nullptr) { - op_def_.GetProtoMsg()->set_has_out_attr(true); - } - SetName(name); - SetType(type); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ProtoMsgOwner &proto_msg_owner, - ge::proto::OpDef *op_def) - : op_def_(proto_msg_owner, op_def) { - if (op_def != nullptr && !op_def->has_out_attr()) { - op_def->set_has_out_attr(true); - - int64_t id = 0; - (void)AttrUtils::GetInt(this, ATTR_NAME_ID, id); - op_def->set_id(id); - - int64_t stream_id = 0; - (void)AttrUtils::GetInt(this, ATTR_NAME_STREAM_ID, stream_id); - op_def->set_stream_id(stream_id); - - vector input_name; - (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME, input_name); - for (auto &item : input_name) { - op_def->add_input_name(item); - } - vector src_name; - (void)AttrUtils::GetListStr(this, ATTR_NAME_SRC_NAME, src_name); - for (auto &item : src_name) { - op_def->add_src_name(item); - } - vector src_index; - (void)AttrUtils::GetListInt(this, ATTR_NAME_SRC_INDEX, src_index); - for (auto &item : src_index) { - op_def->add_src_index(item); - } - vector input; - (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT, input); - for (auto &item : input) { - op_def->add_input_i(item); - } - vector output; - (void)AttrUtils::GetListInt(this, ATTR_NAME_OUTPUT, output); - for (auto &item : output) { - op_def->add_output_i(item); - } - vector dst_name; - (void)AttrUtils::GetListStr(this, ATTR_NAME_DST_NAME, dst_name); - for (auto &item : dst_name) { - op_def->add_dst_name(item); - } - vector dst_index; - (void)AttrUtils::GetListInt(this, ATTR_NAME_DST_INDEX, dst_index); - for (auto &item : dst_index) { - op_def->add_dst_index(item); - } - vector workspace; - (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE, workspace); - for (auto &item : workspace) { - op_def->add_workspace(item); - } - vector workspace_bytes; - (void)AttrUtils::GetListInt(this, ATTR_NAME_WORKSPACE_BYTES, workspace_bytes); - for (auto &item : workspace_bytes) { - op_def->add_workspace_bytes(item); - } - vector is_input_const; - (void)AttrUtils::GetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); - for (auto item : is_input_const) { - op_def->add_is_input_const(item); - } - auto input_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_INPUT_DESC].mutable_list(); - if (input_desc_mutable_list != nullptr) { - *op_def->mutable_input_desc() = *(input_desc_mutable_list->mutable_td()); - } - auto output_desc_mutable_list = (*op_def->mutable_attr())[ATTR_NAME_OUTPUT_DESC].mutable_list(); - if (output_desc_mutable_list != nullptr) { - *op_def->mutable_output_desc() = *(output_desc_mutable_list->mutable_td()); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetName() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->name(); - } - return ""; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_name(name); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetType() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->type(); - } - return ""; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const string &type) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_type(type); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) { - int index = static_cast(inputs_desc_.size()); - return AddInputDesc("__input" + std::to_string(index), input_desc); -} - -graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc) { - graphStatus ret = GRAPH_SUCCESS; - if (index < inputs_desc_.size()) { - // InputsDesc[index] is exist, then update it - ret = UpdateInputDesc(index, input_desc); - } else { - // InputDesc[index] is not exist, then add it - ret = AddInputDesc(input_desc); - } - return ret; -} - -graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { - if (input_name_idx_.find(name) != input_name_idx_.end()) { - GELOGI("input %s is exist, update it", name.c_str()); - graphStatus ret = UpdateInputDesc(name, input_desc); - return ret; - } else { - int index = static_cast(inputs_desc_.size()); - std::shared_ptr in_desc = ComGraphMakeShared(input_desc); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - inputs_desc_.push_back(in_desc); - (void)input_name_idx_.insert(make_pair(name, index)); - if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { - register_input_name_.push_back(name); - } - - return GRAPH_SUCCESS; - } -} - -graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { - for (unsigned int i = 0; i < num; i++) { - string input_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, - "Add input tensor_desc is existed. name[%s]", input_name.c_str()); - - std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - if (index > inputs_desc_.size()) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); - return GRAPH_FAILED; - } - - (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); - - // Update index in input_name_idx - for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { - if (it->second >= (index + i)) { - it->second += 1; - } - } - - (void)input_name_idx_.insert(make_pair(input_name, i + index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) { - for (unsigned int i = 0; i < num; i++) { - string output_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, - "Add input tensor_desc is existed. name[%s]", output_name.c_str()); - - std::shared_ptr out_desc = ComGraphMakeShared(GeTensorDesc()); - if (out_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - if (index > outputs_desc_.size()) { - GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); - return GRAPH_FAILED; - } - - (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc); - - // Update index in input_name_idx - for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { - if (it->second >= (index + i)) { - it->second += 1; - } - } - - (void)output_name_idx_.insert(make_pair(output_name, i + index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { - for (unsigned int i = 0; i < num; i++) { - string input_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, - "Add input tensor_desc is existed. name[%s]", input_name.c_str()); - - std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); - - // Update index in input_name_idx - for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { - it->second += 1; - } - - (void)input_name_idx_.insert(make_pair(input_name, 0)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int num) { - for (unsigned int i = 0; i < num; i++) { - string output_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, - "Add output tensor_desc is existed. name[%s]", output_name.c_str()); - - std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "AddOutputDescForward failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - (void)outputs_desc_.insert(outputs_desc_.begin(), in_desc); - - // Update index in output_name_idx - for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { - it->second += 1; - } - (void)output_name_idx_.insert(make_pair(output_name, 0)); - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { - if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; - (void)optional_input_names_.insert(name); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { - GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); - - inputs_desc_[index] = ComGraphMakeShared(tensor_Desc); - if (inputs_desc_[index] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { - return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && - IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && - IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && - IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && - IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { - const auto &op_def = this->op_def_.GetProtoMsg(); - const auto &r_op_def = r_op_desc.op_def_.GetProtoMsg(); - if ((op_def != nullptr) && (r_op_def != nullptr)) { - // Message OpDef in ge_ir.proto - return ( - IsEqual(op_def->name(), r_op_def->name(), "OpDef_.name()") && - IsEqual(op_def->type(), r_op_def->type(), "OpDef_.type()") && - IsEqual(ToString(op_def->input()), ToString(r_op_def->input()), "OpDef_.input()") && - IsEqual(op_def->has_out_attr(), r_op_def->has_out_attr(), "OpDef_.has_out_attr()") && - IsEqual(op_def->stream_id(), r_op_def->stream_id(), "OpDef_.stream_id()") && - IsEqual(ToString(op_def->input_name()), ToString(r_op_def->input_name()), "OpDef_.input_name()") && - IsEqual(ToString(op_def->src_name()), ToString(r_op_def->src_name()), "OpDef_.src_name()") && - IsEqual(ToString(op_def->dst_name()), ToString(r_op_def->dst_name()), "OpDef_.dst_name()") && - IsEqual(ToString(op_def->src_index()), ToString(r_op_def->src_index()), "OpDef_.src_index()") && - IsEqual(ToString(op_def->dst_index()), ToString(r_op_def->dst_index()), "OpDef_.dst_index()") && - IsEqual(ToString(op_def->input_i()), ToString(r_op_def->input_i()), "OpDef_.input_i()") && - IsEqual(ToString(op_def->output_i()), ToString(r_op_def->output_i()), "OpDef_.output_i()") && - IsEqual(ToString(op_def->workspace()), ToString(r_op_def->workspace()), "OpDef_.workspace()") && - IsEqual(ToString(op_def->workspace_bytes()), ToString(r_op_def->workspace_bytes()), "OpDef_.workspace_bytes()") && - IsEqual(ToString(op_def->is_input_const()), ToString(r_op_def->is_input_const()), "OpDef_.is_input_const()")); - } else { - return ((op_def == nullptr) && (r_op_def == nullptr)); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual( - const OpDesc &r_op_desc) const { - // 1.Verify inputs and outputs desc size - const auto inputs_desc_size = this->inputs_desc_.size(); - const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); - if (inputs_desc_size != r_inputs_desc_size) { - GELOGE(GRAPH_FAILED, "Size of OpDesc's inputs desc verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - const auto outputs_desc_size = this->outputs_desc_.size(); - const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size(); - if (outputs_desc_size != r_outputs_desc_size) { - GELOGE(GRAPH_FAILED, "Size of OpDesc's outputs desc verify failed, node name: %s.", this->GetName().c_str()); - return false; - } - // 2.Verify all inputs desc equal - for (uint32_t i = 0; i < inputs_desc_size; i++) { - const auto &in_ge_tensor_desc = this->GetInputDesc(i); - const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i); - // Determine the connection relationship by GeTensorDesc - if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) { - GELOGE(GRAPH_FAILED, "Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.", - this->GetName().c_str()); - return false; - } - } - // 3.Verify all outputs desc equal - for (uint32_t i = 0; i < outputs_desc_size; i++) { - const auto &out_ge_tensor_desc = this->GetOutputDesc(i); - const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i); - if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) { - GELOGE(GRAPH_FAILED, "Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.", - this->GetName().c_str()); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const { - return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) && - OpDescGenTensorDescsAreEqual(r_op_desc)); -} - -graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { - auto it = input_name_idx_.find(name); - if (it == input_name_idx_.end()) { - GELOGW("Cann't find the input desc. name[%s]", name.c_str()); - return GRAPH_FAILED; - } - if (it->second >= inputs_desc_.size()) { - GELOGE(GRAPH_FAILED, "[%d] more than size of inputs_desc_", it->second); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); - return GRAPH_FAILED); - inputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); - if (inputs_desc_[it->second] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateInputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -bool OpDesc::InputIsSet(const string &name) const { - auto it = input_name_idx_.find(name); - if (it != input_name_idx_.end()) { - GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); - auto tensor_desc = inputs_desc_[it->second]; - GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); - auto dims = tensor_desc->GetShape().GetDims(); - if (dims.size() > 0) { - return true; - } - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), GeTensorDesc()); - return *(inputs_desc_[index].get()); -} - -GeTensorDesc OpDesc::GetInputDesc(const string &name) const { - auto it = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); - GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); - return *(inputs_desc_[it->second].get()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { - GELOGW("input desc is invalid"); - return nullptr; - } - return inputs_desc_[index]; -} - -GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - if (it == input_name_idx.end()) { - GELOGW("Failed to get [%s] input desc", name.c_str()); - return nullptr; - } - return MutableInputDesc(it->second); -} - -GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { - vector names; - if (input_name_idx_.empty()) { - return OpDesc::Vistor(shared_from_this(), names); - } - for (std::pair input : input_name_idx_) { - names.push_back(input.first); - } - return OpDesc::Vistor(shared_from_this(), names); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) { - op_kernel_lib_name_ = name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const { - return op_kernel_lib_name_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) { - engine_name_ = name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { - vector temp{}; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - temp.push_back(*it); - } else { - GELOGW("this inputDesc is InValid, it won't be return"); - continue; - } - } - return OpDesc::Vistor(shared_from_this(), temp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDescPtr() const { - vector temp{}; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - temp.push_back(it); - } else { - GELOGW("this inputDesc is InValid, it won't be return"); - continue; - } - } - return OpDesc::Vistor(shared_from_this(), temp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { - // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. - size_t size = 0; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - size++; - } - } - return size; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { - int index = static_cast(outputs_desc_.size()); - return AddOutputDesc("__output" + std::to_string(index), output_desc); -} - -graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &output_desc) { - GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(name) == output_name_idx_.end()), GRAPH_FAILED, - "Add output tensor_Desc is existed. name[%s]", name.c_str()); - int index = static_cast(outputs_desc_.size()); - - std::shared_ptr tensor = ComGraphMakeShared(output_desc); - if (tensor == nullptr) { - GELOGE(GRAPH_FAILED, "AddOutputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - outputs_desc_.push_back(tensor); - (void)output_name_idx_.insert(make_pair(name, index)); - if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { - register_output_name_.push_back(name); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::UpdateOutputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { - GE_CHK_BOOL_RET_STATUS((index < outputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); - - outputs_desc_[index] = ComGraphMakeShared(tensor_Desc); - if (outputs_desc_[index] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::UpdateOutputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { - auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("Cann't find the output desc. name[%s]", name.c_str()); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); - return GRAPH_FAILED); - outputs_desc_[it->second] = ComGraphMakeShared(tensor_Desc); - if (outputs_desc_[it->second] == nullptr) { - GELOGE(GRAPH_FAILED, "UpdateOutputDesc failed, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < outputs_desc_.size(), GeTensorDesc()); - return *(outputs_desc_[index].get()); -} - -GeTensorDesc OpDesc::GetOutputDesc(const string &name) const { - auto it = output_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), GeTensorDesc()); - GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), GeTensorDesc()); - return *(outputs_desc_[it->second].get()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < outputs_desc_.size(), nullptr, "Cann't find the output desc %u", index); - return outputs_desc_[index]; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { - auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("Failed to get [%s] output desc", name.c_str()); - return nullptr; - } - return MutableOutputDesc(it->second); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { - return static_cast(outputs_desc_.size()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDesc() const { - vector temp{}; - for (const auto &it : outputs_desc_) { - temp.push_back(*it); - } - return OpDesc::Vistor(shared_from_this(), temp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDescPtr() const { - return OpDesc::Vistor(shared_from_this(), outputs_desc_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { return outputs_desc_.size(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(outputs_desc_.size()), nullptr); - return outputs_desc_[index]; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { - GELOGW("inputsDesc[%u] is InValid", index); - return nullptr; - } else { - return inputs_desc_[static_cast(index)]; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr -OpDesc::GetInputDescPtrDfault(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); - return inputs_desc_[(int32_t)index]; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { - auto it = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); - return inputs_desc_[it->second]; -} - -graphStatus OpDesc::AddRegisterInputName(const std::string &name) { - if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { - register_input_name_.push_back(name); - } - - return GRAPH_SUCCESS; -} - -vector OpDesc::GetRegisterInputName() const { return register_input_name_; } - -graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { - if (is_push_back) { - for (unsigned int i = 0; i < num; i++) { - if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - } else { - if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - if (AddRegisterInputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { - if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddRegisterOutputName(const string &name) { - if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { - register_output_name_.push_back(name); - } - - return GRAPH_SUCCESS; -} - -vector OpDesc::GetRegisterOutputName() const { return register_output_name_; } - -graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { - if (is_push_back) { - for (unsigned int i = 0; i < num; i++) { - if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - } else { - if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; - } - - if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OpDesc::IsOptionalInput(const string &name) const { - return optional_input_names_.find(name) != optional_input_names_.end(); -} - -bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } - -std::map OpDesc::GetAllInputName() const { return input_name_idx_; } - -std::map OpDesc::GetAllOutputName() { return output_name_idx_; } - -bool OpDesc::UpdateInputName(std::map input_name_idx) { - bool ret = true; - // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name. - auto input_map_size = inputs_desc_.size(); - auto factory_map_size = input_name_idx.size(); - // It indicates that some inputs have no optionalname. - // The redundant optionalname of factory needs to be deleted and then assigned - if (input_map_size < factory_map_size) { - GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, - factory_map_size); - for (auto it = input_name_idx.begin(); it != input_name_idx.end();) { - if (it->second >= input_map_size) { - it = input_name_idx.erase(it); - } else { - ++it; - } - } - if (input_name_idx.size() == input_map_size) { - GELOGI("UpdateInputName"); - input_name_idx_ = input_name_idx; - } else { - ret = false; - GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); - } - } else if (input_map_size == factory_map_size) { - input_name_idx_ = input_name_idx; - } else { - ret = false; - GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); - } - return ret; -} - -bool OpDesc::UpdateOutputName(std::map output_name_idx) { - size_t output_map_size = GetAllOutputsDescSize(); - size_t factory_map_size = output_name_idx.size(); - if (output_map_size < factory_map_size) { - GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, - factory_map_size); - for (auto it = output_name_idx.begin(); it != output_name_idx.end();) { - if (it->second >= output_map_size) { - it = output_name_idx.erase(it); - } else { - ++it; - } - } - if (output_name_idx.size() == output_map_size) { - GELOGI("UpdateoutputName"); - output_name_idx_ = output_name_idx; - return true; - } - } else if (output_map_size == factory_map_size) { - output_name_idx_ = output_name_idx; - return true; - } else { - GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); - return false; - } - GELOGW("UpdateOutputName org name map size: %zu, factory map size: %zu", output_map_size, factory_map_size); - return false; -} - -std::function OpDesc::GetInferFunc() const { return infer_func_; } - -std::function OpDesc::GetVerifyFunc() const { return verifier_func_; } - -void OpDesc::AddInferFunc(const std::function &func) { infer_func_ = func; } - -std::function OpDesc::GetInferFormatFunc() const { return infer_format_func_; } - -void OpDesc::AddInferFormatFunc(const std::function &func) { infer_format_func_ = func; } - -void OpDesc::AddVerifierFunc(const std::function &func) { verifier_func_ = func; } - -graphStatus OpDesc::InferShapeAndType() { - if (infer_func_ == nullptr) { - infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); - if (infer_func_ == nullptr) { - GELOGW("%s does not have inferfunc_.", GetName().c_str()); - /// The infoshape function has not been added for each operator in the current operator information library. - /// No infoshape added operator skips the call - /// and directly uses the shape information passed down by the upper framework - return GRAPH_SUCCESS; - } - } - Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); - graphStatus ret = (graphStatus)infer_func_(op_proxy); - op_proxy.BreakConnect(); - return ret; -} - -graphStatus OpDesc::DefaultInferFormat() { - ge::Format first_none_nd_format = FORMAT_ND; - auto input_descs = GetAllInputsDescPtr(); - auto output_descs = GetAllOutputsDescPtr(); - // Overall input and output,get the first non-nd format - for (const auto &input_desc : input_descs) { - Format origin_format = input_desc->GetOriginFormat(); - if (origin_format != FORMAT_ND) { - first_none_nd_format = origin_format; - break; - } - } - for (const auto &output_desc : output_descs) { - Format origin_format = output_desc->GetOriginFormat(); - if (origin_format != FORMAT_ND) { - first_none_nd_format = origin_format; - break; - } - } - // Refresh all input output format - GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format); - - for (const auto &input_desc : input_descs) { - Format origin_format = input_desc->GetOriginFormat(); - GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format); - if (origin_format == FORMAT_ND) { - input_desc->SetOriginFormat(first_none_nd_format); - input_desc->SetFormat(first_none_nd_format); - } - } - for (const auto &output_desc : output_descs) { - Format origin_format = output_desc->GetOriginFormat(); - GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format); - if (origin_format == FORMAT_ND) { - output_desc->SetOriginFormat(first_none_nd_format); - output_desc->SetFormat(first_none_nd_format); - } - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::OpVerify() { - if (verifier_func_ == nullptr) { - verifier_func_ = OperatorFactoryImpl::GetVerifyFunc(GetType()); - } - if (verifier_func_ != nullptr) { - Operator op_proxy = ge::OpDescUtils::CreateOperatorFromOpDesc(shared_from_this()); - graphStatus ret = (graphStatus)verifier_func_(op_proxy); - op_proxy.BreakConnect(); - return ret; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::CommonVerify() const { - for (const string &iname : GetAllInputNames()) { - // Checking shape of all inputs - vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); - for (int64_t dim : ishape) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( - "E19014", {"opname", "value", "reason"}, - {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); - return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), - iname.c_str()); - } - } - // Check all attributes defined - const auto &all_attributes = GetAllAttrs(); - for (const auto &name : GetAllAttrNames()) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - all_attributes.find(name) == all_attributes.end(), - ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, - {GetName(), "attribute " + name, "is empty"}); - return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { - auto it = input_name_idx_.begin(); - for (; it != input_name_idx_.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); - return it->first; -} - -int OpDesc::GetInputIndexByName(const string &name) const { - auto it_find = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); - return static_cast(it_find->second); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetOutputNameByIndex(uint32_t index) const { - auto it = output_name_idx_.begin(); - for (; it != output_name_idx_.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), ""); - return it->first; -} - -int OpDesc::GetOutputIndexByName(const string &name) const { - auto it_find = output_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1); - return static_cast(it_find->second); -} - -ProtoAttrMapHelper OpDesc::MutableAttrMap() { - if (op_def_.GetProtoMsg() == nullptr) { - GELOGE(GRAPH_FAILED, "op def get proto msg failed"); - return GeIrProtoHelper(); - } - return ProtoAttrMapHelper(op_def_.GetProtoOwner(), op_def_.GetProtoMsg()->mutable_attr()); -} - -ConstProtoAttrMapHelper OpDesc::GetAttrMap() const { - return ConstProtoAttrMapHelper(op_def_.GetProtoOwner(), &op_def_.GetProtoMsg()->attr()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(int64_t id) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_id(id); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->id(); - } - return 0; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(int64_t stream_id) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->set_stream_id(stream_id); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - return proto_msg->stream_id(); - } - return 0; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const vector &input_name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_input_name(); - for (auto &item : input_name) { - proto_msg->add_input_name(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputName() const { - vector input_name; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->input_name()) { - input_name.push_back(item); - } - } - return input_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const vector &src_name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_src_name(); - for (auto &item : src_name) { - proto_msg->add_src_name(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcName() const { - vector src_name; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->src_name()) { - src_name.push_back(item); - } - } - return src_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const vector &src_index) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_src_index(); - for (auto &item : src_index) { - proto_msg->add_src_index(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetSrcIndex() const { - vector src_index; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->src_index()) { - src_index.push_back(item); - } - } - return src_index; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const vector &input) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_input_i(); - for (auto &item : input) { - proto_msg->add_input_i(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetInputOffset() const { - vector input; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->input_i()) { - input.push_back(item); - } - } - return input; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const vector &output) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_output_i(); - for (auto &item : output) { - proto_msg->add_output_i(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOutputOffset() const { - vector output; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->output_i()) { - output.push_back(item); - } - } - return output; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const vector &dst_name) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_dst_name(); - for (auto &item : dst_name) { - proto_msg->add_dst_name(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstName() const { - vector dst_name; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->dst_name()) { - dst_name.push_back(item); - } - } - return dst_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector &depend_names) { - auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); - if (ret != true) { - GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOpInferDepends() const { - vector depend_names; - (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); - return depend_names; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector &dst_index) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_dst_index(); - for (auto &item : dst_index) { - proto_msg->add_dst_index(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstIndex() const { - vector dst_index; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->dst_index()) { - dst_index.push_back(item); - } - } - return dst_index; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const vector &workspace) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_workspace(); - for (auto &item : workspace) { - proto_msg->add_workspace(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspace() const { - vector workspace; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->workspace()) { - workspace.push_back(item); - } - } - return workspace; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspaceBytes(const vector &workspace_bytes) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_workspace_bytes(); - for (auto &item : workspace_bytes) { - proto_msg->add_workspace_bytes(item); - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetWorkspaceBytes() const { - vector workspace_bytes; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto &item : proto_msg->workspace_bytes()) { - workspace_bytes.push_back(item); - } - } - return workspace_bytes; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const vector &is_input_const) { - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_is_input_const(); - for (auto item : is_input_const) { - proto_msg->add_is_input_const(item); - } - } - // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag - auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); - if (ret != true) { - GELOGE(GRAPH_FAILED, "set is_input_const fail."); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetIsInputConst() const { - vector is_input_const; - auto proto_msg = op_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - for (auto item : proto_msg->is_input_const()) { - is_input_const.push_back(item); - } - } - return is_input_const; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, - const int &index) { - if (input_name_idx_.find(name) != input_name_idx_.end()) { - GELOGI("Restore input name index is existed. name[%s]", name.c_str()); - } - (void)input_name_idx_.insert(make_pair(name, index)); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreOutputNameIdx(const string &name, - const int &index) { - if (output_name_idx_.find(name) != output_name_idx_.end()) { - GELOGI("Restore output name index is existed. name[%s]", name.c_str()); - } - (void)output_name_idx_.insert(make_pair(name, index)); - return GRAPH_SUCCESS; -} -graphStatus OpDesc::CallInferFunc(Operator &op) { - if (infer_func_ == nullptr) { - infer_func_ = OperatorFactoryImpl::GetInferShapeFunc(GetType()); - if (infer_func_ == nullptr) { - GELOGW("%s does not have infer func.", GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - } - graphStatus graph_status = (graphStatus)infer_func_(op); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} -graphStatus OpDesc::CallInferFormatFunc(Operator &op) { - if (infer_format_func_ == nullptr) { - infer_format_func_ = OperatorFactoryImpl::GetInferFormatFunc(GetType()); - if (infer_format_func_ == nullptr) { - return DefaultInferFormat(); - } - } - return (graphStatus)infer_format_func_(op); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { - if (static_cast(index) >= subgraph_instance_names_.size()) { - return ""; - } - return subgraph_instance_names_.at(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetSubgraphInstanceNames() - const { - return subgraph_instance_names_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { - for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { - if (*iter == name) { - *iter = ""; - return; - } - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { - GELOGI("Add subgraph name is %s", name.c_str()); - auto iter = subgraph_names_to_index_.find(name); - if (iter != subgraph_names_to_index_.end()) { - GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); - return GRAPH_FAILED; - } - auto size = subgraph_names_to_index_.size(); - subgraph_names_to_index_[name] = size; - subgraph_instance_names_.resize(size + 1); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphNameIndexes() - const { - return subgraph_names_to_index_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, - const std::string &name) { - GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); - if (index >= subgraph_instance_names_.size()) { - GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); - return GRAPH_PARAM_INVALID; - } - subgraph_instance_names_[index] = name; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, - SubgraphType type) { - subgraph_ir_names_to_type_[name] = type; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphIrNames() - const { - return subgraph_ir_names_to_type_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType -OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { - auto iter = subgraph_ir_names_to_type_.find(name); - if (iter == subgraph_ir_names_to_type_.end()) { - return kSubgraphTypeEnd; - } - return iter->second; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { - for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { - if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. - continue; - } - - for (auto name_to_index : subgraph_names_to_index_) { - if (name_to_index.second != idx) { // find subgraph name. - continue; - } - - subgraph_name = name_to_index.first; - return GRAPH_SUCCESS; - } - } - - return GRAPH_PARAM_INVALID; -} - -} // namespace ge diff --git a/metadef/graph/op_imp.cc b/metadef/graph/op_imp.cc deleted file mode 100644 index 9abf242b..00000000 --- a/metadef/graph/op_imp.cc +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include "debug/ge_log.h" -#include "debug/ge_util.h" - -using namespace std; - -namespace ge { - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -BroadCastInfer(const function()>& get_in1_shape, const function()>& get_in2_shape, - const function& outShape)>& set_out_shape) { - auto x1_shape = get_in1_shape(); - auto x2_shape = get_in2_shape(); - vector y_shape; - - if (x1_shape.empty()) { - y_shape = x2_shape; - set_out_shape(y_shape); - return GRAPH_SUCCESS; - } - if (x2_shape.empty()) { - y_shape = x1_shape; - set_out_shape(y_shape); - return GRAPH_SUCCESS; - } - - int len_diff = static_cast(x1_shape.size() - x2_shape.size()); - if (len_diff >= 0) { - for (int i = 0; i < len_diff; i++) { - y_shape.push_back(x1_shape[i]); - } - int x2_shape_size = static_cast(x2_shape.size()); - for (int i = 0; i < x2_shape_size; i++) { - bool shapeFlag = - ((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); - if (shapeFlag) { - GE_LOGE("operands could not be broadcast together"); - return GRAPH_FAILED; - } - y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i])); - } - } else { - for (int i = 0; i < -len_diff; i++) { - y_shape.push_back(x2_shape[i]); - } - int x1_shape_size = static_cast(x1_shape.size()); - for (int i = 0; i < x1_shape_size; i++) { - bool shapeFlag = - ((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); - if (shapeFlag) { - GE_LOGE("operands could not be broadcast together"); - return GRAPH_FAILED; - } - y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff])); - } - } - set_out_shape(y_shape); - return GRAPH_SUCCESS; -} - -} // namespace ge diff --git a/metadef/graph/operator.cc b/metadef/graph/operator.cc deleted file mode 100644 index 21554fa1..00000000 --- a/metadef/graph/operator.cc +++ /dev/null @@ -1,1587 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/operator.h" -#include "external/graph/operator_factory.h" -#include -#include -#include -#include -#include -#include "./array_ops.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "external/graph/attr_value.h" -#include "external/graph/types.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_context.h" -#include "graph/ge_tensor.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/runtime_inference_context.h" -#include "graph/usr_types.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "utils/graph_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/tensor_adapter.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" -#include -#include -#include -#include -#include - -using std::enable_shared_from_this; -using std::make_pair; -using std::shared_ptr; -using std::string; -using std::to_string; -using std::vector; - -/*lint -save -e529 -e728*/ -/*lint -e446 -e732*/ -/*lint -e665*/ -namespace ge { -class OpIO { - public: - OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} - - ~OpIO() = default; - - string GetName() const { return name_; } - - int GetIndex() const { return index_; } - - OperatorImplPtr GetOwner() const { return owner_; } - - bool operator==(const OpIO &r_value) const { - return (this->name_ == r_value.GetName()) && (this->index_ == r_value.GetIndex()) && - (this->GetOwner() == r_value.GetOwner()); - } - - private: - string name_; - int index_; - std::shared_ptr owner_; -}; - -class TensorTypeImpl { - public: - TensorTypeImpl() = default; - ~TensorTypeImpl() = default; - - std::vector dt_vec_; -}; - -TensorType::TensorType(DataType dt) { - tensor_type_impl_ = ComGraphMakeShared(); - if (tensor_type_impl_ != nullptr) { - tensor_type_impl_->dt_vec_.push_back(dt); - } -} - -TensorType::TensorType(const std::initializer_list &types) { - tensor_type_impl_ = ComGraphMakeShared(); - if (tensor_type_impl_ != nullptr) { - tensor_type_impl_->dt_vec_ = types; - } -} - -class OperatorImpl : public std::enable_shared_from_this { - friend class GraphBuilderImpl; - friend class OpDescUtils; - - public: - explicit OperatorImpl(const string &name, const string &type) : op_desc_(ComGraphMakeShared(name, type)) { - if (op_desc_ == nullptr) { - GELOGW("OpDesc make shared failed"); - } - } - explicit OperatorImpl(const OpDescPtr &op_desc) : op_desc_(op_desc) {} - explicit OperatorImpl(ge::ConstNodePtr node) : node_(std::move(node)) { - if (node_ != nullptr && node_->GetOpDesc() != nullptr) { - op_desc_ = node_->GetOpDesc(); - } - } - ~OperatorImpl() {} - void SetInputImpl(const string &dst_name, const ge::Operator &src_oprt) { - GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); - GE_CHK_BOOL_EXEC(src_oprt.operator_impl_ != nullptr, return, "operator_impl_ is nullptr."); - GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_ != nullptr, return, "op_desc_ is nullptr."); - - auto src_op_impl = src_oprt.GetOperatorImplPtr(); - GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return, "Src impl is null."); - GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return, "Src impl's opdesc is null."); - GE_CHK_BOOL_EXEC(src_oprt.operator_impl_->op_desc_->GetOutputsSize() == 1, return, - "The source operator[%s] must has one output", - src_oprt.operator_impl_->op_desc_->GetName().c_str()) - - uint32_t src_index = 0; - string src_name = src_op_impl->op_desc_->GetOutputNameByIndex(src_index); - GE_CHK_BOOL_EXEC(!src_name.empty(), return, "Src output's name is empty."); - - OpIO out_handler(src_name, src_index, src_op_impl); - input_link_.insert(std::make_pair(dst_name, out_handler)); - - int dst_index = op_desc_->GetInputIndexByName(dst_name); - GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), - op_desc_->GetName().c_str()); - - bool is_const = false; - if (src_oprt.GetOpType() == CONSTANT) { - is_const = true; - } - auto is_input_const = op_desc_->GetIsInputConst(); - for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { - is_input_const.push_back(false); - } - - is_input_const[dst_index] = is_const; - op_desc_->SetIsInputConst(is_input_const); - - OpIO op_dst(dst_name, dst_index, shared_from_this()); - src_op_impl->UpdateLinkMapImpl(src_name, op_dst); - auto output_desc = src_op_impl->GetOutputDesc(src_name); - auto input_desc = op_desc_->GetInputDesc(dst_name); - if (input_desc.GetFormat() == FORMAT_RESERVED) { - output_desc.SetFormat(FORMAT_ND); - } else { - output_desc.SetFormat(input_desc.GetFormat()); - } - // Fix for linking opdesc - if (op_desc_->UpdateInputDesc(dst_name, output_desc) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Update inputdesc failed,dst name is %s, src name is %s", dst_name.c_str(), - src_name.c_str()); - return; - } - } - - void SetInputImpl(const string &dst_name, const ge::OutHandler &out_handler) { - GE_CHK_BOOL_EXEC(!dst_name.empty(), return, "dst name is empty"); - GE_CHK_BOOL_EXEC(out_handler != nullptr, return, "SetInputImpl faild, out_handler is nullptr."); - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return, "op_desc_ is nullptr."); - input_link_.insert(std::make_pair(dst_name, *out_handler)); - - string src_name = out_handler->GetName(); - int dst_index = op_desc_->GetInputIndexByName(dst_name); - GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), - op_desc_->GetName().c_str()); - auto out_op_impl = out_handler->GetOwner(); - GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, - "out_handler invalid. name[%s]", dst_name.c_str()); - bool is_const = false; - if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { - is_const = true; - } - auto is_input_const = op_desc_->GetIsInputConst(); - for (int i = static_cast(is_input_const.size()); i <= dst_index; ++i) { - is_input_const.push_back(false); - } - is_input_const[dst_index] = is_const; - op_desc_->SetIsInputConst(is_input_const); - - OpIO in_handler(dst_name, dst_index, shared_from_this()); - GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); - - out_op_impl->UpdateLinkMapImpl(src_name, in_handler); - auto src_output_desc = out_op_impl->GetOutputDesc(src_name); - auto dst_input_desc = op_desc_->GetInputDesc(dst_name); - if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { - src_output_desc.SetFormat(FORMAT_ND); - } else { - src_output_desc.SetFormat(dst_input_desc.GetFormat()); - } - GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, return, - "Update input desc failed,dst name is %s,src name is %s", dst_name.c_str(), - src_name.c_str()); // fix for linking opdesc - } - - void AddControlInputImp(const ge::Operator &src_oprt) { - if (src_oprt.operator_impl_ == nullptr) { - GELOGE(FAILED, "Src operator impl is nullptr"); - return; - } - for (auto &input : control_input_link_) { - if (input.lock() == src_oprt.operator_impl_) { - return; - } - } - control_input_link_.push_back(src_oprt.operator_impl_); - src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this()); - } - - graphStatus GetInputImpl(const string &dst_name, ge::OpIO &out_handler) { - auto out = input_link_.find(dst_name); - if (out == input_link_.end()) { - return GRAPH_FAILED; - } - out_handler = out->second; - return GRAPH_SUCCESS; - } - - bool InputIsSet(const string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return false, "op_desc_ is nullptr."); - return op_desc_->InputIsSet(name); - } - - string GetName() const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return string(), "op_desc_ is nullptr."); - return op_desc_->GetName(); - } - - GeTensorDesc GetInputDesc(const string &name) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - return op_desc_->GetInputDesc(name); - } - - GeTensorDesc GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - return op_desc_->GetInputDesc(index); - } - - graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GRAPH_FAILED, "op_desc_ is nullptr."); - - return op_desc_->UpdateInputDesc(name, tensor_desc); - } - - OutHandler GetOutput(const string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); - - int src_index = op_desc_->GetOutputIndexByName(name); - GE_CHK_BOOL_EXEC(src_index >= 0, return nullptr, "Find src index by name failed. name[%s]", name.c_str()); - shared_ptr output_ptr = ComGraphMakeShared(name, src_index, shared_from_this()); - if (output_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OpIO make shared failed"); - return nullptr; - } - return output_ptr; - } - - OutHandler GetOutput(uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return nullptr, "op_desc_ is nullptr."); - - string name = op_desc_->GetOutputNameByIndex(index); - if (name.empty()) { - GELOGE(GRAPH_FAILED, "Find src name by index failed. index[%u]", index); - return nullptr; - } - shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); - if (output_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OpIO make shared failed"); - return nullptr; - } - return output_ptr; - } - - GeTensorDesc GetOutputDesc(const string &name) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - - return op_desc_->GetOutputDesc(name); - } - - GeTensorDesc GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, return GeTensorDesc(), "op_desc_ is nullptr."); - - return op_desc_->GetOutputDesc(index); - } - - graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc) { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); - - auto res = op_desc_->UpdateOutputDesc(name, tensor_desc); - if (res == GRAPH_SUCCESS) { - for (auto ol : output_links_[name]) { - if (ol.GetOwner() == nullptr) { - GELOGW("%s get owner is nullptr", ol.GetName().c_str()); - continue; - } - GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), tensor_desc) == GRAPH_SUCCESS, GRAPH_FAILED, - "Could not update next operator's input %s.", ol.GetName().c_str()); - } - } - return res; - } - - size_t GetInputsSize() const { - GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); - return op_desc_->GetInputsSize(); - } - - size_t GetOutputsSize() const { - GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0); - return op_desc_->GetOutputsSize(); - } - - graphStatus SetAttr(const string &name, GeAttrValue &&attr_value) { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); - return op_desc_->SetAttr(name, std::move(attr_value)); - } - - graphStatus GetAttr(const string &name, GeAttrValue &attr_value) const { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "op_desc is nullptr."); - return op_desc_->GetAttr(name, attr_value); - } - - OpDescPtr GetOpDescImpl() const { return op_desc_; } - - void UpdateLinkMapImpl(const string &src_name, OpIO &op_dst) { - auto it_find = output_links_.find(src_name); - if (it_find == output_links_.end()) { - std::vector dsts{op_dst}; - output_links_.insert(std::make_pair(src_name, dsts)); - } else { - it_find->second.push_back(op_dst); - } - } - - Operator ToOperator() { return Operator(shared_from_this()); } - - static OpDescPtr GetOpDesc(const Operator &oprt) { - GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr); - return oprt.operator_impl_->op_desc_; - } - - void ClearOutputLinks() noexcept { output_links_.clear(); } - - void ClearInputLinks() noexcept { input_link_.clear(); } - - ge::ConstNodePtr GetNode() { return node_; } - - void SetInferenceContext(const InferenceContextPtr &inference_context) { inference_context_ = inference_context; } - - InferenceContextPtr GetInferenceContext() const { return inference_context_; } - - void SubgraphRegister(const std::string &ir_name, bool dynamic) { - op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); - } - - void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { - op_desc_->AddSubgraphName(ir_name); - subgraph_names_to_builders_[ir_name] = nullptr; - } else { - for (uint32_t i = 0; i < count; ++i) { - string key_name = ir_name + std::to_string(i); - op_desc_->AddSubgraphName(key_name); - subgraph_names_to_builders_[key_name] = nullptr; - } - } - } - - void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { - string key_name = ir_name; - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { - key_name += std::to_string(index); - } - - auto it = subgraph_names_to_builders_.find(key_name); - if (it == subgraph_names_to_builders_.end()) { - GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); - return; - } - it->second = builder; - } - - SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { - string key_name = ir_name; - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { - key_name += std::to_string(index); - } - - return GetSubgraphBuilder(key_name); - } - - SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { - auto iter = subgraph_names_to_builders_.find(name); - if (iter == subgraph_names_to_builders_.end()) { - GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); - return nullptr; - } - - return iter->second; - } - - std::vector GetSubgraphNames() const { - std::vector names; - for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { - names.emplace_back(subgraph_name_to_type.first); - } - return names; - } - - size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } - - OpDescPtr op_desc_ = nullptr; - - private: - ge::ConstNodePtr node_{nullptr}; - ge::InferenceContextPtr inference_context_; - std::map> output_links_{}; - std::map input_link_{}; - std::vector> control_input_link_{}; - std::vector> control_output_link_{}; - std::map subgraph_names_to_builders_; -}; - -// Used to manage OperatorImpl instances created by ge api. -class OperatorKeeper { - private: - OperatorKeeper() = default; - ~OperatorKeeper() { - for (const auto &iter : operators_) { - if (iter) { - iter->ClearInputLinks(); - iter->ClearOutputLinks(); - } - } - } - std::set operators_; - std::mutex mutex_; - - public: - static OperatorKeeper &GetInstance() { - static OperatorKeeper instance; - return instance; - } - void CheckInOperator(const OperatorImplPtr &op_impl) { - if (op_impl) { - std::lock_guard lock(mutex_); - operators_.insert(op_impl); - } - } - void CheckOutOperator(const OperatorImplPtr &op_impl) { - if (op_impl) { - std::lock_guard lock(mutex_); - operators_.erase(op_impl); - } - } -}; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) { - ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared(node_ptr); - if (operator_impl_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return Operator("default"); - } - return operator_impl_ptr->ToOperator(); -} - -Operator::Operator(const std::string &type) { - static uint32_t index = 0; - string name = type + "_" + std::to_string(index++); - operator_impl_ = ComGraphMakeShared(name, type); - if (operator_impl_ == nullptr) { - GELOGW("OperatorImpl make shared failed"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) { - shared_ptr operator_impl_ptr; - operator_impl_ptr = ComGraphMakeShared(op_desc); - if (operator_impl_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return Operator("default"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); - return operator_impl_ptr->ToOperator(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { - return OperatorImpl::GetOpDesc(oprt); -} - -GE_FUNC_HOST_VISIBILITY Operator::Operator(const string &name, const string &type) { - operator_impl_ = ComGraphMakeShared(name, type); - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return; - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); -} - -Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); } - -bool Operator::IsEmpty() const { - if (operator_impl_ == nullptr) { - return true; - } - return false; -} - -string Operator::GetName() const { - if (operator_impl_ != nullptr) { - return operator_impl_->GetName(); - } - return ""; -} - -GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const string &dst_name, const ge::Operator &src_oprt) { - // Describe the connection relationship between operators, no create action - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); - operator_impl_->SetInputImpl(dst_name, src_oprt); - return *this; -} - -Operator &Operator::SetInput(const string &dst_name, const ge::OutHandler &out_handler) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); - operator_impl_->SetInputImpl(dst_name, out_handler); - return *this; -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) { - auto out_handler = src_oprt.GetOutput(name); - GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); - (void)SetInput(dst_name, out_handler); - return *this; -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { - auto out_handler = src_oprt.GetOutput(index); - GE_CHK_BOOL_EXEC(out_handler != nullptr, return *this, "out_handler is nullptr."); - (void)SetInput(dst_name, out_handler); - return *this; -} - -Operator &Operator::AddControlInput(const Operator &src_oprt) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr."); - return *this; - } - operator_impl_->AddControlInputImp(src_oprt); - return *this; -} - -graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { - GE_CHECK_NOTNULL(operator_impl_); - auto node_ptr = operator_impl_->GetNode(); - if (node_ptr != nullptr) { - // For inner compute graph - auto op_desc = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto index = op_desc->GetInputIndexByName(dst_name); - auto in_data_anchor = node_ptr->GetInDataAnchor(index); - GE_CHECK_NOTNULL(in_data_anchor); - auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(out_data_anchor); - auto peer_node = out_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); - auto peer_op_desc = peer_node->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - auto peer_op_type = peer_op_desc->GetType(); - if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { - auto const_op_impl = ComGraphMakeShared(peer_node); - GE_CHECK_NOTNULL(const_op_impl); - Operator const_op(std::move(const_op_impl)); - return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); - } else if (peer_op_type == DATA) { - auto parent_node = NodeUtils::GetParentInput(peer_node); - while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { - parent_node = NodeUtils::GetParentInput(parent_node); - } - if ((parent_node != nullptr) && - ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { - auto const_op_impl = ComGraphMakeShared(parent_node); - GE_CHECK_NOTNULL(const_op_impl); - Operator const_op(std::move(const_op_impl)); - return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); - } - } - // Try get from runtime inference context - auto session_id = std::to_string(GetContext().SessionId()); - RuntimeInferenceContext *runtime_infer_ctx = nullptr; - if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { - GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); - auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); - if (ret == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - } - } else { - // For outer graph - return GetInputConstDataOut(dst_name, data); - } - auto op_name = operator_impl_->GetName(); - GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); - return GRAPH_FAILED; -} -graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { - ge::OpIO out_handle("", 0, nullptr); - GE_CHECK_NOTNULL(operator_impl_); - if (operator_impl_->GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) { - GELOGE(FAILED, "%s get input impl failed", dst_name.c_str()); - return GRAPH_FAILED; - } - if (out_handle.GetOwner() != nullptr && out_handle.GetOwner()->GetOpDescImpl() != nullptr) { - Operator const_op(out_handle.GetOwner()); - const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); - if (op_desc_impl_type == CONSTANTOP) { - return const_op.GetAttr(op::Constant::name_attr_value(), data); - } else if (op_desc_impl_type == CONSTANT) { - return const_op.GetAttr(op::Const::name_attr_value(), data); - } - } - return GRAPH_FAILED; -} - -std::shared_ptr Operator::GetNode() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetNode(); -} - -TensorDesc Operator::GetInputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); -} - -void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - operator_impl_->SetInferenceContext(inference_context); -} - -InferenceContextPtr Operator::GetInferenceContext() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetInferenceContext(); -} -TensorDesc Operator::GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); -} - -graphStatus Operator::TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - auto check = operator_impl_->InputIsSet(name); - if (check) tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); - return check ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -OutHandler Operator::GetOutput(const string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetOutput(name); -} - -OutHandler Operator::GetOutput(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); - return operator_impl_->GetOutput(index); -} - -TensorDesc Operator::GetOutputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); -} - -TensorDesc Operator::GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index)); -} - -graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -TensorDesc Operator::GetDynamicInputDesc(const string &name, uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index))); -} - -graphStatus Operator::UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateInputDesc(name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -TensorDesc Operator::GetDynamicOutputDesc(const string &name, uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index))); -} - -graphStatus Operator::UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->UpdateOutputDesc(name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::InferShapeAndType() { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); - - return operator_impl_->GetOpDescImpl()->CallInferFunc(*this); -} - -graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); - - if (!disable_common_verifier && (graphStatus)Operator::VerifyAll() == GRAPH_FAILED) { - return GRAPH_FAILED; - } else { - return (graphStatus)operator_impl_->GetOpDescImpl()->OpVerify(); - } -} - -GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); - return operator_impl_->GetInputsSize(); -} - -GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "OperatorImpl_ is nullptr"); - return operator_impl_->GetOutputsSize(); -} - -// According to op get the attrs name and type -namespace { -const std::map kAttrTypesMap = { - {GeAttrValue::VT_NONE, "VT_STRING"}, - {GeAttrValue::VT_STRING, "VT_STRING"}, - {GeAttrValue::VT_FLOAT, "VT_FLOAT"}, - {GeAttrValue::VT_BOOL, "VT_BOOL"}, - {GeAttrValue::VT_INT, "VT_INT"}, - {GeAttrValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, - {GeAttrValue::VT_TENSOR, "VT_TENSOR"}, - {GeAttrValue::VT_BYTES, "VT_BYTES"}, - {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, - {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, - {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, - {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, - {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, - {GeAttrValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, - {GeAttrValue::VT_LIST_INT, "VT_LIST_INT"}, - {GeAttrValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, - {GeAttrValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, - {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, - {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, - {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, -}; -} // namespace -const std::map Operator::GetAllAttrNamesAndTypes() const { - std::map attr_types; - - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return attr_types, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return attr_types, "GetOpDescImpl is nullptr."); - std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - - map::iterator iter; - for (iter = attr_map.begin(); iter != attr_map.end(); ++iter) { - string name = iter->first; - GeAttrValue attr_value = iter->second; - - GeAttrValue::ValueType type = attr_value.GetValueType(); - - auto iter2 = kAttrTypesMap.find(type); - if (iter2 != kAttrTypesMap.end()) { - attr_types[name] = iter2->second; - } - } - - return attr_types; -} - -void Operator::InputRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); -} - -void Operator::OptionalInputRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, - GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); -} - -void Operator::InferFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); -} - -void Operator::InferFormatFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); -} - -void Operator::VerifierFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); -} - -void Operator::OutputRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); -} - -void Operator::DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return, - "set int failed"); - (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); -} - -void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { - GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); - operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); -} - -int Operator::GetDynamicInputNum(const string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); - int num = 0; - GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num), return num, - "Get %s int failed", name.c_str()); - return num; -} - -void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, - "Set %s int failed", name.c_str()); - (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); -} - -int Operator::GetDynamicOutputNum(const string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); - int num = 0; - GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, - "Get %s int failed", name.c_str()); - return num; -} - -void Operator::RequiredAttrRegister(const string &name) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - operator_impl_->GetOpDescImpl()->AddRequiredAttr(name); -} - -graphStatus Operator::VerifyAll() { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return GRAPH_FAILED, "GetOpDescImpl is nullptr."); - - // Check all inputs defined - for (const string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) { - GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname), - GRAPH_FAILED, "operator input %s is not linked.", iname.c_str()); - vector ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims(); - for (int64_t dim : ishape) { - GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", - iname.c_str()); - } - } - // Check all attributes defined - const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) { - GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, - "operator attribute %s is empty.", name.c_str()); - } - - return GRAPH_SUCCESS; -} - -string Operator::GetOpType() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return "Data", "operator impl is nullptr."); - return OperatorImpl::GetOpDesc(*this)->GetType(); -} - -Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { - string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); - return SetInput(dynamic_dst_name, src_oprt); -} - -Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, - const std::string &name) { - string dynamic_dst_name = DYNAMIN_INPUT_NAME(dst_name, dst_index); - return SetInput(dynamic_dst_name, src_oprt, name); -} - -OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } - -#define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \ - Operator &Operator::SetAttr(const string &name, ArgType attr_value) { \ - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ - return *this; \ - } \ - if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("set attr name %s failed.", name.c_str()); \ - } \ - return *this; \ - } // lint !e665 - -#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ - graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ - return GRAPH_FAILED; \ - } \ - if (!AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("get attr name %s failed.", name.c_str()); \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } // lint !e665 - -void Operator::BreakConnect() const { - if (operator_impl_ == nullptr) { - GELOGW("operator impl is nullptr."); - return; - } - operator_impl_->ClearInputLinks(); - operator_impl_->ClearOutputLinks(); - OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_); -} - -#define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \ - void Operator::AttrRegister(const string &name, ArgType attr_value) { \ - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { \ - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); \ - return; \ - } \ - if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("reg attr name %s failed.", name.c_str()); \ - } \ - } // lint !e665 - -OP_ATTR_SET_IMP(int64_t, Int) -OP_ATTR_SET_IMP(int32_t, Int) -OP_ATTR_SET_IMP(uint32_t, Int) -OP_ATTR_GET_IMP(int64_t &, Int) -OP_ATTR_GET_IMP(int32_t &, Int) -OP_ATTR_GET_IMP(uint32_t &, Int) -OP_ATTR_SET_IMP(const vector &, ListInt) -OP_ATTR_SET_IMP(const vector &, ListInt) -OP_ATTR_SET_IMP(const vector &, ListInt) -OP_ATTR_SET_IMP(std::initializer_list &&, ListInt) -OP_ATTR_GET_IMP(vector &, ListInt) -OP_ATTR_GET_IMP(vector &, ListInt) -OP_ATTR_GET_IMP(vector &, ListInt) -OP_ATTR_GET_IMP(vector> &, ListListInt) -OP_ATTR_SET_IMP(const vector> &, ListListInt) - -OP_ATTR_SET_IMP(float, Float) -OP_ATTR_GET_IMP(float &, Float) -OP_ATTR_SET_IMP(const vector &, ListFloat) -OP_ATTR_GET_IMP(vector &, ListFloat) // lint !e665 - -OP_ATTR_SET_IMP(bool, Bool) -OP_ATTR_GET_IMP(bool &, Bool) -OP_ATTR_SET_IMP(const vector &, ListBool) -OP_ATTR_GET_IMP(vector &, ListBool) // lint !e665 - -OP_ATTR_SET_IMP(const string &, Str) -OP_ATTR_GET_IMP(string &, Str) -OP_ATTR_SET_IMP(const vector &, ListStr) -OP_ATTR_GET_IMP(vector &, ListStr) // lint !e665 - -OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) -OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) -OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) -OP_ATTR_GET_IMP(vector &, ListNamedAttrs) // lint !e665 - -OP_ATTR_REG_IMP(int64_t, Int) -OP_ATTR_REG_IMP(const vector &, ListInt) -OP_ATTR_REG_IMP(float, Float) -OP_ATTR_REG_IMP(const vector &, ListFloat) -OP_ATTR_REG_IMP(const string &, Str) -OP_ATTR_REG_IMP(const vector &, ListStr) -OP_ATTR_REG_IMP(bool, Bool) -OP_ATTR_REG_IMP(const vector &, ListBool) -OP_ATTR_REG_IMP(const vector> &, ListListInt) -OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) -OP_ATTR_REG_IMP(const vector &, ListNamedAttrs) - -#undef OP_ATTR_SET_IMP -#undef OP_ATTR_GET_IMP -#undef OP_ATTR_REG_IMP - -Operator &Operator::SetAttr(const string &name, const Tensor &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -Operator &Operator::SetAttr(const string &name, const vector &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - vector val_list; - for (const auto &item : attr_value) { - auto tensor = TensorAdapter::AsGeTensor(item); - val_list.push_back(tensor); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, Tensor &attr_value) const { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - ConstGeTensorPtr tensor; - if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - attr_value = TensorAdapter::GeTensor2Tensor(tensor); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAttr(const string &name, vector &attr_value) const { - attr_value.clear(); - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - vector val_list; - if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - for (auto &tensor : val_list) { - attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const string &name, const OpBytes &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, OpBytes &attr_value) const { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - Buffer buffer; - if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - attr_value.clear(); - if (buffer.data() == nullptr) { - GELOGE(GRAPH_FAILED, "buffer data is null."); - return GRAPH_FAILED; - } - attr_value.assign(buffer.data(), buffer.data() + buffer.size()); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const string &name, ge::AttrValue &&attrValue) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return *this, "operator impl is nullptr."); - (void)operator_impl_->SetAttr(name, std::move(attrValue.impl->geAttrValue_)); - return *this; -} - -graphStatus Operator::GetAttr(const string &name, ge::AttrValue &attrValue) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return GRAPH_FAILED, "operator impl is nullptr."); - return operator_impl_->GetAttr(name, attrValue.impl->geAttrValue_); -} - -Operator &Operator::SetAttr(const string &name, const std::vector &attr_value) { - if (operator_impl_ == nullptr || !operator_impl_->GetOpDescImpl()) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, std::vector &attr_value) const { - attr_value.clear(); - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const string &name, const ge::DataType &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const string &name, ge::DataType &attr_value) const { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("get attr name %s failed.", name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -void Operator::AttrRegister(const string &name, const std::vector &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const ge::DataType &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("set attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const Tensor &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - auto tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("reg attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const vector &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - vector val_list; - for (const auto &item : attr_value) { - val_list.push_back(TensorAdapter::AsGeTensor(item)); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("reg attr name %s failed.", name.c_str()); - } -} - -void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { - if (operator_impl_ == nullptr || operator_impl_->GetOpDescImpl() == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("reg attr name %s failed.", name.c_str()); - } -} - -void Operator::SubgraphRegister(const std::string &name, bool dynamic) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); -} - -void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); - return; - } - operator_impl_->SubgraphCountRegister(name, count); -} - -void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); - return; - } - operator_impl_->SetSubgraphBuilder(ir_name, index, builder); -} - -std::vector Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } - -SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr."); - return nullptr; - } - return operator_impl_->GetSubgraphBuilder(ir_name, index); -} - -SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { - return GetDynamicSubgraphBuilder(ir_name, 0); -} - -Graph Operator::GetSubgraph(const string &name) const { - if (operator_impl_ == nullptr) { - GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); - return Graph(""); - } - auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); - if (op_desc == nullptr) { - GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); - return Graph(""); - } - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - auto iter = subgraph_names_to_index.find(name); - if (iter == subgraph_names_to_index.end()) { - GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); - return Graph(""); - } - auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); - if (subgraph_instance_name.empty()) { - GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); - return Graph(""); - } - - auto node = operator_impl_->GetNode(); - if (node == nullptr) { - GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); - return Graph(""); - } - auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - if (root_graph == nullptr) { - GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); - return Graph(""); - } - auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); - if (subgraph == nullptr) { - GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), - iter->second, subgraph_instance_name.c_str()); - return Graph(""); - } - return GraphUtils::CreateGraphFromComputeGraph(subgraph); -} - -Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { - return GetSubgraph(name + std::to_string(index)); -} - -size_t Operator::GetSubgraphNamesCount() const { - if (operator_impl_ == nullptr) { - GE_LOGE("Failed to get subgraph names count, the operator impl is null"); - return 0; - } - return operator_impl_->GetSubgraphNamesCount(); -} - -class GraphBuilderImpl { - public: - explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared(name)) { - if (graph_ == nullptr) { - GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); - return; - } - } - - ~GraphBuilderImpl() {} - - ComputeGraphPtr BuildGraph(const std::vector &inputs) { - std::vector vec_inputs; - for (auto &it : inputs) { - auto src_op_impl = it.operator_impl_; - GE_CHK_BOOL_EXEC(src_op_impl != nullptr, return nullptr, "Operator Impl is null."); - GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, return nullptr, "Operator impl's opdesc is null."); - - string type = src_op_impl->op_desc_->GetType(); - auto node_op = ge::OperatorFactory::CreateOperator("node_op", type); - auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - node_op.BreakConnect(); - - GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "tensor_desc is null."); - if ((tensor_desc->GetInputsSize() == 0 && tensor_desc->GetOutputsSize() > 0) || type == DATA || - type == VARIABLE || type == INITDATA || type == GETNEXT) { - vec_inputs.push_back(it.operator_impl_); - } else { - GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); - } - } - GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, - "User Input do not include operator such as " - "Data, Variable operator or operator that has output but no input."); - auto ret = WalkAllOperators(vec_inputs); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); - - ret = AddEdge(); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "AddEdge failed."); - - return graph_; - } - - const std::map &GetAllNodesInfo() const { return all_nodes_info_; } - - private: - graphStatus WalkAllOperators(const std::vector &vec_ops) { - GE_CHK_BOOL_EXEC(graph_ != nullptr, return GRAPH_FAILED, "graph_ is null.") - std::queue> que; - que.push(vec_ops); - while (!que.empty()) { - auto vec_tem = que.front(); - que.pop(); - for (const auto &op_impl : vec_tem) { - GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") - GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, - "This node %s has created.", op_impl->GetName().c_str()) - auto node_ptr = graph_->AddNode(op_impl->op_desc_); - GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); - all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); - - auto &out_links = op_impl->output_links_; - std::vector vec_op_forward{}; - for (const auto &out_link : out_links) { - for (const auto &op_forward : out_link.second) { - vec_op_forward.push_back(op_forward.GetOwner()); - } - } - - auto &out_control_links = op_impl->control_output_link_; - for (const auto &out_link : out_control_links) { - vec_op_forward.push_back(out_link.lock()); - } - que.push(vec_op_forward); - - auto &in_links = op_impl->input_link_; - std::vector vec_op_back_forward{}; - for (const auto &in_link : in_links) { - vec_op_back_forward.push_back(in_link.second.GetOwner()); - } - - auto &in_control_links = op_impl->control_input_link_; - for (const auto &in_link : in_control_links) { - vec_op_back_forward.push_back(in_link.lock()); - } - que.push(vec_op_back_forward); - - if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } - return MoveSubgraphToRoot(graph_); - } - - graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { - const string name = node->GetName(); - for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { - const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); - GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); - - Graph graph = builder(); // Build subgraph from user define builder. - const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); - GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); - - subgraph->SetParentNode(node); - subgraph->SetParentGraph(graph_); - if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; - } - - graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { - const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); - if (root_graph == nullptr) { - GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); - return GRAPH_FAILED; - } - - if (root_graph == graph) { - auto subgraphs = graph->GetAllSubgraphs(); - for (auto &subgraph : subgraphs) { - if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } else { - auto subgraphs = graph->GetAllSubgraphs(); - for (auto &subgraph : subgraphs) { - if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - graph->RemoveSubgraph(subgraph->GetName()); - if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; - } - - graphStatus AddEdge() { - for (const auto &node_info : all_nodes_info_) { - auto src_op_impl_ptr = node_info.first; - auto src_node_ptr = node_info.second; - - GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); - auto out_links = src_op_impl_ptr->output_links_; - GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, - "Src operator impl's op_desc is null."); - auto &op_desc = src_op_impl_ptr->op_desc_; - GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - for (const auto &out : out_links) { - auto src_idx = op_desc->GetOutputIndexByName(out.first); - GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); - - auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx); - GE_CHK_BOOL_EXEC(src_anchor != nullptr, return GRAPH_FAILED, "GetOutDataAnchor failed."); - - for (const auto &dst_opio : out.second) { - auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); - GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); - - GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); - - auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); - GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); - - auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, - "from node[%s][%d] to node[%s][%d]AddEdge failed.", src_node_ptr->GetName().c_str(), - src_anchor->GetIdx(), dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); - } - } - auto out_control_anchor = src_node_ptr->GetOutControlAnchor(); - for (const auto &control_out : src_op_impl_ptr->control_output_link_) { - auto dst_node_info = all_nodes_info_.find(control_out.lock()); - if (dst_node_info == all_nodes_info_.end()) { - GELOGE(GRAPH_FAILED, "Find Dst node failed."); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); - auto in_control_anchor = dst_node_info->second->GetInControlAnchor(); - auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(ret, "AddEdge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), - dst_node_info->second->GetType().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; - } - - ComputeGraphPtr graph_ = nullptr; - std::map all_nodes_info_{}; -}; - -inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { - for (const auto &graph : compute_graph->GetAllSubgraphs()) { - std::set node_names; - for (auto const &node : graph->GetDirectNode()) { - auto result = node_names.insert(node->GetName()); - if (!result.second) { - GELOGE(GRAPH_FAILED, "graph %s has same name node%s", graph->GetName().c_str(), node->GetName().c_str()); - return true; - } - } - } - - std::set node_names; - for (auto const &node : compute_graph->GetDirectNode()) { - auto result = node_names.insert(node->GetName()); - if (!result.second) { - GELOGE(GRAPH_FAILED, "graph %s has same name node%s", compute_graph->GetName().c_str(), node->GetName().c_str()); - return true; - } - } - return false; -} - -ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector &inputs) { - auto graph_builder_impl = GraphBuilderImpl(name); - ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); - GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); - compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); - if (HasSameNameNode(compute_graph)) { - GELOGW("Compute do not allow has same name nodes."); - compute_graph = nullptr; - } - - return compute_graph; -} - -void GraphUtils::BreakConnect(const std::map &all_nodes_infos) { - for (const auto &it : all_nodes_infos) { - OperatorImplPtr op_impl = it.first; - if (op_impl == nullptr) { - GELOGW("operator impl is nullptr."); - continue; - } - op_impl->ClearOutputLinks(); - op_impl->ClearInputLinks(); - OperatorKeeper::GetInstance().CheckOutOperator(op_impl); - } -} -} // namespace ge -/*lint +e446 +e732*/ -/*lint +e665*/ diff --git a/metadef/graph/operator_factory.cc b/metadef/graph/operator_factory.cc deleted file mode 100644 index 43d61a7c..00000000 --- a/metadef/graph/operator_factory.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/operator_factory_impl.h" -#include "debug/ge_log.h" - -namespace ge { -Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) { - return OperatorFactoryImpl::CreateOperator(operator_name, operator_type); -} - -graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { - return OperatorFactoryImpl::GetOpsTypeList(all_ops); -} - -bool OperatorFactory::IsExistOp(const string &operator_type) { return OperatorFactoryImpl::IsExistOp(operator_type); } - -OperatorCreatorRegister::OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator) { - (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator); -} - -InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type, - const InferShapeFunc &infer_shape_func) { - (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func); -} - -InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type, - const InferFormatFunc &infer_format_func) { - (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func); -} - -VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) { - (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func); -} -} // namespace ge diff --git a/metadef/graph/operator_factory_impl.cc b/metadef/graph/operator_factory_impl.cc deleted file mode 100644 index 026a85bc..00000000 --- a/metadef/graph/operator_factory_impl.cc +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/operator_factory_impl.h" -#include "debug/ge_log.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -shared_ptr> OperatorFactoryImpl::operator_creators_; -shared_ptr> OperatorFactoryImpl::operator_infershape_funcs_; -shared_ptr> OperatorFactoryImpl::operator_inferformat_funcs_; -shared_ptr> OperatorFactoryImpl::operator_verify_funcs_; - -Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) { - if (operator_creators_ == nullptr) { - return Operator(); - } - auto it = operator_creators_->find(operator_type); - if (it == operator_creators_->end()) { - GELOGW("no OpProto of [%s] registered", operator_type.c_str()); - return Operator(); - } - return it->second(operator_name); -} - -graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector &all_ops) { - all_ops.clear(); - if (operator_creators_ != nullptr) { - for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) { - all_ops.emplace_back(it->first); - } - } else { - GELOGE(GRAPH_FAILED, "no operator creators found"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OperatorFactoryImpl::IsExistOp(const string &operator_type) { - if (operator_creators_ == nullptr) { - return false; - } - auto it = operator_creators_->find(operator_type); - if (it == operator_creators_->end()) { - return false; - } - return true; -} - -InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) { - if (operator_infershape_funcs_ == nullptr) { - return nullptr; - } - auto it = operator_infershape_funcs_->find(operator_type); - if (it == operator_infershape_funcs_->end()) { - return nullptr; - } - return it->second; -} - -InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) { - if (operator_inferformat_funcs_ == nullptr) { - GELOGI("operator_inferformat_funcs_ is null"); - return nullptr; - } - auto it = operator_inferformat_funcs_->find(operator_type); - if (it == operator_inferformat_funcs_->end()) { - return nullptr; - } - return it->second; -} - -VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) { - if (operator_verify_funcs_ == nullptr) { - return nullptr; - } - auto it = operator_verify_funcs_->find(operator_type); - if (it == operator_verify_funcs_->end()) { - return nullptr; - } - return it->second; -} - -graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { - if (operator_creators_ == nullptr) { - operator_creators_.reset(new (std::nothrow) std::map()); - } - auto it = operator_creators_->find(operator_type); - if (it != operator_creators_->end()) { - return GRAPH_FAILED; - } - (void)operator_creators_->emplace(operator_type, op_creator); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type, - InferShapeFunc const infer_shape_func) { - if (operator_infershape_funcs_ == nullptr) { - GELOGI("operator_infershape_funcs_ init"); - operator_infershape_funcs_.reset(new (std::nothrow) std::map()); - } - auto it = operator_infershape_funcs_->find(operator_type); - if (it != operator_infershape_funcs_->end()) { - return GRAPH_FAILED; - } - (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type, - InferFormatFunc const infer_format_func) { - if (operator_inferformat_funcs_ == nullptr) { - GELOGI("operator_inferformat_funcs_ init"); - operator_inferformat_funcs_.reset(new (std::nothrow) std::map()); - } - auto it = operator_inferformat_funcs_->find(operator_type); - if (it != operator_inferformat_funcs_->end()) { - return GRAPH_FAILED; - } - (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) { - if (operator_verify_funcs_ == nullptr) { - GELOGI("operator_verify_funcs_ init"); - operator_verify_funcs_.reset(new (std::nothrow) std::map()); - } - auto it = operator_verify_funcs_->find(operator_type); - if (it != operator_verify_funcs_->end()) { - return GRAPH_FAILED; - } - (void)operator_verify_funcs_->emplace(operator_type, verify_func); - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/metadef/graph/opsproto/opsproto_manager.cc b/metadef/graph/opsproto/opsproto_manager.cc deleted file mode 100644 index d482715b..00000000 --- a/metadef/graph/opsproto/opsproto_manager.cc +++ /dev/null @@ -1,187 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/opsproto_manager.h" -#include -#include -#include -#include -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_log.h" - -namespace ge { -OpsProtoManager *OpsProtoManager::Instance() { - static OpsProtoManager instance; - return &instance; -} - -bool OpsProtoManager::Initialize(const std::map &options) { - std::lock_guard lock(mutex_); - - if (is_init_) { - GELOGI("OpsProtoManager is already initialized."); - return true; - } - - /*lint -e1561*/ - auto proto_iter = options.find("ge.opsProtoLibPath"); - /*lint +e1561*/ - if (proto_iter == options.end()) { - GELOGW("ge.opsProtoLibPath option not set, return."); - return false; - } - - pluginPath_ = proto_iter->second; - LoadOpsProtoPluginSo(pluginPath_); - - is_init_ = true; - - return true; -} - -void OpsProtoManager::Finalize() { - std::lock_guard lock(mutex_); - - if (!is_init_) { - GELOGI("OpsProtoManager is not initialized."); - return; - } - - for (auto handle : handles_) { - if (handle != nullptr) { - if (dlclose(handle) != 0) { - GELOGW("failed to close handle, message: %s", dlerror()); - continue; - } - GELOGI("close opsprotomanager handler success"); - } else { - GELOGW("close opsprotomanager handler failure, handler is nullptr"); - } - } - - is_init_ = false; -} - -static std::vector Split(const std::string &str, char delim) { - std::vector elems; - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream ss(str); - std::string item; - - while (getline(ss, item, delim)) { - elems.push_back(item); - } - - auto str_size = str.size(); - if (str_size > 0 && str[str_size - 1] == delim) { - elems.emplace_back(""); - } - - return elems; -} - -static void FindParserSo(const std::string &path, std::vector &file_list) { - // Lib plugin path not exist - if (path.empty()) { - GELOGI("realPath is empty"); - return; - } - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path.size() >= PATH_MAX, return, "path is invalid"); - - char resolved_path[PATH_MAX] = {0}; - - // Nullptr is returned when the path does not exist or there is no permission - // Return absolute path when path is accessible - if (realpath(path.c_str(), resolved_path) == nullptr) { - GELOGW("the path [%s] not exsit.", path.c_str()); - return; - } - - struct dirent *dent = nullptr; - DIR *dir = opendir(resolved_path); - // Lib plugin path not exist - if (dir == nullptr) { - GELOGW("Open directory %s failed,maybe it is not exit or not a dir", resolved_path); - return; - } - - while ((dent = readdir(dir)) != nullptr) { - if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) { - continue; - } - std::string name = dent->d_name; - std::string full_name = path + "/" + name; - const std::string so_suff = ".so"; - - if (dent->d_type != DT_DIR && name.size() >= so_suff.size() && - name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { - file_list.push_back(full_name); - GELOGI("OpsProtoManager Parse full name = %s \n", full_name.c_str()); - } - } - if (closedir(dir) != 0) { - GELOGW("close dir fail."); - } -} - -static void GetPluginSoFileList(const std::string &path, std::vector &file_list) { - // Support multi lib directory with ":" as delimiter - std::vector v_path = Split(path, ':'); - - for (size_t i = 0; i < v_path.size(); ++i) { - FindParserSo(v_path[i], file_list); - GELOGI("OpsProtoManager full name = %s", v_path[i].c_str()); - } -} - -void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { - if (path.empty()) { - GELOGE(GRAPH_FAILED, "filePath is invalid. please check your text file %s.", path.c_str()); - return; - } - std::vector file_list; - - // If there is .so file in the lib path - GetPluginSoFileList(path, file_list); - - // Not found any .so file in the lib path - if (file_list.empty()) { - GELOGE(GRAPH_FAILED, "OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str()); - return; - } - // Warning message - GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); - - // Load .so file - for (auto elem : file_list) { - void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); - if (handle == nullptr) { - GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); - continue; - } else { - // Close dl when the program exist, not close here - GELOGI("OpsProtoManager plugin load %s success.", elem.c_str()); - handles_.push_back(handle); - } - } -} -} // namespace ge diff --git a/metadef/graph/option/ge_context.cc b/metadef/graph/option/ge_context.cc deleted file mode 100644 index 421e0aff..00000000 --- a/metadef/graph/option/ge_context.cc +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "./ge_context.h" -#include "./ge_global_options.h" -#include "./ge_local_context.h" -#include "framework/common/ge_types.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -namespace { -const int64_t kMinTrainingTraceJobId = 256; -const int kDecimal = 10; -const char *kHostExecPlacement = "HOST"; -} // namespace -GEContext &GetContext() { - static GEContext ge_context{}; - return ge_context; -} - -graphStatus GEContext::GetOption(const std::string &key, std::string &option) { - return GetThreadLocalContext().GetOption(key, option); -} - -bool GEContext::GetHostExecFlag() { - std::string exec_placement; - if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) { - GELOGW("get option OPTION_EXEC_PLACEMENT failed."); - return false; - } - GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); - return exec_placement == kHostExecPlacement; -} - -std::map &GetMutableGlobalOptions() { - static std::map global_options{}; - return global_options; -} - -void GEContext::Init() { - string session_id; - (void)GetOption("ge.exec.sessionId", session_id); - try { - session_id_ = static_cast(std::stoi(session_id.c_str())); - } catch (std::invalid_argument &) { - GELOGW("%s transform to int failed.", session_id.c_str()); - } catch (std::out_of_range &) { - GELOGW("%s transform to int failed.", session_id.c_str()); - } - - string device_id; - (void)GetOption("ge.exec.deviceId", device_id); - try { - device_id_ = static_cast(std::stoi(device_id.c_str())); - } catch (std::invalid_argument &) { - GELOGW("%s transform to int failed.", device_id.c_str()); - } catch (std::out_of_range &) { - GELOGW("%s transform to int failed.", device_id.c_str()); - } - - string job_id; - (void)GetOption("ge.exec.jobId", job_id); - std::string s_job_id = ""; - for (auto c : job_id) { - if (c >= '0' && c <= '9') { - s_job_id += c; - } - } - if (s_job_id == "") { - trace_id_ = kMinTrainingTraceJobId; - return; - } - int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal); - if (d_job_id < kMinTrainingTraceJobId) { - trace_id_ = d_job_id + kMinTrainingTraceJobId; - } else { - trace_id_ = d_job_id; - } -} - -uint64_t GEContext::SessionId() { return session_id_; } - -uint32_t GEContext::DeviceId() { return device_id_; } - -uint64_t GEContext::TraceId() { return trace_id_; } - -void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } - -void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } - -} // namespace ge diff --git a/metadef/graph/option/ge_local_context.cc b/metadef/graph/option/ge_local_context.cc deleted file mode 100644 index 82b1cb01..00000000 --- a/metadef/graph/option/ge_local_context.cc +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "./ge_local_context.h" -#include - -namespace ge { -namespace { -thread_local GEThreadLocalContext thread_context; -} - -GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } - -graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { - auto graph_iter = graph_options_.find(key); - if (graph_iter != graph_options_.end()) { - option = graph_iter->second; - return GRAPH_SUCCESS; - } - auto session_iter = session_options_.find(key); - if (session_iter != session_options_.end()) { - option = session_iter->second; - return GRAPH_SUCCESS; - } - auto global_iter = global_options_.find(key); - if (global_iter != global_options_.end()) { - option = global_iter->second; - return GRAPH_SUCCESS; - } - return GRAPH_PARAM_INVALID; -} - -void GEThreadLocalContext::SetGlobalOption(map options_map) { - global_options_.clear(); - global_options_ = std::move(options_map); -} - -void GEThreadLocalContext::SetSessionOption(map options_map) { - session_options_.clear(); - session_options_ = std::move(options_map); -} - -void GEThreadLocalContext::SetGraphOption(map options_map) { - graph_options_.clear(); - graph_options_ = std::move(options_map); -} -} // namespace ge diff --git a/metadef/graph/ref_relation.cc b/metadef/graph/ref_relation.cc deleted file mode 100644 index 9a9f66ba..00000000 --- a/metadef/graph/ref_relation.cc +++ /dev/null @@ -1,455 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/ref_relation.h" - -#include -#include - -#include "utils/mem_utils.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "debug/ge_attr_define.h" -#include "graph/ge_error_codes.h" -#include "graph/utils/graph_utils.h" -#include "framework/common/debug/ge_log.h" - -using namespace std; -using namespace ge; -namespace ge { -namespace { -const char *kRefIndex = "_parent_node_index"; -const string kWhile = "While"; -const string kIf = "If"; -const string kCase = "Case"; - -const uint16_t kMaxElementNum = 100; - -std::unordered_set function_op = {kWhile, kIf, kCase}; -} // namespace - -/* Impl */ -class RefRelations::Impl { - public: - graphStatus LookUpRefRelations(const RefCell &key, unordered_set &result) { - unsigned long number = static_cast(reinterpret_cast(key.node.get())); - std::string lookup_key = - key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); - auto iter = look_up_table_.find(lookup_key); - if (iter != look_up_table_.end()) { - for (auto &c : iter->second) { - result.insert(c); - } - return GRAPH_SUCCESS; - } - GELOGW("can not find any relations! key value is %s", lookup_key.c_str()); - return GRAPH_SUCCESS; - }; - graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); - graphStatus Clear() { - GELOGD("Start clear boundary reflections between main graph and sub graph!"); - look_up_table_.clear(); - values_.clear(); - return GRAPH_SUCCESS; - }; - - private: - graphStatus BuildLookUpTables(); - graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, - vector> &node_refs); - graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, - vector> &node_refs); - graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, - const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, - vector> &node_refs); - void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, - vector &netoutput_nodes, const std::vector &sub_graph_names, - const std::string &node_type); - - graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); - graphStatus ProcessSubgraphDataNodes(vector &data_nodes, vector> &classed_data_nodes); - graphStatus ProcessSubgraphNetoutput(const vector &netoutput_nodes, - vector>> &classed_netoutput_nodes); - - std::unordered_map> look_up_table_; - std::vector>> values_; -}; - -// Node Level -graphStatus RefRelations::Impl::BuildRefRelationsForBranch( - const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, vector> &node_refs) { - GELOGD("Enter BuildRefRelationsForBranch!"); - - size_t ref_i = 0; - for (const auto &ref_i_data_nodes : classed_data_nodes) { - vector in_ref_i_all_refs; - RefCell cell_root; - cell_root.node_name = root_node->GetName(); - cell_root.node = root_node; - cell_root.in_out = NODE_IN; - cell_root.in_out_idx = ref_i; - in_ref_i_all_refs.emplace_back(cell_root); - for (const auto &data : ref_i_data_nodes) { - RefCell cell_in; - RefCell cell_out; - cell_in.node_name = data->GetName(); - cell_in.node = data; - cell_in.in_out = NODE_IN; - cell_in.in_out_idx = 0; - cell_out.node_name = data->GetName(); - cell_out.node = data; - cell_out.in_out = NODE_OUT; - cell_out.in_out_idx = 0; - in_ref_i_all_refs.emplace_back(cell_in); - in_ref_i_all_refs.emplace_back(cell_out); - } - node_refs.emplace_back(in_ref_i_all_refs); - ref_i++; - } - - size_t ref_o = 0; - for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { - vector out_ref_i_all_refs; - RefCell cell_root; - cell_root.node_name = root_node->GetName(); - cell_root.node = root_node; - cell_root.in_out = NODE_OUT; - cell_root.in_out_idx = ref_o; - out_ref_i_all_refs.emplace_back(cell_root); - for (const auto &ele : ref_o_net_nodes) { - RefCell cell_netoutput_in; - cell_netoutput_in.node_name = (ele.first)->GetName(); - cell_netoutput_in.node = ele.first; - cell_netoutput_in.in_out = NODE_IN; - cell_netoutput_in.in_out_idx = ele.second; - out_ref_i_all_refs.emplace_back(cell_netoutput_in); - } - node_refs.emplace_back(out_ref_i_all_refs); - ref_o++; - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildLookUpTables() { - GELOGD("start to build look up table!"); - for (size_t i = 0; i < values_.size(); i++) { - vector> &val = values_[i]; - for (const auto &ele : val) { - for (const auto &ref_cell : ele) { - string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + - std::to_string(static_cast(reinterpret_cast(ref_cell.node.get()))); - look_up_table_[key] = ele; - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildRefRelationsForWhile( - const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, vector> &node_refs) { - GELOGD("Enter BuildRefRelations for while op!"); - // data_nodes has been sorted - // for while, input num must be same as output num - auto input_num = root_node->GetAllInDataAnchorsSize(); - NodePtr netoutput = nullptr; - - size_t ref_i = 0; - while (ref_i < input_num) { - auto &ref_i_data_nodes = classed_data_nodes[ref_i]; - auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; - - vector ref_i_all_refs; - RefCell cell_root_i; - RefCell cell_root_o; - cell_root_i.node_name = root_node->GetName(); - cell_root_i.node = root_node; - cell_root_i.in_out = NODE_IN; - cell_root_i.in_out_idx = ref_i; - ref_i_all_refs.emplace_back(cell_root_i); - cell_root_o.node_name = root_node->GetName(); - cell_root_o.node = root_node; - cell_root_o.in_out = NODE_OUT; - cell_root_o.in_out_idx = ref_i; - ref_i_all_refs.emplace_back(cell_root_o); - for (const auto &data : ref_i_data_nodes) { - RefCell cell_in; - RefCell cell_out; - cell_in.node_name = data->GetName(); - cell_in.node = data; - cell_in.in_out = NODE_IN; - cell_in.in_out_idx = 0; - cell_out.node_name = data->GetName(); - cell_out.node = data; - cell_out.in_out = NODE_OUT; - cell_out.in_out_idx = 0; - ref_i_all_refs.emplace_back(cell_in); - ref_i_all_refs.emplace_back(cell_out); - } - - for (const auto &ele : ref_i_net_nodes) { - RefCell cell_netoutput_in; - RefCell cell_netoutput_out; - cell_netoutput_in.node_name = (ele.first)->GetName(); - cell_netoutput_in.node = ele.first; - cell_netoutput_in.in_out = NODE_IN; - cell_netoutput_in.in_out_idx = ele.second; - ref_i_all_refs.emplace_back(cell_netoutput_in); - netoutput = ele.first; - } - node_refs.emplace_back(ref_i_all_refs); - ref_i++; - } - /* There exist scene like the follows, it means data0 data1 netoutput 0'th - * and 1'th tensor should be the same addr. - * Data0 Data1 - * \/ - * /\ - * netoutput - */ - if (netoutput == nullptr) { - return GRAPH_SUCCESS; - } - for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { - auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - continue; - } - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { - GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); - continue; - } - if (peer_out_data_node->GetType() != DATA) { - continue; - } - auto in_data_anchor_idx = in_anchor->GetIdx(); - auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - int ref_d = 0; - int ref_n = 0; - (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); - (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); - - node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); - node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); - } - - return GRAPH_SUCCESS; -} -// build ref relations according to diff func op type -graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( - const NodePtr &root_node, const vector> &classed_data_nodes, - const vector>> &classed_netoutput_nodes, vector> &node_refs) { - // data_nodes has been sorted - auto node_type = root_node->GetType(); - - auto status = GRAPH_SUCCESS; - if (node_type != kWhile) { - status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } else { - status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } - return status; -} - -void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, - vector &netoutput_nodes, - const std::vector &sub_graph_names, - const std::string &node_type) { - int sub_graph_idx = 0; - for (const auto &name : sub_graph_names) { - auto sub_graph = root_graph.GetSubgraph(name); - if (sub_graph == nullptr) { - GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); - continue; - } - for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { - auto sub_graph_node_type = sub_graph_node->GetType(); - - if (sub_graph_node_type == DATA) { - data_nodes.emplace_back(sub_graph_node); - } else if (sub_graph_node_type == NETOUTPUT) { - // if while, the first subgraph must be cond subgraph. - // There is no meaning for refs ,so continue - if (node_type == kWhile && sub_graph_idx == 0) { - continue; - } - netoutput_nodes.emplace_back(sub_graph_node); - } - continue; - } - sub_graph_idx++; - } -} - -graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { - auto parent_graph_ptr = graph.GetParentGraph(); - if (parent_graph_ptr == nullptr) { - root_graph = graph; - return GRAPH_SUCCESS; - } - auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); - if (root_graph_ptr == nullptr) { - GE_LOGE("Get null root graph"); - return GRAPH_PARAM_INVALID; - } - root_graph = *root_graph_ptr; - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_nodes, - vector> &classed_data_nodes) { - GELOGD("start to process subgraph data nodes!"); - int max_ref_idx = 0; - for (const auto &e : data_nodes) { - int i; - bool is_exist = true; - is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); - if (!is_exist) { - GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); - return GRAPH_FAILED; - } - max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; - } - - while (!data_nodes.empty()) { - auto data = data_nodes.back(); - data_nodes.pop_back(); - int ref_idx = 0; - (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); - if (ref_idx >= static_cast(classed_data_nodes.size())) { - return GRAPH_FAILED; - } - classed_data_nodes[ref_idx].emplace_back(data); - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( - const vector &netoutput_nodes, vector>> &classed_netoutput_nodes) { - GELOGD("[RefRelations]Start to process subgraph netoutput!"); - for (const auto &sub_netoutput_node : netoutput_nodes) { - auto op_desc = sub_netoutput_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { - auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); - if (in_desc == nullptr) { - GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", - sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); - return GRAPH_FAILED; - } - int ref_o; - if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { - if (ref_o >= static_cast(classed_netoutput_nodes.size())) { - return GRAPH_FAILED; - } - classed_netoutput_nodes[ref_o].emplace_back( - std::pair({sub_netoutput_node, static_cast(in_data_anchor->GetIdx())})); - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { - GELOGD("Start to build ref relations!"); - /* First Step: Get root graph */ - ge::ComputeGraph &root_graph = graph; - auto status = GetRootGraph(graph, root_graph); - if (status != GRAPH_SUCCESS) { - return status; - } - - for (const auto &node : graph.GetAllNodes()) { - auto node_type = node->GetType(); - std::vector ref_nodes; - auto op_desc = node->GetOpDesc(); - auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - continue; - } - vector data_nodes; - vector netoutput_nodes; - // Get data and netoutput of sub_graph - GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); - size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; - vector> classed_data_nodes(max_elem_num); // according to ref_idx - vector>> classed_netoutput_nodes(max_elem_num); // according to ref_idx - status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); - return status; - } - - // for netoutput - // check netoutput - // here main graph output number must be the same as every sub_graph netoutput node - // key: netoutput node_ptr , - status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "process netoutput failed!"); - return status; - } - - vector> node_refs; - status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); - return status; - } - if (!node_refs.empty()) { - values_.push_back(node_refs); - } - } - /* Seconde Step: generate map */ - status = BuildLookUpTables(); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "Build look up tables failed!"); - return status; - } - return GRAPH_SUCCESS; -} - -/* Ref Relations Interface */ -RefRelations::RefRelations() { - impl_ = MakeShared(); - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "MakeShared failed!"); - return; - } -} - -graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set &result) { - GE_CHECK_NOTNULL(impl_); - return impl_->LookUpRefRelations(key, result); -} - -graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { - GE_CHECK_NOTNULL(impl_); - return impl_->BuildRefRelations(root_graph); -} - -graphStatus RefRelations::Clear() { - GE_CHECK_NOTNULL(impl_); - return impl_->Clear(); -} -} // namespace ge \ No newline at end of file diff --git a/metadef/graph/runtime_inference_context.cc b/metadef/graph/runtime_inference_context.cc deleted file mode 100644 index 95068481..00000000 --- a/metadef/graph/runtime_inference_context.cc +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/runtime_inference_context.h" -#include -#include "framework/common/debug/ge_log.h" - -namespace ge { -std::map> RuntimeInferenceContext::contexts_; -std::mutex RuntimeInferenceContext::ctx_mu_; - -graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { - GELOGI("To create context. session id = %s", context_id.c_str()); - auto ctx = std::unique_ptr(new (std::nothrow) RuntimeInferenceContext()); - if (ctx == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str()); - return GRAPH_FAILED; - } - - std::lock_guard lk(ctx_mu_); - auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); - if (!emplace_ret.second) { - GELOGE(GRAPH_FAILED, "Old context not destroyed"); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { - GELOGI("To destroy context. session id = %s", context_id.c_str()); - std::lock_guard lk(ctx_mu_); - contexts_.erase(context_id); -} - -graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { - std::lock_guard lk(ctx_mu_); - auto it = contexts_.find(context_id); - if (it != contexts_.end()) { - *ctx = it->second.get(); - return GRAPH_SUCCESS; - } - - GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); - return GRAPH_FAILED; -} - -graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { - std::lock_guard lk(mu_); - auto &output_tensors = tensors_[node_id]; - if (static_cast(output_id) >= output_tensors.size()) { - output_tensors.resize(output_id + 1); - } - - GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); - output_tensors[output_id] = std::move(tensor); - return GRAPH_SUCCESS; -} - -graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { - if (output_id < 0) { - GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); - return GRAPH_PARAM_INVALID; - } - - std::lock_guard lk(mu_); - auto iter = tensors_.find(node_id); - if (iter == tensors_.end()) { - GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); - return INTERNAL_ERROR; - } - - auto &output_tensors = iter->second; - if (static_cast(output_id) >= output_tensors.size()) { - GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); - return GRAPH_FAILED; - } - - GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); - tensor = output_tensors[output_id]; - return GRAPH_SUCCESS; -} -} // namespace ge \ No newline at end of file diff --git a/metadef/graph/shape_refiner.cc b/metadef/graph/shape_refiner.cc deleted file mode 100644 index 17423da4..00000000 --- a/metadef/graph/shape_refiner.cc +++ /dev/null @@ -1,688 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/shape_refiner.h" - -#include -#include -#include -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" - -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "external/graph/operator.h" -#include "external/graph/operator_factory.h" -#include "framework/common/debug/ge_log.h" -#include "graph/compute_graph.h" -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace ge { -namespace { -const uint32_t kWhileBodySubGraphIdx = 1; - -graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { - GELOGD("Enter reverse brush while body subgraph process!"); - - auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); - if (sub_graph_body == nullptr) { - GELOGE(GRAPH_FAILED, "Get while body graph failed!"); - return GRAPH_FAILED; - } - - for (const auto &node_sub : sub_graph_body->GetAllNodes()) { - for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { - auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); - GE_IF_BOOL_EXEC(input_desc == nullptr, - GELOGW("Get null input by index %zu from node %s ", i, node_sub->GetName().c_str()); - continue); - (void)input_desc->SetUnknownDimNumShape(); - } - for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { - auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); - (void)output_desc->SetUnknownDimNumShape(); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - // check sub_graph shape. Get max for update. - for (size_t i = 0; i < ref_out_tensors.size(); ++i) { - if (ref_out_tensors[i].empty()) { - continue; - } - - int64_t max_size = 0; - size_t max_shape_index = 0; - auto &ref_out_tensor = ref_out_tensors[i].at(0); - const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape(); - for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { - auto &tensor = ref_out_tensors[i].at(j); - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); - return GRAPH_FAILED; - } - - auto shape = tensor.MutableShape(); - if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { - GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", - node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - return GRAPH_FAILED; - } - - int64_t size = 1; - for (auto dim : shape.GetDims()) { - if (INT64_MAX / dim < size) { - GELOGE(PARAM_INVALID, "The shape size overflow"); - return PARAM_INVALID; - } - size *= dim; - } - - if (size > max_size) { - max_size = size; - max_shape_index = j; - } - } - - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class branch op process"); - if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { - return UpdataOutputForMultiBatcch(node, ref_out_tensors); - } - - // check sub_graph shape.If not same ,do unknown shape process - for (size_t i = 0; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto ref_out_tensor = ref_out_tensors[i].at(0); - ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); - for (auto &tensor : ref_out_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); - return GRAPH_FAILED; - } - auto shape = tensor.MutableShape(); - if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { - GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, - shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - ref_out_tensor_shape = GeShape(UNKNOWN_RANK); - break; - } - for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { - if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { - continue; - } - GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, - j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector> &ref_data_tensors, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class while op process"); - if (ref_data_tensors.size() != ref_out_tensors.size()) { - GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), - ref_data_tensors.size(), ref_out_tensors.size()); - return GRAPH_FAILED; - } - for (size_t i = 0; i < ref_data_tensors.size(); i++) { - if (ref_out_tensors[i].size() != 1) { - GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); - return GRAPH_FAILED; - } - } - bool is_need_reverse_brush = false; - // check input and output - for (size_t i = 0; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto ref_out_tensor = ref_out_tensors[i].at(0); - auto tmp_shape = ref_out_tensor.MutableShape(); - // ref_i's data and output tensor shape should be same - for (auto &tensor : ref_data_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); - return GRAPH_FAILED; - } - auto shape = tensor.MutableShape(); - if (shape.GetDims() != tmp_shape.GetDims()) { - ref_out_tensor.SetUnknownDimNumShape(); - is_need_reverse_brush = true; - break; - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); - } - // reverse refresh while body shape - if (is_need_reverse_brush) { - return ReverseBrushWhileBodySubGraph(node); - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { - auto op_desc = node->GetOpDesc(); - auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } - - auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - for (const auto &name : sub_graph_names) { - if (name.empty()) { - GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); - continue; - } - auto sub_graph = root_graph->GetSubgraph(name); - if (sub_graph == nullptr) { - GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - for (const auto &node_sub : sub_graph->GetDirectNode()) { - if (node_sub->GetType() != DATA) { - continue; - } - int ref_i; - auto data_opdesc = node_sub->GetOpDesc(); - if (data_opdesc == nullptr) { - GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { - continue; - } - auto input_desc = op_desc->MutableInputDesc(ref_i); - if (input_desc == nullptr) { - GE_LOGE( - "The ref index(%d) on the data %s on the sub graph %s " - "parent node %s are incompatible, inputs num %u", - ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); - return GRAPH_FAILED; - } - GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), - node->GetName().c_str()); - auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); - - if (ret != GRAPH_SUCCESS) { - GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - return ret; - } - ret = data_opdesc->UpdateOutputDesc(0, *input_desc); - if (ret != GRAPH_SUCCESS) { - GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, NodePtr &netoutput, - const ConstNodePtr &node, - std::vector> &ref_data_tensors) { - auto sub_nodes = sub_graph->GetDirectNode(); - for (size_t i = sub_nodes.size(); i > 0; --i) { - auto sub_node = sub_nodes.at(i - 1); - if (sub_node->GetType() == NETOUTPUT) { - netoutput = sub_node; - } - if (sub_node->GetType() == DATA) { - if (sub_node->GetOpDesc() == nullptr) { - return GRAPH_FAILED; - } - - int ref_i; - if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); - return GRAPH_FAILED; - } - if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { - GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), - ref_i, node->GetAllInDataAnchorsSize()); - return GRAPH_FAILED; - } - ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); - } - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { - auto op_desc = node->GetOpDesc(); - auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } - - std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); - std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); - auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - - for (const auto &name : sub_graph_names) { - if (name.empty()) { - GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); - continue; - } - auto sub_graph = root_graph->GetSubgraph(name); - if (sub_graph == nullptr) { - GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - NodePtr netoutput = nullptr; - auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); - if (ret != GRAPH_SUCCESS) { - return ret; - } - if (netoutput == nullptr) { - GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - auto netoutput_opdesc = netoutput->GetOpDesc(); - if (netoutput_opdesc == nullptr) { - GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { - auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); - if (edge_desc == nullptr) { - GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), - node->GetName().c_str(), edge_anchor->GetIdx()); - return GRAPH_FAILED; - } - GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), - edge_desc->GetShape().GetDimNum()); - int ref_i; - if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. - continue; - } - GELOGI("Parent node index of edge desc is %d", ref_i); - if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { - return GRAPH_FAILED; - } - ref_out_tensors[ref_i].emplace_back(*edge_desc); - } - } - - if (node->GetType() == WHILE) { - return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); - } - return UpdateParentNodeForBranch(node, ref_out_tensors); -} - -string Serial(const vector &dims) { - string serial_string; - serial_string += "["; - for (int64_t dim : dims) { - serial_string += std::to_string(dim) + " "; - } - serial_string += "]"; - return serial_string; -} - -graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); - GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); - for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { - auto in_idx = in_anchor->GetIdx(); - auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - continue; - } - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { - continue; - } - int peer_out_idx = peer_out_data_anchor->GetIdx(); - auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); - - // check shape and dtype continuity. do not stop process - auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast(in_idx)); - if (in_desc == nullptr) { - continue; - } - auto in_shape = in_desc->GetShape().GetDims(); - auto in_dtype = in_desc->GetDataType(); - auto peer_out_shape = peer_out_desc->GetShape().GetDims(); - auto peer_out_dtype = peer_out_desc->GetDataType(); - if (peer_out_dtype != in_dtype) { - GELOGW( - "current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th " - "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), - peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); - } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { - string in_shape_str = Serial(in_shape); - string peer_out_shape_str = Serial(peer_out_shape); - GELOGW( - "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " - "input_shape is [%s].The two shape should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, - peer_out_shape_str.c_str()); - } - // refresh current node input desc - in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); - in_desc->SetShape(peer_out_desc->GetShape()); - in_desc->SetDataType(peer_out_desc->GetDataType()); - in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); - std::vector> shape_range; - (void)peer_out_desc->GetShapeRange(shape_range); - in_desc->SetShapeRange(shape_range); - ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast(peer_out_desc->GetShape().GetDims().size())); - } - return GRAPH_SUCCESS; -} -} // namespace -void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { - if (!IsLogEnable(GE, DLOG_DEBUG)) { - return; - } - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "node is null"); - return; - } - ge::OpDescPtr op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); - std::string str; - if (op_desc->GetInputsSize() != 0) { - std::string input_desc_str = "input shape: "; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - input_desc_str += "["; - for (int64_t dim : input_desc->GetShape().GetDims()) { - input_desc_str += std::to_string(dim) + " "; - } - input_desc_str += "]"; - input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" + - TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; - } - str += input_desc_str; - - input_desc_str = "input origin shape: "; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - input_desc_str += "["; - for (int64_t dim : input_desc->GetOriginShape().GetDims()) { - input_desc_str += std::to_string(dim) + " "; - } - input_desc_str += "]"; - input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + - TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; - } - str += input_desc_str; - } - - if (op_desc->GetAllOutputsDescSize() != 0) { - std::string output_desc_str = "output shape: "; - for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { - if (output_desc == nullptr) { - continue; - } - output_desc_str += "["; - for (int64_t dim : output_desc->GetShape().GetDims()) { - output_desc_str += std::to_string(dim) + " "; - } - output_desc_str += "]"; - output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" + - TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; - } - str += output_desc_str; - - output_desc_str = "output origin shape: "; - for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { - if (output_desc == nullptr) { - continue; - } - output_desc_str += "["; - for (int64_t dim : output_desc->GetOriginShape().GetDims()) { - output_desc_str += std::to_string(dim) + " "; - } - output_desc_str += "]"; - output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + - TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; - } - str += output_desc_str; - } - GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); -} - -graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { - return InferShapeAndType(node, op, true); -} -graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { - auto op_desc = node->GetOpDesc(); - const auto &op_type = op_desc->GetType(); - - graphStatus ret; - if (before_subgraph) { - ret = UpdateSubGraphDataNodes(node); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - // Get infer func and execute - ret = op_desc->CallInferFunc(op); - if (ret == GRAPH_PARAM_INVALID) { - // Op ir no infer func, try to get infer func from operator factory - auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); - if (node_op.IsEmpty()) { - GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); - return ret; - } - - GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); - auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - node_op.BreakConnect(); - if (temp_op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "temp op desc is null"); - return GRAPH_FAILED; - } - if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { - GELOGW("InferShapeAndType UpdateInputName failed"); - for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { - if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { - break; - } - return GRAPH_SUCCESS; - } - } - if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { - GELOGW("InferShapeAndType UpdateOutputName failed"); - } - op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); - ret = op_desc->CallInferFunc(op); - GELOGI("op CallInferFunc second. ret: %u", ret); - } - if (ret != GRAPH_SUCCESS) { - return ret; - } - - if (!before_subgraph) { - return UpdateParentNodeOutTensor(node); - } - return GRAPH_SUCCESS; -} - -InferenceContextPtr CreateInferenceContext(const std::unordered_map &context_map, - const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "node is null"); - return nullptr; - } - InferenceContextPtr inference_context = std::shared_ptr(InferenceContext::Create()); - if (inference_context == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); - return nullptr; - } - - auto all_in_data_anchors = node->GetAllInDataAnchors(); - std::vector> input_shapes_and_types(all_in_data_anchors.size()); - std::vector marks; - - bool has_input_shapes_and_types = false; - for (const auto &in_anchor : all_in_data_anchors) { - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - - auto input_node = out_anchor->GetOwnerNode(); - if (input_node == nullptr) { - continue; - } - - auto iter = context_map.find(input_node); - if (iter != context_map.end()) { - const auto &src_context = iter->second; - GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr); - GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), - input_node->GetName().c_str()); - for (auto mark : src_context->GetMarks()) { - marks.push_back(mark); - } - auto output_idx = out_anchor->GetIdx(); - auto input_idx = in_anchor->GetIdx(); - auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); - if (output_idx < static_cast(output_shape_and_type.size())) { - GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, - node->GetName().c_str(), input_idx); - input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; - has_input_shapes_and_types = true; - } else { - GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, - output_shape_and_type.size()); - } - } - } - - if (has_input_shapes_and_types) { - inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); - } - inference_context->SetMarks(marks); - - return inference_context; -} - -namespace { -thread_local std::unordered_map context_map; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { - return InferShapeAndType(node, true); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, - bool before_subgraph) { - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); - bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - auto opdesc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); - // some op can not infershape twice such as aipp - bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); - if (need_update_input) { - auto status = UpdateOpInputDesc(node); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "update op input_desc failed!"); - return status; - } - } - - if (node->Verify() != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - PrintInOutTensorShape(node, "before_infershape"); - Operator op = OpDescUtils::CreateOperatorFromNode(node); - - if (!is_unknown_graph) { - auto inference_context = CreateInferenceContext(context_map, node); - if (inference_context == nullptr) { - GELOGE(GRAPH_FAILED, "inference context is null"); - return GRAPH_FAILED; - } - GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); - op.SetInferenceContext(inference_context); - } - - graphStatus status = InferShapeAndType(node, op, before_subgraph); - if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { - if (is_unknown_graph) { - PrintInOutTensorShape(node, "after_infershape when running"); - return GRAPH_SUCCESS; - } - auto op_desc = node->GetOpDesc(); - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); - output_tensor->SetOriginShape(output_tensor->GetShape()); - output_tensor->SetOriginDataType(output_tensor->GetDataType()); - - GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); - } - } else { - GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - if (!is_unknown_graph) { - auto ctx_after_infer = op.GetInferenceContext(); - if (ctx_after_infer != nullptr) { - GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); - if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { - GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), - ctx_after_infer->GetMarks().size()); - (void)context_map.emplace(node, ctx_after_infer); - } - } - } - PrintInOutTensorShape(node, "after_infershape"); - - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/metadef/graph/tensor.cc b/metadef/graph/tensor.cc deleted file mode 100644 index 1f30c876..00000000 --- a/metadef/graph/tensor.cc +++ /dev/null @@ -1,704 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "external/graph/tensor.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/ge_tensor.h" -#include "securec.h" -#include "utils/attr_utils.h" -#include "utils/tensor_adapter.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace { -/// Extra 8 bytes store pointer of string -/// Extra 1 byte store '\0' -const int EXTRA_STORE_POINTER_FOR_STRING = 8; -const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; -const int64_t UNKNOWN_DIM_SIZE = -1; -} // namespace - -namespace ge { -// If not overflow return true -static bool Int64MulNotOverflow(int64_t a, int64_t b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return false; - } - } else { - if (b < (INT64_MIN / a)) { - return false; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return false; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return false; - } - } - } - return true; -} - -class TensorDescImpl { - public: - TensorDescImpl() = default; - ~TensorDescImpl() = default; - TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} - - Shape shape_; - std::vector> range_; - Format format_ = FORMAT_ND; - Format origin_format_ = FORMAT_ND; - DataType data_type_ = DT_FLOAT; - Shape origin_shape_; - int64_t size_ = 0; - int64_t real_dim_cnt_ = 0; - std::string name_; -}; - -class TensorImpl { - public: - TensorImpl() = default; - ~TensorImpl() = default; - - explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {} - TensorImpl(const TensorDesc &tensor_desc, const std::vector &data) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {} - TensorImpl(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {} - TensorImpl(TensorDesc &&tensor_desc, std::vector &&data) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {} - - GeTensor ge_tensor; -}; - -class ShapeImpl { - public: - ShapeImpl() = default; - ~ShapeImpl() = default; - explicit ShapeImpl(const std::vector &dims) { - bool is_unknown_dim_num = false; - for (const auto &dim : dims) { - if (dim == UNKNOWN_DIM_NUM) { - is_unknown_dim_num = true; - break; - } - } - dims_ = is_unknown_dim_num ? std::vector({UNKNOWN_DIM_NUM}) : dims; - } - - std::vector dims_; -}; - -Shape::Shape() { impl_ = ComGraphMakeShared(); } - -Shape::Shape(const std::vector &dims) { impl_ = ComGraphMakeShared(dims); } - -size_t Shape::GetDimNum() const { - if (impl_ != nullptr) { - for (auto i : impl_->dims_) { - if (i == UNKNOWN_DIM_NUM) { - return 0; - } - } - return impl_->dims_.size(); - } - return 0; -} - -int64_t Shape::GetDim(size_t idx) const { - if (impl_ != nullptr) { - if (idx >= impl_->dims_.size()) { - return 0; - } - return impl_->dims_[idx]; - } - return 0; -} - -graphStatus Shape::SetDim(size_t idx, int64_t value) { - if (impl_ != nullptr) { - if (idx >= impl_->dims_.size()) { - return GRAPH_FAILED; - } - impl_->dims_[idx] = value; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -std::vector Shape::GetDims() const { - vector dims; - if (impl_ != nullptr) { - return impl_->dims_; - } - return dims; -} - -int64_t Shape::GetShapeSize() const { - if (impl_ != nullptr) { - if (impl_->dims_.empty()) { - return 0; - } - int64_t size = 1; - for (auto i : impl_->dims_) { - if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { - return UNKNOWN_DIM_SIZE; - } - - if (!Int64MulNotOverflow(size, i)) { - GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); - size = 0; - return size; - } - size *= i; - } - return size; - } - return 0; -} - -TensorDesc::TensorDesc() { - impl = ComGraphMakeShared(); // lint !e665 -} - -TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { - impl = ComGraphMakeShared(shape, format, dt); // lint !e665 - SetRealDimCnt(shape.GetDimNum()); -} - -TensorDesc::TensorDesc(const TensorDesc &desc) { - // Copy - impl = ComGraphMakeShared(); // lint !e665 - if (desc.impl != nullptr && impl != nullptr) { - *impl = *desc.impl; - } -} - -TensorDesc::TensorDesc(TensorDesc &&desc) { - // Move - impl = std::move(desc.impl); -} - -TensorDesc &TensorDesc::operator=(const TensorDesc &desc) { - // Copy - if (&desc != this) { - impl = ComGraphMakeShared(); - if (desc.impl != nullptr && impl != nullptr) { - *impl = *desc.impl; - } - } - return *this; -} - -TensorDesc &TensorDesc::operator=(TensorDesc &&desc) { - if (&desc != this) { - impl = std::move(desc.impl); - } - return *this; -} - -void TensorDesc::Update(const Shape &shape, Format format, DataType dt) { - if (impl != nullptr) { - impl->shape_ = shape; - impl->format_ = format; - impl->data_type_ = dt; - } -} - -Shape TensorDesc::GetShape() const { - if (impl != nullptr) { - return impl->shape_; - } - return Shape(); -} - -void TensorDesc::SetShape(const Shape &shape) { - if (impl != nullptr) { - impl->shape_ = shape; - } -} - -// set shape with -2, it stand for unknown shape -graphStatus TensorDesc::SetUnknownDimNumShape() { - if (impl != nullptr) { - impl->shape_ = Shape({UNKNOWN_DIM_NUM}); - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); - return GRAPH_FAILED; -} - -// for unknown shape -graphStatus TensorDesc::SetShapeRange(const std::vector> &range) { - if (impl != nullptr) { - impl->range_ = range; - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); - return GRAPH_FAILED; -} -graphStatus TensorDesc::GetShapeRange(std::vector> &range) const { - if (impl != nullptr) { - range = impl->range_; - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "impl is nullptr!"); - return GRAPH_FAILED; -} - -Shape TensorDesc::GetOriginShape() const { - if (impl != nullptr) { - return impl->origin_shape_; - } - return Shape(); -} - -void TensorDesc::SetOriginShape(const Shape &origin_shape) { - if (impl != nullptr) { - impl->origin_shape_ = origin_shape; - } -} - -Format TensorDesc::GetFormat() const { - if (impl != nullptr) { - return impl->format_; - } - return FORMAT_RESERVED; -} - -void TensorDesc::SetFormat(Format format) { - if (impl != nullptr) { - impl->format_ = format; - } -} - -Format TensorDesc::GetOriginFormat() const { - if (impl != nullptr) { - return impl->origin_format_; - } - return FORMAT_RESERVED; -} - -void TensorDesc::SetOriginFormat(Format origin_format) { - if (impl != nullptr) { - impl->origin_format_ = origin_format; - } -} - -DataType TensorDesc::GetDataType() const { - if (impl != nullptr) { - return impl->data_type_; - } - return DT_UNDEFINED; -} - -void TensorDesc::SetDataType(DataType dt) { - if (impl != nullptr) { - impl->data_type_ = dt; - } -} - -void TensorDesc::SetSize(int64_t size) { - if (impl != nullptr) { - impl->size_ = size; - } -} - -int64_t TensorDesc::GetSize() const { - if (impl != nullptr) { - return impl->size_; - } - return 0; -} - -void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) { - if (impl != nullptr) { - impl->real_dim_cnt_ = real_dim_cnt; - } -} - -int64_t TensorDesc::GetRealDimCnt() const { - if (impl != nullptr) { - return impl->real_dim_cnt_; - } - return 0; -} - -std::string TensorDesc::GetName() const { - if (impl != nullptr) { - return impl->name_; - } - return ""; -} - -void TensorDesc::SetName(const std::string &name) { - if (impl != nullptr) { - impl->name_ = name; - } -} - -Tensor::Tensor() { impl = ComGraphMakeShared(); } - -Tensor::Tensor(const TensorDesc &tensor_desc) { - impl = ComGraphMakeShared(tensor_desc); // lint !e665 -} - -Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) { - uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); - DataType data_type = tensor_desc.GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - } - - auto data_size = data.size(); - if (ret && (shape_size || (data_size != type_length))) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != data_size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - } - } - } - impl = ComGraphMakeShared(tensor_desc, data); // lint !e665 -} - -Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { - uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); - DataType data_type = tensor_desc.GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - } - if (ret && (shape_size || (size != type_length))) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - } - } - } - - impl = ComGraphMakeShared(tensor_desc, data, size); // lint !e665 -} - -Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { - uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); - DataType data_type = tensor_desc.GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - } - - auto data_size = data.size(); - if (ret && (shape_size || (data_size != type_length))) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != data_size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - } - } - } - impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); // lint !e665 -} - -TensorDesc Tensor::GetTensorDesc() const { - if (impl != nullptr) { - return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc()); - } - return TensorDesc(); -} - -graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) { - if (impl != nullptr) { - impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -const uint8_t *Tensor::GetData() const { - if (impl != nullptr) { - return impl->ge_tensor.GetData().data(); - } - return nullptr; -} - -uint8_t *Tensor::GetData() { - if (impl != nullptr) { - return impl->ge_tensor.MutableData().data(); - } - return nullptr; -} - -size_t Tensor::GetSize() const { - if (impl != nullptr) { - return impl->ge_tensor.GetData().size(); - } - return 0; -} - -graphStatus Tensor::SetData(std::vector &&data) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::vector &data) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const uint8_t *data, size_t size) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data, size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::string &data) { - if (impl != nullptr && (!data.empty())) { - /// Extra 8 bytes store pointer of string - /// Extra 1 byte store '\0' - size_t total_size = data.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL; - std::unique_ptr buff(new (std::nothrow) char[total_size]()); - if (buff == nullptr) { - GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); - return GRAPH_FAILED; - } - uint64_t *p = reinterpret_cast(buff.get()); - // Front 8 bytes store pointer of string - char *raw_data = buff.get() + EXTRA_STORE_POINTER_FOR_STRING; - p[0] = reinterpret_cast(raw_data); - int32_t memcpy_ret = memcpy_s(raw_data, total_size - EXTRA_STORE_POINTER_FOR_STRING, data.c_str(), data.size() + 1); - GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); - (void)impl->ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} -graphStatus Tensor::SetData(const std::vector &data) { - if (impl != nullptr) { - if (data.empty()) { - GELOGE(GRAPH_FAILED, "there is no data, please check the input variable"); - return GRAPH_FAILED; - } - size_t total_size = 0; - for (auto str : data) { - /// Extra 8 bytes store pointer of each string - /// Extra 1 byte store '\0' - total_size += (str.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL); - } - std::unique_ptr buff(new (std::nothrow) char[total_size]); - if (buff == nullptr) { - GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); - return GRAPH_FAILED; - } - uint64_t *p = reinterpret_cast(buff.get()); - // Front some bytes store pointer of each string - char *raw_data = buff.get() + data.size() * sizeof(uint64_t); - uint64_t ptr_size = data.size() * sizeof(uint64_t); - for (size_t i = 0; i < data.size(); ++i) { - p[i] = reinterpret_cast(raw_data); - if (total_size < ptr_size) { - GELOGE(GRAPH_FAILED, "Subtraction invalid, total_size: %zu, ptr_size: %lu", total_size, ptr_size); - return GRAPH_FAILED; - } - int32_t memcpy_ret = memcpy_s(raw_data, total_size - ptr_size, data[i].c_str(), data[i].size() + 1); - GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); - raw_data += (data[i].size() + 1); - ptr_size += (data[i].size() + 1); - } - - (void)impl->ge_tensor.SetData(reinterpret_cast(buff.get()), total_size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::IsValid() { - uint64_t shape_size = GetTensorDesc().GetShape().GetShapeSize(); - DataType data_type = GetTensorDesc().GetDataType(); - uint32_t type_length; - bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("datatype %d is not found.", data_type); - return GRAPH_SUCCESS; - } - - size_t data_size = GetSize(); - if (data_type != DT_STRING) { - if (shape_size || (data_size != type_length)) { - if (type_length != 0 && UINT64_MAX / type_length < shape_size) { - GELOGW("mul overflow: %lu, %u", shape_size, type_length); - } else { - if (shape_size * type_length != data_size) { - GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, - data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - return GRAPH_FAILED; - } - } - } - } - - return GRAPH_SUCCESS; -} - -Tensor Tensor::Clone() const { - Tensor tensor; - if (impl != nullptr && tensor.impl != nullptr) { - tensor.impl->ge_tensor = impl->ge_tensor.Clone(); - } - return tensor; -} - -GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) { - GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(), - tensor_desc.GetDataType()); - ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); - ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); - ge_tensor_desc.SetName(tensor_desc.GetName()); - std::vector> shape_range; - auto status = tensor_desc.GetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Get shape range failed!"); - return ge_tensor_desc; - } - status = ge_tensor_desc.SetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set shape range failed!"); - return ge_tensor_desc; - } - auto size = tensor_desc.GetSize(); - TensorUtils::SetSize(ge_tensor_desc, size); - - auto real_dim_cnt = static_cast(tensor_desc.GetRealDimCnt()); - TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); - return ge_tensor_desc; -} - -TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) { - TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(), - ge_tensor_desc.GetDataType()); - tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); - tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); - tensor_desc.SetName(ge_tensor_desc.GetName()); - std::vector> shape_range; - auto status = ge_tensor_desc.GetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Get shape range failed!"); - return tensor_desc; - } - status = tensor_desc.SetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set shape range failed!"); - return tensor_desc; - } - int64_t size = 0; - (void)TensorUtils::GetSize(ge_tensor_desc, size); - tensor_desc.SetSize(size); - - uint32_t real_dim_cnt = 0; - (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); - tensor_desc.SetRealDimCnt(real_dim_cnt); - return tensor_desc; -} - -GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor.Clone()); // lint !e665 - } - return ge_tensor; -} - -Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { - Tensor tensor; - if (ge_tensor != nullptr && tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor->Clone(); - } - return tensor; -} - -ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 - } - return ge_tensor; -} - -GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 - } - return ge_tensor; -} - -const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) { - if (tensor.impl != nullptr) { - return tensor.impl->ge_tensor; - } - return GeTensor(); -} - -GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) { - if (tensor.impl != nullptr) { - return tensor.impl->ge_tensor; - } - return GeTensor(); -} - -const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) { - Tensor tensor; - if (tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor; - } - return tensor; -} - -Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) { - Tensor tensor; - if (tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor; - } - return tensor; -} -} // namespace ge diff --git a/metadef/graph/utils/anchor_utils.cc b/metadef/graph/utils/anchor_utils.cc deleted file mode 100644 index 5a042283..00000000 --- a/metadef/graph/utils/anchor_utils.cc +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/anchor_utils.h" -#include -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" - -namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Format AnchorUtils::GetFormat(const DataAnchorPtr &data_anchor) { - if (data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); - return FORMAT_RESERVED; - } - return data_anchor->format_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetFormat(const DataAnchorPtr &data_anchor, - Format data_format) { - if ((data_anchor == nullptr) || (data_format == FORMAT_RESERVED)) { - GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); - return GRAPH_FAILED; - } - data_anchor->format_ = data_format; - return GRAPH_SUCCESS; -} - -// Get anchor status -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) { - if (data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); - return ANCHOR_RESERVED; - } - return data_anchor->status_; -} - -// Set anchor status -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor, - AnchorStatus anchor_status) { - if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) { - GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); - return GRAPH_FAILED; - } - data_anchor->status_ = anchor_status; - return GRAPH_SUCCESS; -} - -bool AnchorUtils::HasControlEdge(const AnchorPtr &anchor) { - auto control_anchor = Anchor::DynamicAnchorCast(anchor); - if (control_anchor != nullptr) { - return (control_anchor->GetPeerAnchors().size() != 0); - } - - auto data_anchor = Anchor::DynamicAnchorCast(anchor); - if (data_anchor) { - for (const auto &peer : data_anchor->GetPeerAnchors()) { - auto peer_cast = Anchor::DynamicAnchorCast(peer); - if (peer_cast) { - return true; - } - } - return false; - } - GELOGE(GRAPH_FAILED, "the anchor is neither control anchor nor data anchor"); - return false; -} - -bool AnchorUtils::IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst) { - GE_CHK_BOOL_EXEC(src != nullptr, return false, "src is null."); - GE_CHK_BOOL_RET_STATUS_NOLOG(src->IsLinkedWith(dst), false); - auto src_control_anchor = Anchor::DynamicAnchorCast(src); - auto dst_control_anchor = Anchor::DynamicAnchorCast(dst); - return (src_control_anchor || dst_control_anchor); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int AnchorUtils::GetIdx(const AnchorPtr &anchor) { - // Check if it can add edge between DataAnchor - auto data_anchor = Anchor::DynamicAnchorCast(anchor); - if (data_anchor != nullptr) { - return data_anchor->GetIdx(); - } - // Check if it can add edge between ControlAnchor - auto control_anchor = Anchor::DynamicAnchorCast(anchor); - if (control_anchor != nullptr) { - return control_anchor->GetIdx(); - } - return -1; -} -} // namespace ge diff --git a/metadef/graph/utils/ge_ir_utils.cc b/metadef/graph/utils/ge_ir_utils.cc deleted file mode 100644 index f238c6e8..00000000 --- a/metadef/graph/utils/ge_ir_utils.cc +++ /dev/null @@ -1,1178 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/ge_ir_utils.h" -#include -#include "framework/common/debug/ge_log.h" - -namespace { -const char *const kControlAnchorIndex = ":-1"; -const char *const kNodeTypeForSubgraph = "subgraph"; -const char *const kPrefixForInputDesc = "input_desc_attr_"; -const char *const kPrefixForOutputDesc = "output_desc_attr_"; -const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; -const int8_t kMaxRecursionDepth = 10; -const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); -const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; -const int64_t kInputPrefixLength = 5; -const int64_t kOutputPrefixLength = 6; -using AttrDefPair = ::google::protobuf::MapPair; -} // namespace - -namespace ge { -// Part 1: from IR convert to ONNX Protobuf -static const std::map kGeDataTypeToOnnxMap = { - {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, - {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, - {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, - {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, - {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, - {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, -}; - -onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) { - auto it = kGeDataTypeToOnnxMap.find(data_type); - if (it != kGeDataTypeToOnnxMap.end()) { - return it->second; - } else { - GELOGW("EncodeDataType: datatype not support %u", data_type); - return onnx::TensorProto_DataType_UNDEFINED; - } -} - -void OnnxUtils::AddAttrProtoFromAttribute(const std::pair &string_attr_value, - onnx::NodeProto *node_proto) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node proto is nullptr."); - return; - } - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - auto attr_name = string_attr_value.first; - attr->set_name(attr_name); - auto attr_value = string_attr_value.second; - auto value_type = attr_value.GetValueType(); - switch (value_type) { - case GeAttrValue::VT_FLOAT: { - GeAttrValue::FLOAT data_f = 0; - (void)attr_value.GetValue(data_f); - attr->set_f(data_f); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - } - case GeAttrValue::VT_LIST_FLOAT: { - GeAttrValue::LIST_FLOAT data_fs = {}; - (void)attr_value.GetValue(data_fs); - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for (auto &v : data_fs) { - attr->add_floats(v); - } - break; - } - case GeAttrValue::VT_INT: { - GeAttrValue::INT data_i = 0; - (void)attr_value.GetValue(data_i); - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i(data_i); - break; - } - case GeAttrValue::VT_LIST_INT: { - GeAttrValue::LIST_INT data_is = {}; - (void)attr_value.GetValue(data_is); - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for (auto &v : data_is) { - attr->add_ints(v); - } - break; - } - case GeAttrValue::VT_STRING: { - GeAttrValue::STR data_s; - (void)attr_value.GetValue(data_s); - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s(data_s); - break; - } - case GeAttrValue::VT_LIST_STRING: { - GeAttrValue::LIST_STR data_ss = {}; - (void)attr_value.GetValue(data_ss); - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for (auto &v : data_ss) { - attr->add_strings(v); - } - break; - } - default: - GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type); - break; - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - void *data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); - return; - } - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - switch (type) { - case onnx::AttributeProto_AttributeType_FLOAT: - attr->set_f((*(static_cast(data)))); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - - case onnx::AttributeProto_AttributeType_FLOATS: - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for (auto &v : (*(static_cast *>(data)))) { - attr->add_floats(v); - } - break; - - case onnx::AttributeProto_AttributeType_INT: - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i((*(static_cast(data)))); - break; - - case onnx::AttributeProto_AttributeType_INTS: - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for (auto &v : *(static_cast *>(data))) { - attr->add_ints(v); - } - break; - - case onnx::AttributeProto_AttributeType_STRING: - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s((*(static_cast(data)))); - break; - - case onnx::AttributeProto_AttributeType_STRINGS: - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for (auto &v : *(static_cast *>(data))) { - attr->add_strings(v); - } - break; - - default: - GELOGW("AttributeProto AttributeType: %u is not supported for now", type); - break; - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedField<::google::protobuf::int64> data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_ints(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedField data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_ints(static_cast(v)); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedField data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_floats(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, - ::google::protobuf::RepeatedPtrField<::std::string> data) { - if (node_proto == nullptr) { - GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); - return; - } - if (!data.empty()) { - auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - GELOGE(GRAPH_FAILED, "attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_strings(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) { - if (node_proto == nullptr || op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr"); - return; - } - // Input describes - auto size_in = op_desc->GetAllInputsSize(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in); - if (size_in > 0) { - for (uint32_t i = 0; i < size_in; i++) { - auto input_desc = op_desc->GetInputDescPtrDfault(i); - if (input_desc != nullptr) { - auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), - &data_type); - auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); - auto dims = input_desc->GetShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), - &dims); - auto dims_origin = input_desc->GetOriginShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, - "input_desc_origin_shape:" + std::to_string(i), &dims_origin); - auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i), - &layout); - auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_origin_layout:" + std::to_string(i), &layout_origin); - auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor != nullptr) { - auto size = tensor_descriptor->size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i), - &size); - auto weight_size = tensor_descriptor->weight_size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_weight_size:" + std::to_string(i), &weight_size); - auto reuse_input = tensor_descriptor->reuse_input(); - auto reuse_input_int = static_cast(reuse_input); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_reuse_input:" + std::to_string(i), &reuse_input_int); - auto output_tensor = tensor_descriptor->output_tensor(); - auto output_tensor_int = static_cast(output_tensor); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_output_tensor:" + std::to_string(i), &output_tensor_int); - auto device_type = tensor_descriptor->device_type(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_device_type:" + std::to_string(i), &device_type); - auto input_tensor = tensor_descriptor->input_tensor(); - auto input_tensor_int = static_cast(input_tensor); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_input_tensor:" + std::to_string(i), &input_tensor_int); - auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); - auto data_offset = tensor_descriptor->data_offset(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_data_offset:" + std::to_string(i), &data_offset); - auto cmps_size = tensor_descriptor->cmps_size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i), - &cmps_size); - auto cmps_tab = tensor_descriptor->cmps_tab(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "input_desc_cmps_tab:" + std::to_string(i), &cmps_tab); - auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); - const auto &tensor_desc_map = tensor_descriptor->attr(); - std::string suffix = ":" + std::to_string(i); - AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); - } else { - GELOGW("Tensor descriptor is nullptr"); - continue; - } - } else { - GELOGW("Input desc is nullptr"); - continue; - } - } - } - // Output describes - auto size_out = op_desc->GetOutputsSize(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out); - if (size_out > 0) { - for (uint32_t i = 0; i < size_out; i++) { - auto output_desc = op_desc->GetOutputDescPtr(i); - if (output_desc != nullptr) { - auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), - &data_type); - auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); - auto dims = output_desc->GetShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), - &dims); - auto dims_origin = output_desc->GetOriginShape().GetDims(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, - "output_desc_origin_shape:" + std::to_string(i), &dims_origin); - auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i), - &layout); - auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "output_desc_origin_layout:" + std::to_string(i), &layout_origin); - auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg(); - if (tensor_descriptor != nullptr) { - auto size = tensor_descriptor->size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i), - &size); - auto weight_size = tensor_descriptor->weight_size(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "output_desc_weight_size:" + std::to_string(i), &weight_size); - auto device_type = tensor_descriptor->device_type(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "output_desc_device_type:" + std::to_string(i), &device_type); - auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, - "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); - const auto &tensor_desc_map = tensor_descriptor->attr(); - std::string suffix = ":" + std::to_string(i); - AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); - } else { - GELOGW("Tensor descriptor is nullptr"); - continue; - } - } else { - GELOGW("Output desc is nullptr"); - continue; - } - } - } -} - -void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( - const ::google::protobuf::Map &attr_map, onnx::NodeProto *node_proto, - const std::string &prefix, const std::string &suffix) { - for (const auto &item : attr_map) { - auto attr_name = item.first; - auto attr_def = item.second; - auto attr_type = attr_def.value_case(); - if (attr_type == ge::proto::AttrDef::kT) { - const auto &tensor_def = attr_def.t(); - const auto &tensor_desc = tensor_def.desc(); - auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, - &data_type); - auto dims = tensor_desc.shape().dim(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, - dims); - auto layout = tensor_desc.layout(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, - &layout); - auto device_type = tensor_desc.device_type(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - prefix + attr_name + "_desc_device_type" + suffix, &device_type); - if (kDumpLevel == DUMP_ALL) { - auto data = tensor_def.data(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, - &data); - } - } - if (attr_type == ge::proto::AttrDef::kS) { - if (kDumpLevel == DUMP_ALL) { - auto str_value = attr_def.s(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); - } - } - if (attr_type == ge::proto::AttrDef::kI) { - auto int_value = attr_def.i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); - } - if (attr_type == ge::proto::AttrDef::kF) { - auto float_value = attr_def.f(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); - } - if (attr_type == ge::proto::AttrDef::kB) { - auto int_value = static_cast(attr_def.b()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); - } - if (attr_type == ge::proto::AttrDef::kList) { - const auto &list_value = attr_def.list(); - auto list_value_type = list_value.val_type(); - if (list_value_type == - ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { - if (kDumpLevel == DUMP_ALL) { - const auto &strings = list_value.s(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); - } - } - if (list_value_type == - ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { - const auto &floats = list_value.f(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { - const auto &ints = list_value.i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { - const auto &bools = list_value.b(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); - } - } - } -} - -void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "node is nullptr"); - return; - } - // 1.Attributes added from node's methods - auto send_list = node->send_event_id_list_; - if (!send_list.empty()) { - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list); - } - auto recv_list = node->recv_event_id_list_; - if (!recv_list.empty()) { - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list); - } - auto op_desc = node->op_; - if (op_desc != nullptr) { - // for input_name_idx_ in opdesc - auto input_name_2_indexs = op_desc->GetAllInputName(); - ::google::protobuf::RepeatedPtrField<::std::string> input_names; - ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes; - for (const auto &input_name_2_index : input_name_2_indexs) { - std::string input_name = input_name_2_index.first; - input_names.Add(std::move(input_name)); - input_indexes.Add(input_name_2_index.second); - } - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes); - // 2.Attributes added from node's op_(message OpDef) - // Input and out describes - AddAttrProtoForOpInAndOutDesc(node_proto, op_desc); - // Others - auto op_def = op_desc->op_def_.GetProtoMsg(); - if (op_def != nullptr) { - auto id = op_def->id(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id); - auto stream_id = op_def->stream_id(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id); - const auto &input_name = op_def->input_name(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name); - const auto &src_name = op_def->src_name(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name); - const auto &src_index = op_def->src_index(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index); - const auto &dst_name = op_def->dst_name(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name); - const auto &dst_index = op_def->dst_index(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index); - const auto &input_i = op_def->input_i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i); - const auto &output_i = op_def->output_i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i); - const auto &workspace = op_def->workspace(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace); - const auto &workspace_bytes = op_def->workspace_bytes(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); - const auto &is_input_const = op_def->is_input_const(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); - const auto &op_def_attr_map = op_def->attr(); - AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); - } else { - GELOGE(FAILED, "Opdef is nullptr"); - return; - } - } else { - GELOGE(FAILED, "Opdesc is nullptr"); - return; - } -} - -bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid"); - return false; - } - - // 2.Encode map attrs_ to AttributeProto - for (auto &node_attr : node->attrs_) { - AddAttrProtoFromAttribute(node_attr, node_proto); - } - // 3.Encode ge::Node members to AttributeProto - AddAttrProtoFromNodeMembers(node, node_proto); - return true; -} - -void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid"); - return; - } - const auto &node_name = node->GetName(); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) { - node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx())); - } - } - auto out_control_anchor = node->GetOutControlAnchor(); - if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) { - node_proto->add_output(node_name + kControlAnchorIndex); - } -} - -bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid"); - return false; - } - node_proto->clear_input(); - // 1. Add input by in data edge - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { - node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + - std::to_string(peer_out_anchor->GetIdx())); - } else { - // Add "" input - node_proto->add_input(""); - } - } - - // 2. Add input by in control edge - auto in_control_anchor = node->GetInControlAnchor(); - if (in_control_anchor != nullptr) { - auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor->GetOwnerNode()) { - node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); - } - } - } else { - GELOGE(FAILED, "Incontrol anchor is nullptr"); - return false; - } - - // 3. Add output for Netron visual support - EncodeNodeLinkForNetronVisual(node, node_proto); - return true; -} - -bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid"); - return false; - } - // 1. Encode name and type - node_proto->set_name(node->GetName()); - /// Netron believes that some operators, such as the activation operator of softplus, only have one input, - /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix - /// is added to correctly display the link relation at the expense of some color features - node_proto->set_op_type("ge:" + node->GetType()); - - if (kDumpLevel != DUMP_WITH_OUT_DESC) { - // 2.for attr - if (!EncodeNodeDesc(node, node_proto)) { - GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str()); - return false; - } - } - // 3.for link info - return EncodeNodeLink(node, node_proto); -} - -void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) { - if ((node == nullptr) || (tensor_type == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid"); - return; - } - const auto &op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - uint32_t size_out = static_cast(op_desc->GetOutputsSize()); - if (size_out > 0) { - for (uint32_t i = 0; i < size_out; i++) { - const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); - if (ge_tensor != nullptr) { - auto ge_data_type = ge_tensor->GetDataType(); - auto onnx_data_type = EncodeDataType(ge_data_type); - tensor_type->set_elem_type(onnx_data_type); - onnx::TensorShapeProto *shape = tensor_type->mutable_shape(); - if (shape != nullptr) { - for (auto d : ge_tensor->GetShape().GetDims()) { - auto dim = shape->add_dim(); - dim->set_dim_value(d); - } - } else { - GELOGW("Shape is nullptr"); - continue; - } - } else { - GELOGW("Ge tensor is nullptr"); - continue; - } - } - } - } else { - GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str()); - return; - } -} - -void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) { - if ((node == nullptr) || (value_info_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid"); - return; - } - value_info_proto->set_name(node->GetName()); - onnx::TypeProto *t = value_info_proto->mutable_type(); - onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type(); - EncodeTypeProtoTensorType(node, tensor_type); -} - -bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) { - if ((graph == nullptr) || (graph_proto == nullptr)) { - GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid"); - return false; - } - graph_proto->set_name(graph->GetName()); - // 1. Add graph inputs - for (const auto &input : graph->GetInputNodes()) { - auto value_info_proto = graph_proto->add_input(); - EncodeValueInfo(input, value_info_proto); - } - // 2. Add graph outputs - for (const auto &output : graph->GetOutputNodes()) { - auto value_info_proto = graph_proto->add_output(); - EncodeValueInfo(output, value_info_proto); - } - // 3. Add nodes - for (const auto &node : graph->GetDirectNode()) { - if (!EncodeNode(node, graph_proto->add_node())) { - GELOGW("EncodeNode failed"); - continue; - } - } - return true; -} - -bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) { - model_proto.set_model_version(model.GetVersion()); - model_proto.set_ir_version(onnx::IR_VERSION); - model_proto.set_producer_name(model.GetName()); - auto &graph = model.graph_; - auto compute_graph = GraphUtils::GetComputeGraph(graph); - if (compute_graph == nullptr) { - GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr"); - return false; - } - auto graph_proto = model_proto.mutable_graph(); - if (graph_proto == nullptr) { - GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str()); - return false; - } - if (!EncodeGraph(compute_graph, graph_proto)) { - GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str()); - return false; - } - - // For subgraphs: a subgraph is represented by a node - for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { - if (sub_compute_graph != nullptr) { - auto node_proto = graph_proto->add_node(); - if (node_proto == nullptr) { - GELOGW("Node proto is nullptr"); - continue; - } - node_proto->set_name(sub_compute_graph->GetName()); - node_proto->set_op_type(kNodeTypeForSubgraph); - auto attr = node_proto->add_attribute(); - attr->set_name("graph"); - attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); - auto sub_graph_proto = attr->mutable_g(); - if (sub_graph_proto == nullptr) { - GELOGW("Sub graph proto is nullptr"); - continue; - } - if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { - GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); - continue; - } - } else { - GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str()); - continue; - } - } - return true; -} - -// Part 2: from ONNX Protobuf convert to IR -static std::map onnxDataTypeToGeMap = { - {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, - {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, - {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, - {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, - {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, - {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, -}; - -ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) { - auto it = onnxDataTypeToGeMap.find(data_type); - if (it != onnxDataTypeToGeMap.end()) { - return it->second; - } else { - GELOGW("DecodeDataType: datatype not support %u", data_type); - return ge::DT_UNDEFINED; - } -} - -bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) { - auto sep = node_name_index.rfind(':'); - if (sep == std::string::npos) { - return false; - } - node_name = node_name_index.substr(0, sep); - auto index_str = node_name_index.substr(sep + 1); - index = static_cast(std::strtol(index_str.c_str(), nullptr, 10)); - return true; -} - -bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) { - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr"); - return false; - } - // Data edge - if (item.src_out_index >= 0) { - auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index); - auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); - if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { - GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, - item.dst_node_name.c_str(), item.dst_in_index); - return false; - } - if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed"); - return false; - } - // Control edge - } else { - auto src_anchor = node_ptr->GetOutControlAnchor(); - auto dst_anchor = item.dst_node->GetInControlAnchor(); - if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { - GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, - item.dst_node_name.c_str(), item.dst_in_index); - return false; - } - if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed"); - return false; - } - } - return true; -} - -bool OnnxUtils::DecodeNodeLink(const std::vector &node_proto_vector, - const std::map &node_map) { - for (const auto &node_proto : node_proto_vector) { - const auto &node_name = node_proto.name(); - auto dst_node = node_map.find(node_name); - if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) { - GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str()); - return false; - } - int32_t dst_index = 0; - for (const auto &input : node_proto.input()) { - std::string input_node_name; - int32_t index = 0; - if (ParseNameIndex(input, input_node_name, index)) { - auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()}; - auto src_node = node_map.find(input_node_name); - if (src_node == node_map.end()) { - GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str()); - return false; - } - auto node_ptr = src_node->second; - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str()); - return false; - } - if (!DecodeNodeLinkImp(item, node_ptr)) { - GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str()); - return false; - } - } - if (index >= 0) { - dst_index++; - } - } - } - return true; -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - for (int i = 0; i < attr_proto.strings_size(); i++) { - strings.push_back(attr_proto.strings(i)); - } -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - value = attr_proto.s(); -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - for (int i = 0; i < attr_proto.ints_size(); i++) { - ints.push_back(attr_proto.ints(i)); - } -} - -void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) { - if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) { - GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - value = attr_proto.i(); -} - -void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_desc, int32_t index, - OpDescPtr &op_desc) { - if (op_desc->MutableInputDesc(static_cast(index)) == nullptr) { - GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast(index)) is nullptr", - op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); - return; - } - if (attr_name_for_input_desc == "input_desc_dtype") { - auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); - op_desc->MutableInputDesc(static_cast(index))->SetDataType(data_type); - } else if (attr_name_for_input_desc == "input_desc_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableInputDesc(static_cast(index))->SetShape(ge_shape); - } else if (attr_name_for_input_desc == "input_desc_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableInputDesc(static_cast(index))->SetFormat(data_format); - } else if (attr_name_for_input_desc == "input_desc_origin_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableInputDesc(static_cast(index))->SetOriginShape(ge_shape); - } else if (attr_name_for_input_desc == "input_desc_origin_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableInputDesc(static_cast(index))->SetOriginFormat(data_format); - } else if (attr_name_for_input_desc == "input_desc_size") { - int64_t input_size = 0; - auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - DecodeAttribute(attr_proto, input_size); - tensor_descriptor->set_size(input_size); - } else if (attr_name_for_input_desc == "input_desc_data_offset") { - auto tensor_descriptor = op_desc->MutableInputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - int64_t offset = 0; - DecodeAttribute(attr_proto, offset); - tensor_descriptor->set_data_offset(offset); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_output_desc, int32_t index, - OpDescPtr &op_desc) { - if (op_desc->MutableOutputDesc(static_cast(index)) == nullptr) { - GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast(index)) is nullptr", - op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); - return; - } - if (attr_name_for_output_desc == "output_desc_dtype") { - auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); - op_desc->MutableOutputDesc(static_cast(index))->SetDataType(data_type); - } else if (attr_name_for_output_desc == "output_desc_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableOutputDesc(static_cast(index))->SetShape(ge_shape); - } else if (attr_name_for_output_desc == "output_desc_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableOutputDesc(static_cast(index))->SetFormat(data_format); - } else if (attr_name_for_output_desc == "output_desc_origin_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - GeShape ge_shape(ints); - op_desc->MutableOutputDesc(static_cast(index))->SetOriginShape(ge_shape); - } else if (attr_name_for_output_desc == "output_desc_origin_layout") { - auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - op_desc->MutableOutputDesc(static_cast(index))->SetOriginFormat(data_format); - } else if (attr_name_for_output_desc == "output_desc_size") { - int64_t output_size = 0; - auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - DecodeAttribute(attr_proto, output_size); - tensor_descriptor->set_size(output_size); - } else if (attr_name_for_output_desc == "output_desc_data_offset") { - auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast(index))->tensor_descriptor_.GetProtoMsg(); - int64_t offset = 0; - DecodeAttribute(attr_proto, offset); - tensor_descriptor->set_data_offset(offset); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_output_desc, int32_t index, - OpDescPtr &op_desc) { - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc is nullptr"); - return; - } - if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { - DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { - DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { - auto attr_map = op_def.mutable_attr(); - const auto &attr_name = attr_proto.name(); - ge::proto::AttrDef op_attr; - int64_t value = 0; - DecodeAttribute(attr_proto, value); - op_attr.set_i(value); - attr_map->insert(AttrDefPair(attr_name, op_attr)); -} - -void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); - return; - } - const auto &attr_name = attr_proto.name(); - std::string attr_name_for_input_output_desc; - int32_t index = 0; - if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) { - if (attr_name == "id") { - op_desc->SetId(attr_proto.i()); - } else if (attr_name == "stream_id") { - op_desc->SetStreamId(attr_proto.i()); - } else if (attr_name == "src_name") { - std::vector strings; - DecodeAttribute(attr_proto, strings); - op_desc->SetSrcName(strings); - } else if (attr_name == "dst_name") { - std::vector strings; - DecodeAttribute(attr_proto, strings); - op_desc->SetDstName(strings); - } else if (attr_name == "src_index") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetSrcIndex(ints); - } else if (attr_name == "dst_index") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetDstIndex(ints); - } else if (attr_name == "fusion_scope") { - DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); - } else if (attr_name == "input_i") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetInputOffset(ints); - } else if (attr_name == "output_i") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetOutputOffset(ints); - } else { - return; - } - // Update input and output desc - } else { - DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } -} - -bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) { - if (op_desc == nullptr || node_proto == nullptr) { - GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr"); - return false; - } - // 1. Decode node_proto name and type - op_desc->SetName(node_proto->name()); - const auto &node_type_with_ge_prefix = node_proto->op_type(); - auto sep = node_type_with_ge_prefix.find(':'); - if (sep == std::string::npos) { - return false; - } - auto node_type = node_type_with_ge_prefix.substr(sep + 1); - op_desc->SetType(node_type); - // 2. Add empty input and output desc - for (const auto &attr : node_proto->attribute()) { - if (attr.name() == "input_desc_nums") { - auto size_in = attr.i(); - for (int64_t i = 0; i < size_in; i++) { - GeTensorDesc ge_tensor_desc; - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); - } - } - if (attr.name() == "output_desc_nums") { - auto size_out = attr.i(); - for (int64_t i = 0; i < size_out; i++) { - GeTensorDesc ge_tensor_desc; - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); - } - } - } - // 3.Decode node_proto attributes - for (int i = 0; i < node_proto->attribute_size(); i++) { - DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc); - } - return true; -} - -bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) { - if (recursion_depth > kMaxRecursionDepth) { - GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort"); - return false; - } - - graph = ComGraphMakeShared(graph_proto.name()); - GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); - /// 1. Decode all nodes first, node should include input - /// and output nodes and nodes which represent sub graphs - std::map node_map; - std::vector node_proto_vector; - for (const auto &node_proto : graph_proto.node()) { - // a. nodes represent sub graphs - if (node_proto.op_type() == kNodeTypeForSubgraph) { - ComputeGraphPtr compute_graph; - // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH - const auto &node_attr = node_proto.attribute(0); - if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) && - DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) { - (void)graph->AddSubGraph(compute_graph); - } else { - GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), - node_attr.type()); - return false; - } - // b. direct nodes in graph - } else { - node_proto_vector.push_back(node_proto); - OpDescPtr op_desc = ComGraphMakeShared(); - // b.1 For node desc - if (!DecodeNodeDesc(&node_proto, op_desc)) { - GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str()); - return false; - } - auto node = graph->AddNode(op_desc); - node_map.insert(std::make_pair(node_proto.name(), node)); - } - } - /// We get all nodes in graph here - /// b.2 For node link - if (!DecodeNodeLink(node_proto_vector, node_map)) { - GELOGE(GRAPH_FAILED, "Decode node link failed"); - return false; - } - - // 2. Add inputs nodes for graph - for (const auto &input : graph_proto.input()) { - const auto &input_node_name = input.name(); - auto input_node_item = node_map.find(input_node_name); - if (input_node_item == node_map.end()) { - GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str()); - return false; - } - auto ret = graph->AddInputNode(input_node_item->second); - GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed"); - } - // 3. Add outputs nodes for graph - for (const auto &output : graph_proto.output()) { - const auto &output_node_name = output.name(); - auto output_node_item = node_map.find(output_node_name); - if (output_node_item == node_map.end()) { - GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str()); - return false; - } - auto ret = graph->AddOutputNode(output_node_item->second); - if (ret == nullptr) { - GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str()); - continue; - } - } - return true; -} - -bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) { - model.name_ = model_proto.producer_name(); - model.version_ = static_cast(model_proto.model_version()); - - auto &graph_proto = model_proto.graph(); - ComputeGraphPtr compute_graph; - // 0 means recursion depth, father call - if (!DecodeGraph(0, graph_proto, compute_graph)) { - GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed"); - return false; - } - model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph); - return true; -} -} // namespace ge diff --git a/metadef/graph/utils/ge_ir_utils.h b/metadef/graph/utils/ge_ir_utils.h deleted file mode 100644 index 9b16be18..00000000 --- a/metadef/graph/utils/ge_ir_utils.h +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ -#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "proto/ge_ir.pb.h" -#include "proto/onnx.pb.h" - -namespace ge { -const int kOffsetToString = 2; - -/// -/// @ingroup ge_ir_utils -/// @brief RepeatedField->String -/// @param [in] const rpd_field RepeatedField -/// @return String -/// -template -const std::string ToString(const google::protobuf::RepeatedField &rpd_field) { - std::stringstream ss; - ss << "["; - for (const T &x : rpd_field) { - ss << x; - ss << ", "; - } - std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); - str_ret += "]"; - return str_ret; -} - -/// -/// @ingroup ge_ir_utils -/// @brief RepeatedPtrField->String -/// @param [in] const rpd_field RepeatedPtrField -/// @return String -/// -template -const std::string ToString(const google::protobuf::RepeatedPtrField &rpd_ptr_field) { - std::stringstream ss; - ss << "["; - for (const T &x : rpd_ptr_field) { - ss << x; - ss << ", "; - } - std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); - str_ret += "]"; - return str_ret; -} - -/// -/// @ingroup ge_ir_utils -/// @brief check, if not equal, log with tag -/// @param [in] const left_value, right_value reference, log_info_tag -/// @return bool -/// -template -bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { - if (l_value == r_value) { - return true; - } else { - GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str()); - return false; - } -} - -class OnnxUtils { - public: - enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END }; - - static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto); - - static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model); - - private: - // Part 1: from IR convert to ONNX Protobuf - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, void *data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedField data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedField data); - - static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, - const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data); - - static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto); - - static void AddAttrProtoFromAttribute(const std::pair &string_attr_value, - onnx::NodeProto *node_proto); - - static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); - - static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map &attr_map, - onnx::NodeProto *node_proto, const std::string &prefix = "", - const std::string &suffix = ""); - - static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); - - static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); - - static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto); - - static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto); - - static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto); - - static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto); - - static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type); - - static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v); - - static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto); - - /// Part 2: from ONNX Protobuf convert to IR - /// Describes node's link relationships - struct NodeLinkInfo { - std::string src_node_name; - int32_t src_out_index; - NodePtr dst_node; - int32_t dst_in_index; - std::string dst_node_name; - }; - - // Parse node name and index - static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index); - - static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value); - - static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); - - static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_output_desc, int32_t index, - OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_desc, int32_t index, - OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_output_desc, int32_t index, - OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); - - static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); - - static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); - - static bool DecodeNodeLink(const std::vector &node_proto_vector, - const std::map &node_map); - - static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node); - - static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); -}; -} // namespace ge - -#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/metadef/graph/utils/graph_utils.cc b/metadef/graph/utils/graph_utils.cc deleted file mode 100644 index c741a316..00000000 --- a/metadef/graph/utils/graph_utils.cc +++ /dev/null @@ -1,2767 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/graph_utils.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "./ge_context.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "proto/ge_ir.pb.h" -#include "utils/attr_utils.h" -#include "utils/ge_ir_utils.h" -#include "utils/node_utils.h" -#include "debug/ge_op_types.h" -#include "external/ge/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" - -using google::protobuf::io::FileOutputStream; - -namespace ge { -enum DumpGraphLevel { - kDumpLevel1 = 1, - kDumpLevel2 = 2, - kDumpLevel3 = 3, - kDumpLevelOther, -}; - -namespace { -const int32_t kBaseOfIntegerValue = 10; -#ifdef FMK_SUPPORT_DUMP -const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; -const int kDumpGraphIndexWidth = 5; -#endif -const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; -const char *const kDumpStrBuild = "Build"; -const char *const kDumpStrPartition = "partition"; -const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; -const char *const kDumpStrSubgraphFunc = "sub_graph"; -const char *const kDumpStrAicpu = "Aicpu"; -}; // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const AnchorPtr &src, - const AnchorPtr &dst) { - OutDataAnchorPtr src_data = Anchor::DynamicAnchorCast(src); - InDataAnchorPtr dst_data = Anchor::DynamicAnchorCast(dst); - OutControlAnchorPtr src_control = Anchor::DynamicAnchorCast(src); - InControlAnchorPtr dst_control = Anchor::DynamicAnchorCast(dst); - if ((src_data != nullptr) && (dst_data != nullptr) && (src_data->LinkTo(dst_data) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_data != nullptr) && (dst_control != nullptr) && (src_data->LinkTo(dst_control) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_control != nullptr) && (dst_control != nullptr) && (src_control->LinkTo(dst_control) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_control != nullptr) && (dst_data != nullptr) && (src_control->LinkTo(dst_data) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const Format &src_format, - const InDataAnchorPtr &dst, - const Format &dst_format) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - auto ret = AnchorUtils::SetFormat(src, src_format); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed, format is %d", static_cast(src_format)); - return ret; - } - ret = AnchorUtils::SetFormat(dst, dst_format); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(dst_format)); - return ret; - } - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutControlAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Add edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const AnchorPtr &src, - const AnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutControlAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Remove edge Failed."); - return GRAPH_FAILED; -} - -graphStatus GraphUtils::ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const InDataAnchorPtr &new_dst) { - if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); - return GRAPH_FAILED; -} - -graphStatus GraphUtils::ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, - const InControlAnchorPtr &new_dst) { - if (RemoveEdge(src, dst) == GRAPH_SUCCESS && AddEdge(src, new_dst) == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Replace edge dst Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( - const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { - GE_CHECK_NOTNULL(src); - GE_CHECK_NOTNULL(dst); - GE_CHECK_NOTNULL(new_node); - - InDataAnchorPtr node_in_anchor = new_node->GetInDataAnchor(0); - GE_CHK_BOOL_RET_STATUS(node_in_anchor != nullptr, GRAPH_FAILED, "this node has not inDataAnchor"); - OutDataAnchorPtr node_out_anchor = new_node->GetOutDataAnchor(0); - GE_CHK_BOOL_RET_STATUS(node_out_anchor != nullptr, GRAPH_FAILED, "this node has not outDataAnchor"); - GE_CHK_STATUS_RET(src->ReplacePeer(dst, node_in_anchor, node_out_anchor), "ReplacePeer Failed"); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node) { - GE_CHECK_NOTNULL(compute_graph); - if (remove_node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - // Check if this node is belong to this compute graph, maybe a little slow - const auto &all_nodes_in_graph = compute_graph->GetDirectNode(); - if (std::find(all_nodes_in_graph.begin(), all_nodes_in_graph.end(), remove_node) == all_nodes_in_graph.end()) { - GELOGE(GRAPH_FAILED, "Can not find node %s in graph %s.", remove_node->GetName().c_str(), - compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - // Find all subgraph of this node - const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); - std::vector subgraphs; - std::vector all_nodes; - std::deque candidates; - NodePtr remove_node_new = remove_node; - candidates.emplace_back(remove_node_new); - while (!candidates.empty()) { - const NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { - auto subgraph = root_graph->GetSubgraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph); - candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); - } - } - } - // Remove all subgraph - for (const auto &remove_graph : subgraphs) { - if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Remove subgraph failed, sub graph name is %s, compute graph is %s.", - remove_node->GetName().c_str(), compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { - GE_CHECK_NOTNULL(compute_graph); - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should not be null."); - return GRAPH_FAILED; - } - - // If the node save as input node, delete it - (void)compute_graph->RemoveInputNode(node); - - // If the node save as output node, delete it - (void)compute_graph->RemoveOutputNode(node); - - // If the node has sub-graphs, delete them - auto ret = RemoveSubgraphRecursively(compute_graph, node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Remove subgraph recursively failed."); - return GRAPH_FAILED; - } - - auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); - if (iter != compute_graph->nodes_.end()) { - compute_graph->nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -/// Add two edges to the new node, respectively connecting the SRC and DST -/// associated with the original edge -/// A ---> B transfered to A ---> N ---> B -graphStatus InsertTransNode(ComputeGraph &compute_graph, const InDataAnchorPtr &in_data_anchor, - const std::vector &vec_op_desc) { - GE_CHECK_NOTNULL(in_data_anchor); - for (const auto &op_desc : vec_op_desc) { - GE_CHECK_NOTNULL(op_desc); - - auto ret = op_desc->AddInputDesc(GeTensorDesc()); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); - ret = op_desc->AddOutputDesc(GeTensorDesc()); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return GRAPH_FAILED, "Add input desc failed"); - auto node_to_insert = compute_graph.AddNode(op_desc); - - GE_CHECK_NOTNULL(node_to_insert); - GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor()); - - auto src = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); - if (!src) { - GELOGE(GRAPH_FAILED, "src nullptr error."); - return GRAPH_FAILED; - } - - auto src_out_index = in_data_anchor->GetPeerOutAnchor()->GetIdx(); - - auto dst = in_data_anchor->GetOwnerNode(); - if (!dst) { - GELOGE(GRAPH_FAILED, "dst nullptr error."); - return GRAPH_FAILED; - } - - auto dst_in_index = in_data_anchor->GetIdx(); - - auto in_data_anchor_src_format = AnchorUtils::GetFormat(in_data_anchor->GetPeerOutAnchor()); - auto in_data_anchor_dst_format = AnchorUtils::GetFormat(in_data_anchor); - - GE_CHECK_NOTNULL(src->GetOutDataAnchor(src_out_index)); - GE_CHECK_NOTNULL(dst->GetInDataAnchor(dst_in_index)); - - ret = GraphUtils::RemoveEdge(src->GetOutDataAnchor(src_out_index), dst->GetInDataAnchor(dst_in_index)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Remove edge failed"); - return GRAPH_FAILED; - } - - GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)); - GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)); - - ret = GraphUtils::AddEdge(src->GetOutDataAnchor(src_out_index), node_to_insert->GetInDataAnchor(0)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return ret; - } - ret = GraphUtils::AddEdge(node_to_insert->GetOutDataAnchor(0), dst->GetInDataAnchor(dst_in_index)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return ret; - } - - if (op_desc->HasAttr("input_format")) { - int64_t input_format = 0; - int64_t output_format = 0; - if (!AttrUtils::GetInt(op_desc, "input_format", input_format)) { - GELOGW("get attr input_format failed"); - continue; - } - if (!AttrUtils::GetInt(op_desc, "output_format", output_format)) { - GELOGW("get attr output_format failed"); - continue; - } - - GE_CHECK_NOTNULL(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor()); - GE_CHK_BOOL_RET_STATUS(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().empty(), GRAPH_FAILED, - "Vistor is empty"); - GE_CHECK_NOTNULL(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)); - - auto status = - AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0)->GetPeerOutAnchor(), in_data_anchor_src_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_src_format)); - return status; - } - status = AnchorUtils::SetFormat(node_to_insert->GetInDataAnchor(0), static_cast(input_format)); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", input_format); - return status; - } - status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0), static_cast(output_format)); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %ld", output_format); - return status; - } - status = AnchorUtils::SetFormat(node_to_insert->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0), - in_data_anchor_dst_format); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set format failed,format is %d", static_cast(in_data_anchor_dst_format)); - return status; - } - } - std::vector original_nodes; - GraphUtils::RecordOriginalNames(original_nodes, node_to_insert); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTransNode( - ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector &vec_op_desc) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(in_data_anchor); - graphStatus ret = - ge::InsertTransNode(*compute_graph, in_data_anchor, vec_op_desc) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED; - return ret; -} - -/// -/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst -/// @param [in] src -/// @param [in] dsts -/// @param [in] insert_node -/// @param [in] input_index -/// @param [in] output_index -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { - GE_CHECK_NOTNULL(src); - GE_CHECK_NOTNULL(insert_node); - - NodePtr src_node = src->GetOwnerNode(); - if (src_node->GetOwnerComputeGraph() != insert_node->GetOwnerComputeGraph()) { - GELOGE(GRAPH_FAILED, "src:%s and insert_node:%s not exist in the same graph.", src_node->GetName().c_str(), - insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "AddEdge %s->%s failed.", src_node->GetName().c_str(), insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(src_out_ctrl_anchor); - - bool ctrl_edge_flag = true; - std::string type = NodeUtils::GetNodeType(src->GetOwnerNode()); - if ((type == SWITCH) || (type == REFSWITCH) || (type == SWITCHN)) { - ctrl_edge_flag = false; - } - - for (auto &dst : dsts) { - GE_CHECK_NOTNULL(dst); - NodePtr dst_node = dst->GetOwnerNode(); - GELOGI("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src_node->GetName().c_str(), - dst_node->GetName().c_str()); - if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { - GELOGE(GRAPH_FAILED, "src:%s and dst:%s not exist in the same graph.", src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - - (void)RemoveEdge(src, dst); - if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (!ctrl_edge_flag) { - continue; - } - for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { - if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || - (AddEdge(insert_node->GetOutControlAnchor(), peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { - GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), insert_node->GetName().c_str(), - peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraph &compute_graph, - const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); - return GRAPH_FAILED; - } - auto iter = find(compute_graph.nodes_.begin(), compute_graph.nodes_.end(), node); - if (iter != compute_graph.nodes_.end()) { - compute_graph.nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraphPtr compute_graph, - const NodePtr &node) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(node); - graphStatus ret = (RemoveJustNode(*compute_graph, node) == GRAPH_SUCCESS ? GRAPH_SUCCESS : GRAPH_FAILED); - return ret; -} - -void GraphUtils::RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); - std::vector original_names; - for (const auto &node_tmp : original_nodes) { - std::vector names_tmp; - ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); - if (opdesc_tmp == nullptr) { - GELOGE(GRAPH_FAILED, "Node %s get opdesc is nullptr", node_tmp->GetName().c_str()); - continue; - } - auto ret = ge::AttrUtils::GetListStr(opdesc_tmp, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp); - if (!ret) { - GELOGW("Get list str failed"); - continue; - } - if (names_tmp.size() != 0) { - original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); - } else { - original_names.push_back(opdesc_tmp->GetName()); - } - } - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), - return, "Set original_op_names fail."); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNames(std::vector names_tmp, - const ge::NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return, "node is null."); - std::vector original_names; - if (names_tmp.size() != 0) { - original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); - } else { - std::string tmp; - original_names.push_back(tmp); - } - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), - return, "Set original_op_names fail."); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(const std::string &suffix) { - char *dump_level = std::getenv(kDumpGraphLevel); - int64_t dump_graph_level = - (dump_level != nullptr) ? std::strtol(dump_level, nullptr, kBaseOfIntegerValue) : kDumpLevel2; - - if (dump_graph_level == kDumpLevel1) { - return false; - } - - if (dump_graph_level == kDumpLevel2 && - ((suffix.find(kDumpStrPartition) != std::string::npos) || - (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || - (suffix.find(kDumpStrAicpu) != std::string::npos) || (suffix.find(kDumpStrSubgraphFunc) != std::string::npos))) { - return true; - } - - if (dump_graph_level == kDumpLevel3 && suffix.compare(kDumpStrBuild) != 0) { - return true; - } - - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, - const std::string &suffix, - bool is_always_dump, - const std::string &user_graph_name) { -#ifdef FMK_SUPPORT_DUMP - char *dump_ge_graph = std::getenv(kDumpGeGraph); - GE_IF_BOOL_EXEC(dump_ge_graph == nullptr && !is_always_dump, return;); - - // dump the graph according to different graph level - if (GraphUtils::MatchDumpStr(suffix)) { - return; - } - - // file name - static std::atomic_long atomic_file_index(0); - auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump om txt: %ld", file_index); - - thread_local long max_dump_file_num = 0; - if (max_dump_file_num == 0) { - string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); - max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if (max_dump_file_num != 0 && file_index > max_dump_file_num) { - GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); - return; - } - - std::stringstream stream_file_name; - stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; - stream_file_name << "_" << suffix << ".txt"; - std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; - - // Create buffer - ge::Model model("", ""); - model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(graph))); - Buffer buffer; - const int64_t kDumpLevel = - (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; - model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); - - // Write file - ge::proto::ModelDef ge_proto; - if (buffer.GetData() != nullptr) { - std::string str(reinterpret_cast(buffer.GetData()), buffer.GetSize()); - if (!ge_proto.ParseFromString(str)) { - GELOGE(GRAPH_FAILED, "parse from string failed."); - return; - } - char real_path[PATH_MAX] = {0x00}; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= PATH_MAX, return, "file path is too longer!"); - GE_IF_BOOL_EXEC(realpath(proto_file.c_str(), real_path) == nullptr, - GELOGI("file %s does not exist, it will be created.", proto_file.c_str())); - - GraphUtils::WriteProtoToTextFile(ge_proto, real_path); - } -#else - GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, - ge::ComputeGraph &compute_graph) { - ge::proto::ModelDef model_def; - // Get ModelDef object from file generated by DumpGEGraph() - if (!ReadProtoFromTextFile(file, &model_def)) { - GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); - return false; - } - ge::Model model; - // Get Model object from ModelDef by deserialize ModelDef - if (model.Load(model_def) == GRAPH_SUCCESS) { - GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, - "Get computer graph is nullptr"); - compute_graph = *(GraphUtils::GetComputeGraph(model.GetGraph())); - return true; - } else { - GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); - return false; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, - ge::ComputeGraphPtr &compute_graph) { - ge::proto::ModelDef model_def; - // Get ModelDef object from file generated by DumpGEGraph() - if (!ReadProtoFromTextFile(file, &model_def)) { - GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); - return false; - } - ge::Model model; - // Get Model object from ModelDef by deserialize ModelDef - if (model.Load(model_def) == GRAPH_SUCCESS) { - GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, - "Get computer graph is nullptr"); - compute_graph = GraphUtils::GetComputeGraph(model.GetGraph()); - for (const auto &node : compute_graph->GetDirectNode()) { - GELOGI("Node %s set owner graph", node->GetName().c_str()); - GE_CHECK_NOTNULL(node); - if (node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node %s set owner graph failed", node->GetName().c_str()); - return false; - } - } - return true; - } else { - GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); - return false; - } -} - -// Printing protocol messages in text format is useful for debugging and human editing of messages. -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( - const google::protobuf::Message &proto, const char *real_path) { -#ifdef FMK_SUPPORT_DUMP - const int FILE_AUTHORITY = 0600; - int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); - if (fd < 0) { - GELOGE(GRAPH_FAILED, "fail to open the file: %s, %s", real_path, strerror(errno)); - return; - } - google::protobuf::io::FileOutputStream *output = new (std::nothrow) FileOutputStream(fd); - if (output == nullptr) { - GELOGE(GRAPH_FAILED, "Output is nullptr"); - if (close(fd) != 0) { - GELOGE(GRAPH_FAILED, "Close fileoutputstream failed"); - } - return; - } - bool ret = google::protobuf::TextFormat::Print(proto, output); - if (!ret) { - GELOGE(GRAPH_FAILED, "Fail to write the file: %s", real_path); - delete output; - output = nullptr; - GE_CHK_BOOL_EXEC(close(fd) == 0, return, "Close fileoutputstream failed"); - return; - } - delete output; - output = nullptr; - GE_CHK_BOOL_EXEC(close(fd) == 0, return, "Close fileoutputstream failed"); - - FILE *file = fopen(real_path, "rb"); - if (file == nullptr) { - return; - } - if (fseek(file, 0L, SEEK_END) == 0) { - long fileSize = ftell(file); - thread_local long max_dump_file_size = 0; - if (max_dump_file_size == 0) { - string opt = "0"; - // Can not check return value - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); - max_dump_file_size = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if (max_dump_file_size != 0 && fileSize != -1 && fileSize > max_dump_file_size) { - GELOGW("dump graph file size > maxDumpFileSize, maxDumpFileSize=%ld.", max_dump_file_size); - GE_IF_BOOL_EXEC(std::remove(real_path) != 0, GELOGW("remove %s failed", real_path)); - GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose %s failed", real_path); - return; - } - } - GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed"); -#else - GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( - const char *file, google::protobuf::Message *proto) { - if (file == nullptr || proto == nullptr) { - GELOGE(GRAPH_FAILED, "incorrect parameter. file path or message is invalid"); - return false; - } - std::ifstream fs(file, std::ifstream::in); - if (!fs.is_open()) { - GELOGE(GRAPH_FAILED, "proto file '%s' open fail.", file); - return false; - } - google::protobuf::io::IstreamInputStream input(&fs); - bool ret = google::protobuf::TextFormat::Parse(&input, proto); - if (!ret) { - GELOGE(GRAPH_FAILED, "parse proto from text ret fail, please check your text file '%s'.", file); - } - fs.close(); - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, - const std::string &suffix) { -#ifdef FMK_SUPPORT_DUMP - char *dump_ge_graph = std::getenv(kDumpGeGraph); - int64_t dump_ge_graph_level = - (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : OnnxUtils::NO_DUMP; - if ((dump_ge_graph_level == OnnxUtils::NO_DUMP) || (dump_ge_graph_level >= OnnxUtils::DUMP_LEVEL_END)) { - GELOGD("Skip DumpGEGraphToOnnx with dump_ge_graph_level %ld.", dump_ge_graph_level); - return; - } - - // dump the graph according to different graph level - if (GraphUtils::MatchDumpStr(suffix)) { - return; - } - - // 1.Get ge::onnx::ModelProto from ge::Model - ge::Model model("GE", ""); - std::shared_ptr compute_graph_ptr = ComGraphMakeShared(compute_graph); - model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast(compute_graph_ptr))); - onnx::ModelProto model_proto; - if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { - GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed."); - return; - } - - // 2.Set file name - static std::atomic_long atomic_file_index(0); - auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump ge onnx file: %ld", file_index); - - thread_local long max_dump_file_num = 0; - if (max_dump_file_num == 0) { - string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); - max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if (max_dump_file_num != 0 && file_index > max_dump_file_num) { - GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); - return; - } - - std::stringstream stream_file_name; - stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; - stream_file_name << "_graph_" << compute_graph.GetGraphID(); - stream_file_name << "_" << suffix << ".pbtxt"; - std::string proto_file = stream_file_name.str(); - if ((proto_file.length()) >= NAME_MAX) { - GELOGE(GRAPH_FAILED, "File name is too longer!"); - return; - } - std::unique_ptr real_path(new (std::nothrow) char[PATH_MAX]{0}); - if (real_path == nullptr) { - GELOGE(GRAPH_FAILED, "New real_path failed."); - return; - } - /// Returning nullptr means 3 case as follows: - /// a.path is PATH_MAX chars or more - /// b.the file does not exist - /// c.the path has no permissions - /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() - if (realpath(proto_file.c_str(), real_path.get()) == nullptr) { - // For case a - if (errno == ENAMETOOLONG) { - GELOGE(GRAPH_FAILED, "Call realpath failed: path is PATH_MAX chars or more."); - return; - } - } - - // 3. Serialize to file in current path - GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); -#else - GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file, - ge::ComputeGraph &compute_graph) { - if (file == nullptr) { - GELOGE(GRAPH_FAILED, "incorrect parameter. file path is invalid"); - return false; - } - onnx::ModelProto model_proto; - // 1. Get ModelDef object from file generated by DumpGEGraphToOnnx() - if (!ReadProtoFromTextFile(file, &model_proto)) { - GELOGE(GRAPH_FAILED, "Get ModelDef from file failed"); - return false; - } - // 2.Convert onnx::ModelProto To ge::Model - ge::Model model; - if (!OnnxUtils::ConvertModelProtoToGeModel(model_proto, model)) { - GELOGE(GRAPH_FAILED, "Convert ModelDef to Model failed"); - return false; - } - auto compute_graph_ptr = GraphUtils::GetComputeGraph(model.GetGraph()); - if (compute_graph_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "Get compute graph from Model failed"); - return false; - } - compute_graph = *(compute_graph_ptr); - return true; -} - -namespace { -using InNodesToOut = std::unordered_map>; - -inline std::string GetNodeNameByAnchor(const Anchor *anchor) { - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Anchor is nullptr"); - return "Null"; - } - auto node = anchor->GetOwnerNode(); - return node == nullptr ? "Null" : node->GetName(); -} - -graphStatus ReplaceOutDataAnchor(const OutDataAnchorPtr &new_anchor, const OutDataAnchorPtr &old_anchor, - InNodesToOut *in_nodes_to_out = nullptr) { - if (new_anchor == nullptr || old_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "new_anchor or old_anchor is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto new_node = new_anchor->GetOwnerNode(); - for (const auto &peer_in_anchor : old_anchor->GetPeerInDataAnchors()) { - auto ret = peer_in_anchor->Unlink(old_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to unlink old anchor link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(old_anchor.get()).c_str(), old_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - ret = peer_in_anchor->LinkFrom(new_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to relink new anchors from %s(%d) to %s(%d)", - GetNodeNameByAnchor(new_anchor.get()).c_str(), new_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - - if (in_nodes_to_out != nullptr) { - (*in_nodes_to_out)[new_node].insert(peer_in_anchor->GetOwnerNode()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, InNodesToOut &in_nodes_to_out) { - GE_CHECK_NOTNULL(node); - auto in_data_anchors = node->GetAllInDataAnchors(); - auto out_data_anchors = node->GetAllOutDataAnchors(); - if (out_data_anchors.size() < io_map.size()) { - GELOGE(GRAPH_FAILED, "The io_map specified for node %s type %s is larger %zu than the actual size %zu", - node->GetName().c_str(), node->GetType().c_str(), io_map.size(), out_data_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - for (size_t i = 0; i < out_data_anchors.size(); ++i) { - auto out_data_anchor = out_data_anchors.at(i); - if (out_data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to relink for node %s type %s, the out data anchor at index %zu is null", - node->GetName().c_str(), node->GetType().c_str(), i); - return GRAPH_FAILED; - } - - int in_index = -1; - if (i < io_map.size()) { - in_index = io_map.at(i); - } - if (in_index < 0) { - out_data_anchor->UnlinkAll(); - continue; - } - - if (in_index >= static_cast(in_data_anchors.size())) { - GELOGE(GRAPH_PARAM_INVALID, "Failed to relink for node %s type %s, invalid index %d specified for input(%zu)", - node->GetName().c_str(), node->GetType().c_str(), in_index, in_data_anchors.size()); - return GRAPH_PARAM_INVALID; - } - auto in_anchor = in_data_anchors.at(in_index); - if (in_anchor == nullptr) { - GELOGW("Invalid in data anchors(null) found at node %s type %s index %d, ignore it.", node->GetName().c_str(), - node->GetType().c_str(), in_index); - continue; - } - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - if (peer_out_anchor->Unlink(in_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed relink node %s type %s, failed to unlink the data link" - " from %s(%d) to it at input-index %d", - node->GetName().c_str(), node->GetType().c_str(), GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), - peer_out_anchor->GetIdx(), in_index); - return GRAPH_FAILED; - } - auto ret = ReplaceOutDataAnchor(peer_out_anchor, out_data_anchor, &in_nodes_to_out); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to relink node %s type %s for relinking data anchors", node->GetName().c_str(), - node->GetType().c_str()); - return GRAPH_FAILED; - } - } - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - in_anchor->UnlinkAll(); - } - return GRAPH_SUCCESS; -} - -InNodesToOut GetFullConnectIONodes(const NodePtr &node) { - InNodesToOut in_nodes_to_out; - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); - return in_nodes_to_out; - } - auto in_nodes_list = node->GetInNodes(); - auto out_nodes_list = node->GetOutNodes(); - auto out_nodes = std::unordered_set(out_nodes_list.begin(), out_nodes_list.end()); - - for (const auto &in_node : in_nodes_list) { - in_nodes_to_out.insert(std::make_pair(in_node, out_nodes)); - } - return in_nodes_to_out; -} - -graphStatus RelinkControlNodeIfNeed(const NodePtr &node, InNodesToOut &in_nodes_to_out, - InNodesToOut &connected_data_in_to_out) { - GE_CHECK_NOTNULL(node); - for (const auto &in_node_to_out : in_nodes_to_out) { - auto &in_node = in_node_to_out.first; - GE_CHECK_NOTNULL(in_node); - auto &connected_data_out = connected_data_in_to_out[in_node]; - for (const auto &out_node : in_node_to_out.second) { - GE_CHECK_NOTNULL(out_node); - if (connected_data_out.count(out_node) == 0) { - GE_CHECK_NOTNULL(in_node->GetOutControlAnchor()); - if (in_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor())) { - continue; - } - auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), out_node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when isolating node %s type %s", - in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), - node->GetType().c_str()); - return GRAPH_FAILED; - } - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceOutDataAnchors(const Node::Vistor &new_outs, - const Node::Vistor &old_outs, const std::vector &outputs_map) { - auto new_out_size = new_outs.size(); - if (new_out_size < outputs_map.size()) { - GELOGE(GRAPH_PARAM_INVALID, - "Failed to replace out data anchors, the actual size %zu is less than the mapping size %zu", new_out_size, - outputs_map.size()); - return GRAPH_PARAM_INVALID; - } - for (size_t i = 0; i < new_out_size; ++i) { - auto &new_out_anchor = new_outs.at(i); - if (new_out_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on new node is null, index %zu", i); - return GRAPH_FAILED; - } - if (i >= outputs_map.size()) { - continue; - } - auto old_index = outputs_map.at(i); - if (old_index < 0) { - continue; - } - - const OutDataAnchorPtr &old_out_anchor = old_outs.at(old_index); - if (old_out_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace out data anchors, the out data anchor on old node is null, index %d", - old_index); - return GRAPH_FAILED; - } - auto ret = ReplaceOutDataAnchor(new_out_anchor, old_out_anchor); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - - return GRAPH_SUCCESS; -} - -graphStatus ReplaceInDataAnchors(const Node::Vistor &new_ins, - const Node::Vistor &old_ins, const std::vector &inputs_map) { - auto new_in_size = new_ins.size(); - if (new_in_size < inputs_map.size()) { - GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the actual size %zu is less than the mapping size %zu", - new_in_size, inputs_map.size()); - return GRAPH_PARAM_INVALID; - } - - for (size_t i = 0; i < new_in_size; ++i) { - auto &new_in_anchor = new_ins.at(i); - if (new_in_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on new node is null, index %zu", i); - return GRAPH_FAILED; - } - if (i >= inputs_map.size()) { - continue; - } - auto old_index = inputs_map.at(i); - if (old_index < 0) { - continue; - } - const InDataAnchorPtr &old_in_anchor = old_ins.at(old_index); - if (old_in_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Failed to replace in data anchors, the out data anchor on old node is null, index %d", - old_index); - return GRAPH_FAILED; - } - - auto peer_out_anchor = old_in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - GELOGW("Peer out anchor is nullptr"); - continue; - } - auto ret = peer_out_anchor->Unlink(old_in_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to unlink old anchors, unlink from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - ret = peer_out_anchor->LinkTo(new_in_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to link new anchors, link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_node) { - GE_CHECK_NOTNULL(new_node); - GE_CHECK_NOTNULL(new_node->GetInControlAnchor()); - GE_CHECK_NOTNULL(old_node); - GE_CHECK_NOTNULL(old_node->GetInControlAnchor()); - auto peer_out_anchors = old_node->GetInControlAnchor()->GetPeerAnchors(); - auto new_in_control_anchor = new_node->GetInControlAnchor(); - auto exists_out_anchors = new_in_control_anchor->GetPeerAnchors(); - auto exists_out_anchors_set = std::set(exists_out_anchors.begin(), exists_out_anchors.end()); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor != nullptr) { - if (exists_out_anchors_set.count(peer_out_anchor) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(peer_out_anchor, new_in_control_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return GRAPH_FAILED; - } - } else { - GELOGW("peer outanchor is nullptr"); - continue; - } - } - auto old_out_control_anchor = old_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(old_out_control_anchor); - auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); - auto new_out_control_anchor = new_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(new_out_control_anchor); - auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); - auto exists_in_anchors_set = std::set(exists_in_anchors.begin(), exists_in_anchors.end()); - for (const auto &peer_in_anchor : peer_in_anchors) { - if (peer_in_anchor != nullptr) { - if (exists_in_anchors_set.count(peer_in_anchor) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(new_out_control_anchor, peer_in_anchor); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add edge failed"); - return GRAPH_FAILED; - } - } else { - GELOGW("Peer inanchor is nullptr"); - continue; - } - } - - return GRAPH_SUCCESS; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNode(const NodePtr &node, - const std::vector &io_map) { - if (node == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "Failed to isolate node(null)"); - return GRAPH_PARAM_INVALID; - } - - /// We must get full connections info before re-link data io, because the data - /// edges may be unlinked when relink data io - auto in_nodes_to_out = GetFullConnectIONodes(node); - - InNodesToOut data_in_to_out; - auto ret = RelinkDataIO(node, io_map, data_in_to_out); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to isolate node %s type %s when relink data IO", node->GetName().c_str(), - node->GetType().c_str()); - return ret; - } - - ret = RelinkControlNodeIfNeed(node, in_nodes_to_out, data_in_to_out); - if (ret != GRAPH_SUCCESS) { - return ret; - } - NodeUtils::UnlinkAll(*node); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::IsolateNode(const NodePtr &node, const std::initializer_list &io_map) { - return IsolateNode(node, std::vector(io_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNodeOneIO(const NodePtr &node) { - if (node == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "incorrect parameter. node is invalid"); - return GRAPH_PARAM_INVALID; - } - if (node->GetAllInDataAnchorsSize() != 1) { - return GRAPH_PARAM_INVALID; - } - if (node->GetAllOutDataAnchorsSize() != 1) { - return GRAPH_PARAM_INVALID; - } - return IsolateNode(node, {0}); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, - const std::vector &outputs_map) { - if ((new_node == nullptr) || (old_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto ret = ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - if (ret != GRAPH_SUCCESS) { - // The error log was printed in `ReplaceNodeDataAnchors` - return GRAPH_FAILED; - } - ret = ReplaceControlAnchors(new_node, old_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed to replace control anchors when replace node from old node %s type %s to new node %s type %s", - old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), - new_node->GetType().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( - const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list inputs_map, - const std::initializer_list outputs_map) { - return ReplaceNodeAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - std::initializer_list inputs_map, std::initializer_list outputs_map) { - return ReplaceNodeDataAnchors(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, - const std::vector &outputs_map) { - if (new_node == nullptr || old_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - - auto ret = ReplaceOutDataAnchors(new_node->GetAllOutDataAnchors(), old_node->GetAllOutDataAnchors(), outputs_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed to replace out data anchors when replace node from old node %s type %s to new node %s type %s", - old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), - new_node->GetType().c_str()); - return GRAPH_FAILED; - } - ret = ReplaceInDataAnchors(new_node->GetAllInDataAnchors(), old_node->GetAllInDataAnchors(), inputs_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, - "Failed to replace in data anchors when replace node from old node %s type %s to new node %s type %s", - old_node->GetName().c_str(), old_node->GetType().c_str(), new_node->GetName().c_str(), - new_node->GetType().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInCtrlEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto src_ctrl_in_nodes = src_node->GetInControlNodes(); - if (src_ctrl_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - - std::unordered_set exist_in_ctrl_nodes_set; - auto exist_in_ctrl_nodes = dst_node->GetInControlNodes(); - if (!exist_in_ctrl_nodes.empty()) { - exist_in_ctrl_nodes_set.insert(exist_in_ctrl_nodes.begin(), exist_in_ctrl_nodes.end()); - } - - auto dst_ctrl = dst_node->GetInControlAnchor(); - for (const auto &in_node : src_ctrl_in_nodes) { - if (exist_in_ctrl_nodes_set.count(in_node) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_ctrl); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", - in_node->GetName().c_str(), dst_node->GetName().c_str(), src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return ret; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveInCtrlEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_FAILED; - } - auto ret = CopyInCtrlEdges(src_node, dst_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Copy in ctrl edges failed"); - return ret; - } - GE_CHECK_NOTNULL(src_node->GetInControlAnchor()); - src_node->GetInControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOutCtrlEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_FAILED; - } - auto out_ctrl_nodes = src_node->GetOutControlNodes(); - if (out_ctrl_nodes.empty()) { - return GRAPH_SUCCESS; - } - - std::unordered_set exists_out_ctrl_nodes_set; - for (const auto &node : dst_node->GetOutControlNodes()) { - exists_out_ctrl_nodes_set.insert(node.get()); - } - - auto dst_out_ctrl = dst_node->GetOutControlAnchor(); - for (const auto &node : out_ctrl_nodes) { - if (exists_out_ctrl_nodes_set.count(node.get()) > 0) { - continue; - } - auto ret = GraphUtils::AddEdge(dst_out_ctrl, node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add control edge from %s to %s when copy control dependencies from %s to %s", - dst_node->GetName().c_str(), node->GetName().c_str(), src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return ret; - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCtrlEdges(NodePtr &src_node, - NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_FAILED; - } - auto ret = CopyOutCtrlEdges(src_node, dst_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); - return ret; - } - GE_CHECK_NOTNULL(src_node->GetOutControlAnchor()); - src_node->GetOutControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} - -/// -/// Copy all in-data edges from `src_node` to `dst_node`. -/// @param src_node -/// @param dst_node -/// @return -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, - NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - GELOGE(GRAPH_FAILED, "Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto src_data_in_nodes = src_node->GetInDataNodes(); - if (src_data_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { - auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); - auto ret = - GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", - in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), - src_node->GetName().c_str(), dst_node->GetName().c_str()); - return ret; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, - const NodePtr &node) { - if (graph->AddInputNode(node) == nullptr) { - GELOGE(GRAPH_FAILED, "Copyout ctrl edges failed"); - return GRAPH_FAILED; - } - graph->SetInputSize(graph->GetInputSize() + 1); - graph->inputs_order_.emplace_back(node->GetName()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindRootGraph(ComputeGraphPtr graph) { - ComputeGraphPtr result = nullptr; - while (graph != nullptr) { - result = std::move(graph); - graph = result->GetParentGraph(); - } - return result; -} - -/// -/// Make a copy of ComputeGraph. -/// @param graph: original graph. -/// @param prefix: node name prefix of new graph. -/// @param output_nodes: output nodes of new graph. -/// @return ComputeGraphPtr -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr -GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, std::vector &input_nodes, - std::vector &output_nodes) { - GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null"); - ComputeGraphPtr new_graph = ComGraphMakeShared(graph->GetName()); - GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed"); - - std::unordered_map all_new_nodes; - for (const auto &n : graph->GetDirectNode()) { - OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc()); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed"); - - if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { - return nullptr; - } - - op_desc->SetName(prefix + n->GetName()); - NodePtr node = new_graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); - all_new_nodes[node->GetName()] = node; - - if (node->GetType() == DATA) { - input_nodes.emplace_back(node); - } else if (node->GetType() == NETOUTPUT) { - output_nodes.emplace_back(node); - } - } - - for (const auto &n : graph->GetDirectNode()) { - if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) { - return nullptr; - } - } - - std::string session_graph_id; - if (AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { - bool ret = AttrUtils::SetStr(*new_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); - if (!ret) { - GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_SESSION_GRAPH_ID failed."); - return nullptr; - } - } - return new_graph; -} - -/// -/// Copy tensor attribute to new node. -/// @param [in] dst_node: cloned node. -/// @param [in] src_node: original node. -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { - if (dst_desc == nullptr) { - GELOGE(GRAPH_FAILED, "Input param dst node not valid"); - return GRAPH_FAILED; - } - if (src_node == nullptr || src_node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "Input param src node not valid"); - return GRAPH_FAILED; - } - - const auto &src_desc = src_node->GetOpDesc(); - dst_desc->CopyAttrsFrom(*src_desc); - - for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { - auto input_desc = dst_desc->MutableInputDesc(i); - if (input_desc == nullptr) { - continue; - } - input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); - } - - for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) { - auto output_desc = dst_desc->MutableOutputDesc(i); - if (output_desc == nullptr) { - GELOGE(GRAPH_FAILED, "Param dst node not valid"); - return GRAPH_FAILED; - } - output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); - } - - return GRAPH_SUCCESS; -} - -/// -/// Relink all edges for cloned ComputeGraph. -/// @param [in] node: original node. -/// @param [in] prefix: node name prefix of new node. -/// @param [in] all_nodes: all nodes in new graph. -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, - const std::unordered_map &all_nodes) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "Input node not valid"); - return GRAPH_FAILED; - } - - auto it = all_nodes.find(prefix + node->GetName()); - if (it == all_nodes.end()) { - GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_node = it->second; - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null"); - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - GELOGW("Peer out anchor is null: %s", node->GetName().c_str()); - continue; - } - GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); - - it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); - if (it == all_nodes.end()) { - GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_out_node = it->second; - - auto rslt = - GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInAnchor(in_anchor->GetIdx())); - GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", - new_out_node->GetName().c_str(), new_node->GetName().c_str()); - } - - if (node->GetInControlAnchor() != nullptr) { - for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { - GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); - GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); - - it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); - if (it == all_nodes.end()) { - GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_out_node = it->second; - - auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor()); - GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", - new_out_node->GetName().c_str(), new_node->GetName().c_str()); - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Get reference-mapping of all data_anchors in graph -/// @param [in] graph -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(graph); - for (const auto &node : graph->GetAllNodes()) { - // in_data_anchor - if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - - // out_data_anchor - if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, - const std::string &name) { - auto root_graph = FindRootGraph(graph); - if (root_graph == nullptr) { - GE_LOGE("Failed find node %s, null root graph", name.c_str()); - return nullptr; - } - - for (const auto &node : root_graph->GetAllNodes()) { - if (node == nullptr) { - continue; - } - if (node->GetName() == name) { - return node; - } - } - - return nullptr; -} - -/// -/// Get reference-mapping for in_data_anchors of node -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - - if (NodeUtils::IsSubgraphOutput(node)) { - return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); - } - - if (NodeUtils::IsSubgraphInput(node)) { - return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); - } - - const std::string &type = node->GetType(); - if ((type == MERGE) || (type == STREAMMERGE)) { - return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); - } - - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - const std::string &symbol = cur_node_info.ToString(); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors[symbol] = {cur_node_info}; - anchor_to_symbol[symbol] = symbol; - } else { - NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Update symbol mapping failed."); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Get reference-mapping for out_data_anchors of node -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); - if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { - continue; - } - - int32_t reuse_in_index = -1; - if (IsRefFromInput(out_data_anchor, reuse_in_index)) { - NodeIndexIO exist_node_info(node, reuse_in_index, kIn); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Update symbol mapping failed."); - return GRAPH_FAILED; - } - } else { - const std::string &symbol = cur_node_info.ToString(); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors.emplace(std::make_pair(symbol, std::list{cur_node_info})); - anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Handle input of subgraph -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - - // Data in subgraph - uint32_t index = 0; - if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { - GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); - GE_CHECK_NOTNULL(parent_in_anchor); - OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor != nullptr) { - // Data has and only has one input - NodeIndexIO cur_node_info(node, 0, kIn); - NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("Update symbol mapping failed."); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Handle input of Merge op -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - std::vector exist_node_infos; - std::vector cur_node_infos; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - std::string next_name; - if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { - ComputeGraphPtr graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - ge::NodePtr next_node = graph->FindNode(next_name); - GE_CHECK_NOTNULL(next_node); - // NextIteration has and only has one output - peer_out_anchor = next_node->GetOutDataAnchor(0); - GE_CHECK_NOTNULL(peer_out_anchor); - cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); - cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); - } - } else { - cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); - exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); - } - } - - size_t anchor_nums = 0; - NodeIndexIO max_node_index_io(nullptr, 0, kOut); - for (const auto &temp_node_info : exist_node_infos) { - auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); - if (iter1 != anchor_to_symbol.end()) { - const std::string &temp_symbol = iter1->second; - auto iter2 = symbol_to_anchors.find(temp_symbol); - if (iter2 != symbol_to_anchors.end()) { - if (iter2->second.size() > anchor_nums) { - max_node_index_io = temp_node_info; - anchor_nums = iter2->second.size(); - } - } - } - } - - std::string symbol; - for (const auto &temp_node_info : exist_node_infos) { - if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != - GRAPH_SUCCESS) || - symbol.empty()) { - GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), - temp_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - } - - auto iter = symbol_to_anchors.find(symbol); - if (iter != symbol_to_anchors.end()) { - for (const auto &temp_node_info : cur_node_infos) { - GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); - iter->second.emplace_back(temp_node_info); - anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); - } - } - - return GRAPH_SUCCESS; -} - -/// -/// Handle output of subgraph -/// @param [in] node -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - NodePtr parent_node = owner_graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - - GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); - uint32_t index = 0; - if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { - continue; - } - GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); - // Union symbol of peer_out_anchor & parent_out_anchor - NodeIndexIO peer_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - NodeIndexIO parent_node_info(parent_node, index, kOut); - std::string symbol; - if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != - GRAPH_SUCCESS) || - symbol.empty()) { - GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), - parent_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - - NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors[symbol].emplace_back(cur_node_info); - anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); - } - - return GRAPH_SUCCESS; -} - -/// -/// Union ref-mapping -/// @param [in] exist_node_info1 -/// @param [in] exist_node_info2 -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @param [out] symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol, std::string &symbol) { - const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; - const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; - if (symbol1 == symbol2) { - symbol = symbol1; - GELOGI("no need to union."); - return GRAPH_SUCCESS; - } - - auto iter1 = symbol_to_anchors.find(symbol1); - auto iter2 = symbol_to_anchors.find(symbol2); - if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { - GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); - return GRAPH_FAILED; - } - - auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); - auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); - symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); - std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); - for (auto &node_index_io : min_iter->second) { - GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); - max_iter->second.emplace_back(node_index_io); - auto iter = anchor_to_symbol.find(node_index_io.ToString()); - if (iter == anchor_to_symbol.end()) { - GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); - return GRAPH_FAILED; - } - if (iter->second != min_symbol) { - GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), - iter->second.c_str()); - } - iter->second = symbol; - } - - GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); - symbol_to_anchors.erase(min_iter); - return GRAPH_SUCCESS; -} - -/// -/// Update symbol mapping with a new reference pair -/// @param [in] cur_node_info -/// @param [in] exist_node_info -/// @param [out] symbol_to_anchors -/// @param [out] anchor_to_symbol -/// @return success: GRAPH_SUCESS -/// -graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol) { - auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); - if (iter1 == anchor_to_symbol.end()) { - GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", - exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - - const std::string &symbol = iter1->second; - auto iter2 = symbol_to_anchors.find(symbol); - if (iter2 == symbol_to_anchors.end()) { - GE_LOGE("symbol %s not found.", symbol.c_str()); - return GRAPH_FAILED; - } - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - iter2->second.emplace_back(cur_node_info); - anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); - - return GRAPH_SUCCESS; -} - -/// -/// Check if out_data_anchor is reference of input -/// @param [in] out_data_anchor -/// @param [out] reuse_in_index -/// @return bool -/// -bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { - if (out_data_anchor == nullptr) { - GELOGW("out_data_anchor is NULL."); - return false; - } - int32_t output_index = out_data_anchor->GetIdx(); - - // pass-through op - NodePtr node = out_data_anchor->GetOwnerNode(); - const std::string &type = node->GetType(); - const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; - if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { - reuse_in_index = output_index; - GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); - return true; - } - - // Merge op 0th output - if ((type == MERGE) && (output_index == 0)) { - reuse_in_index = 0; - GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); - return true; - } - - // ref op - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGW("op_desc is NULL."); - return false; - } - bool is_ref = false; - (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); - if (is_ref) { - const string &output_name = op_desc->GetOutputNameByIndex(output_index); - for (const auto &input_name : op_desc->GetAllInputNames()) { - if (!input_name.empty() && (output_name == input_name)) { - reuse_in_index = op_desc->GetInputIndexByName(input_name); - GELOGI("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), - output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); - return true; - } - } - } - - // reuse input - auto output_op_desc = op_desc->GetOutputDescPtr(output_index); - bool reuse_input = false; - if (output_op_desc != nullptr) { - if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { - uint32_t reuse_input_index = 0; - if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { - reuse_in_index = static_cast(reuse_input_index); - GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, - reuse_in_index); - return true; - } - } - } - - return false; -} - -/// -/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs -/// of the graph have UNKNOWN_SHAPE operators or not. -/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following -/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE -/// ROOT graph: A -----> B -----> C -/// K subgraph U -/// | -/// V -/// SUB graph: D --> E --> F -/// K K K -/// @param [in] graph -/// @return bool -/// -bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { - if (graph == nullptr) { - GELOGW("Input graph is nullptr."); - return false; - } - for (const auto &node : graph->GetDirectNode()) { - bool is_unknown = false; - auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); - if (ret != GRAPH_SUCCESS) { - GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), - node->GetType().c_str()); - continue; - } - if (is_unknown) { - GELOGD("Node %s, type %s is unknown shape in graph %s.", node->GetName().c_str(), node->GetType().c_str(), - graph->GetName().c_str()); - return true; - } - } - GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); - return false; -} - -/// -/// @brief Add node to graph -/// @param [in] op_desc -/// @return ComputeGraphBuilder -/// -ComputeGraphBuilder &ComputeGraphBuilder::AddNode(const OpDescPtr &op_desc) { - nodes_.emplace_back(op_desc); - return *this; -} - -/// -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return ComputeGraphBuilder -/// -ComputeGraphBuilder &ComputeGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind) { - data_links_.emplace_back( - std::make_pair(std::make_pair(src_name, out_anchor_ind), std::make_pair(dst_name, in_anchor_ind))); - return *this; -} - -/// -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return ComputeGraphBuilder -/// -ComputeGraphBuilder &ComputeGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ctrl_links_.emplace_back(std::make_pair(src_name, dst_name)); - return *this; -} - -/// -/// @brief Build nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void ComputeGraphBuilder::BuildNodes(graphStatus &error_code, std::string &error_msg) { - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return; - } - - std::string node_name; - for (auto &op_desc : nodes_) { - if (op_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "op_desc is NULL."; - return; - } - - node_name = op_desc->GetName(); - NodePtr node = owner_graph_->AddNode(op_desc); - if (node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Add node " + node_name + " failed."; - return; - } - - GELOGD("Add node name:%s, type:%s.", node_name.c_str(), op_desc->GetType().c_str()); - node_names_[node_name] = node; - } - - GELOGD("BuildNodes succ."); -} - -/// -/// @brief Build data-links -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void ComputeGraphBuilder::BuildDataLinks(graphStatus &error_code, std::string &error_msg) { - for (auto &pair : data_links_) { - std::string src_name = pair.first.first; - uint32_t out_ind = pair.first.second; - std::string dst_name = pair.second.first; - uint32_t in_ind = pair.second.second; - std::string log_msg = "Add data-edge "; - log_msg.append(src_name) - .append(":") - .append(std::to_string(out_ind)) - .append("->") - .append(dst_name) - .append(":") - .append(std::to_string(in_ind)); - - auto src_iter = node_names_.find(src_name); - auto dst_iter = node_names_.find(dst_name); - if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node not exist in graph."; - return; - } - - NodePtr src_node = node_names_[src_name]; - NodePtr dst_node = node_names_[dst_name]; - if ((src_node == nullptr) || (dst_node == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node is NULL."; - return; - } - - if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_ind), dst_node->GetInDataAnchor(in_ind)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed."; - return; - } - - GELOGD("%s succ.", log_msg.c_str()); - } - - GELOGD("BuildDataLinks succ."); -} - -/// -/// @brief Build ctrl-links -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void ComputeGraphBuilder::BuildCtrlLinks(graphStatus &error_code, std::string &error_msg) { - for (auto &pair : ctrl_links_) { - std::string src_name = pair.first; - std::string dst_name = pair.second; - std::string log_msg = "Add ctrl-edge "; - log_msg.append(src_name).append("->").append(dst_name); - - auto src_iter = node_names_.find(src_name); - auto dst_iter = node_names_.find(dst_name); - if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node not exist in graph."; - return; - } - - NodePtr src_node = node_names_[src_name]; - NodePtr dst_node = node_names_[dst_name]; - if ((src_node == nullptr) || (dst_node == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node is NULL."; - return; - } - - if (GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed."; - return; - } - - GELOGD("%s succ.", log_msg.c_str()); - } - - GELOGD("BuildCtrlLinks succ."); -} - -/// @brief Get node with name -/// @param [in] name -/// @return NodePtr -/// -NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { - auto iter = node_names_.find(name); - if (iter == node_names_.end()) { - GE_LOGE("node %s not exist.", name.c_str()); - return nullptr; - } - return iter->second; -} - -/// @brief Get all nodes -/// @return std::vector -/// -std::vector ComputeGraphBuilder::GetAllNodes() { - std::vector nodes; - for (const auto &iter : node_names_) { - nodes.emplace_back(iter.second); - } - return nodes; -} - -/// -/// @brief Add node to graph -/// @param [in] op_desc -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddNode(const OpDescPtr &op_desc) { - ComputeGraphBuilder::AddNode(op_desc); - return *this; -} - -/// -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind) { - ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); - return *this; -} - -/// -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ComputeGraphBuilder::AddControlLink(src_name, dst_name); - return *this; -} - -/// -/// @brief Set index_th input anchor for graph -/// @param [in] index -/// @param [in] node_names -/// @param [in] anchor_inds -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetInput(uint32_t index, const std::vector &node_names, - const std::vector &anchor_inds) { - graph_inputs_[index] = std::make_pair(node_names, anchor_inds); - return *this; -} - -/// -/// @brief Set index_th input of graph as useless -/// @param [in] index -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetUselessInput(uint32_t index) { - graph_inputs_[index] = std::make_pair(std::vector(), std::vector()); - return *this; -} - -/// -/// @brief Add output anchor for graph -/// @param [in] owner_node_name -/// @param [in] anchor_ind -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddOutput(const std::string &owner_node_name, uint32_t anchor_ind) { - graph_outputs_.emplace_back(std::make_pair(owner_node_name, anchor_ind)); - return *this; -} - -/// -/// @brief Add target for graph -/// @param [in] target_name -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::AddTarget(const std::string &target_name) { - graph_targets_.emplace_back(target_name); - return *this; -} - -/// -/// @brief Set parent-node of graph -/// @param [in] parent_node -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetParentNode(const NodePtr &parent_node) { - parent_node_ = parent_node; - return *this; -} - -/// -/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node -/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetInputMapping(const std::map &input_mapping) { - for (auto &item : input_mapping) { - input_mapping_[item.first] = item.second; - } - return *this; -} - -/// -/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind -/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node -/// @return CompleteGraphBuilder -/// -CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map &output_mapping) { - for (auto &item : output_mapping) { - output_mapping_[item.first] = item.second; - } - return *this; -} - -/// -/// @brief Build graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return ComputeGraphPtr -/// -ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { - owner_graph_ = shared_ptr(new (std::nothrow) ComputeGraph(name_)); - if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "graph / parent_node is NULL."; - return nullptr; - } - - owner_graph_->SetParentNode(parent_node_); - owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); - - BuildNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildDataLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildCtrlLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - AddDataNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - AddRetValNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildGraphTargets(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - // ATTR_NAME_SESSION_GRAPH_ID - std::string graph_id; - if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { - error_code = GRAPH_FAILED; - error_msg = "Get attr session_graph_id failed."; - return nullptr; - } - if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { - error_code = GRAPH_FAILED; - error_msg = "Set attr session_graph_id failed."; - return nullptr; - } - - // refresh node name - for (const NodePtr &node : owner_graph_->GetDirectNode()) { - if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { - continue; - } - node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); - } - - return owner_graph_; -} - -/// -/// @brief Add data nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { - for (auto &input : graph_inputs_) { - NodePtr data_node = AddDataNode(input.first, error_code, error_msg); - if (data_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; - return; - } - - if (owner_graph_->AddInputNode(data_node) == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; - return; - } - - // useless input - std::vector input_names = input.second.first; - std::vector anchor_indes = input.second.second; - if (input_names.size() != anchor_indes.size()) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; - return; - } - if (input_names.empty()) { - continue; - } - - size_t input_num = input_names.size(); - for (size_t i = 0; i < input_num; i++) { - std::string input_name = input_names[i]; - uint32_t ind = anchor_indes[i]; - auto iter = node_names_.find(input_name); - if (iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; - return; - } - - NodePtr in_node = node_names_[input_name]; - if (in_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; - return; - } - - if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + - ":" + std::to_string(ind) + " failed."; - return; - } - } - - GELOGD("AddDataNodes : Add %u input succ.", input.first); - } - - GELOGD("AddDataNodes succ."); -} - -/// -/// @brief Add data node -/// @param [in] index -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { - std::string data_name = "Data_" + std::to_string(index); - OpDescBuilder op_desc_builder(data_name, "Data"); - OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); - if (op_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; - return nullptr; - } - - auto index_iter = input_mapping_.find(index); - if (index_iter != input_mapping_.end()) { - if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; - return nullptr; - } - } - - NodePtr data_node = owner_graph_->AddNode(op_desc); - if (data_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: add node " + data_name + " failed."; - return nullptr; - } - node_names_[data_name] = data_node; - - return data_node; -} - -/// -/// @brief Add RetVal nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { - size_t output_num = graph_outputs_.size(); - for (size_t i = 0; i < output_num; i++) { - int32_t index = graph_outputs_[i].second; - auto out_iter = node_names_.find(graph_outputs_[i].first); - if (out_iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; - return; - } - NodePtr node = out_iter->second; - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode failed: node is NULL."; - return; - } - - std::string name = node->GetName() + "_RetVal_" + std::to_string(index); - OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); - if (ret_val_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; - return; - } - ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); - if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || - (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; - return; - } - - if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && - ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; - return; - } - auto iter = output_mapping_.find(i); - if (iter != output_mapping_.end()) { - if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; - return; - } - } - - NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); - if (ret_val_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add node failed."; - return; - } - - if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + - "->" + ret_val_node->GetName() + ":0 failed."; - return; - } - } - - GELOGD("AddRetValNodes succ."); -} - -/// -/// @brief Build target-nodes for graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { - std::vector target_nodes; - for (const std::string &target_name : graph_targets_) { - auto target_iter = node_names_.find(target_name); - if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; - return; - } - target_nodes.emplace_back(target_iter->second); - } - owner_graph_->SetGraphTargetNodesInfo(target_nodes); - return; -} - -/// -/// @brief Add node to graph -/// @param [in] op_desc -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddNode(const OpDescPtr &op_desc) { - ComputeGraphBuilder::AddNode(op_desc); - return *this; -} - -/// -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind) { - ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); - return *this; -} - -/// -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ComputeGraphBuilder::AddControlLink(src_name, dst_name); - return *this; -} - -/// -/// @brief Set owner graph -/// @param [in] graph -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::SetOwnerGraph(const ComputeGraphPtr &graph) { - owner_graph_ = graph; - return *this; -} - -/// -/// @brief Add exist node -/// @param [in] node -/// @return PartialGraphBuilder -/// -PartialGraphBuilder &PartialGraphBuilder::AddExistNode(const NodePtr &node) { - exist_nodes_.emplace_back(node); - return *this; -} - -/// -/// @brief Build partial graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return ComputeGraphPtr -/// -ComputeGraphPtr PartialGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return nullptr; - } - - BuildNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildExistNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildDataLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildCtrlLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - return owner_graph_; -} - -/// -/// @brief Build exist nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -/// -void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string &error_msg) { - std::string node_name; - for (auto &node : exist_nodes_) { - if (node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Build exist nodes failed: node is NULL."; - return; - } - - node_name = node->GetName(); - if (node->GetOwnerComputeGraph() != owner_graph_) { - error_code = GRAPH_FAILED; - error_msg = "Build exist nodes failed: node " + node_name + " not belongs to this graph."; - return; - } - - GELOGD("Add exist_node name:%s.", node_name.c_str()); - node_names_[node_name] = node; - } - - GELOGD("Build exist nodes succ."); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec) { - std::vector stack_input; - std::map map_in_edge_num; - graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Sort nodes failed."); - return GRAPH_FAILED; - } - const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; - std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, - [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); - - std::queue stack; - NodePtr cur_node = nullptr; - std::map name_node_map; - vector nodes_name; - while (!stack_input.empty() || !stack.empty()) { - if (!stack.empty()) { - cur_node = stack.front(); - stack.pop(); - } else { - cur_node = stack_input.back(); - stack_input.pop_back(); - } - node_vec.emplace_back(cur_node); - compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); - for (const auto &iter : name_node_map) { - nodes_name.emplace_back(iter.first); - } - std::sort(nodes_name.begin(), nodes_name.end()); - for (const auto &iter : nodes_name) { - stack.push(name_node_map[iter]); - } - name_node_map.clear(); - nodes_name.clear(); - } - // If they are not equal, there is a closed loop - if (node_vec.size() != compute_graph->nodes_.size()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - itered_nodes_set.insert(node.get()); - } - GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", - compute_graph->nodes_.size(), node_vec.size()); - for (auto &node : compute_graph->nodes_) { - if (itered_nodes_set.count(node.get()) == 0) { - GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); - } - } - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -} // namespace ge diff --git a/metadef/graph/utils/mem_utils.h b/metadef/graph/utils/mem_utils.h deleted file mode 100644 index 7e8dd9fd..00000000 --- a/metadef/graph/utils/mem_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_ -#define COMMON_GRAPH_UTILS_MEM_UTILS_H_ - -#include -#include - -namespace ge { -template -static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { - typedef typename std::remove_const<_Tp>::type _Tp_nc; - std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); - return ret; -} -} - -#endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ diff --git a/metadef/graph/utils/node_utils.cc b/metadef/graph/utils/node_utils.cc deleted file mode 100644 index 72981d10..00000000 --- a/metadef/graph/utils/node_utils.cc +++ /dev/null @@ -1,956 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/node_utils.h" -#include "utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/anchor.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/types.h" -#include "utils/tensor_utils.h" -#include "utils/type_utils.h" - -namespace ge { -std::map> NodeUtils::map_send_info_{}; -std::map> NodeUtils::map_recv_info_{}; - -const std::set kConstOpTypes = {"Const", "Constant"}; - -const std::set kIfOpTypes = {"If", "_If", "StatelessIf"}; -const std::set kWhileOpTypes = {"While", "_While", "StatelessWhile"}; -const std::set kCaseOpTypes = {"Case"}; -const std::set kForOpTypes = {"For"}; - -bool OpShapeIsUnknown(const OpDescPtr &desc) { - for (const auto &ptr : desc->GetAllInputsDescPtr()) { - auto ge_shape = ptr->GetShape(); - for (const auto &dim : ge_shape.GetDims()) { - if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { - return true; - } - } - } - for (const auto &ptr : desc->GetAllOutputsDescPtr()) { - auto ge_shape = ptr->GetShape(); - for (const auto &dim : ge_shape.GetDims()) { - if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { - return true; - } - } - } - return false; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, - const uint32_t &event_id) { - GE_CHECK_NOTNULL(node); - map_send_info_[node].push_back(event_id); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node, - const uint32_t &event_id) { - GE_CHECK_NOTNULL(node); - map_recv_info_[node].push_back(event_id); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector &vec_send) { - GE_CHECK_NOTNULL(node); - auto find = map_send_info_.find(node); - if (find == map_send_info_.end()) { - return GRAPH_FAILED; - } else { - vec_send = find->second; - return GRAPH_SUCCESS; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv) { - GE_CHECK_NOTNULL(node); - auto find = map_recv_info_.find(node); - if (find == map_recv_info_.end()) { - return GRAPH_FAILED; - } else { - vec_recv = find->second; - return GRAPH_SUCCESS; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() { - map_send_info_.clear(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() { - map_recv_info_.clear(); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) { - GE_CHECK_NOTNULL(src); - NodePtr cur_ptr; - if (depth < 1) { - return GRAPH_FAILED; - } - for (int i = 0; i < depth; i++) { - if (src->GetOutDataNodes().size() != 1) { - return GRAPH_FAILED; - } - cur_ptr = src->GetOutDataNodes().at(0); - GE_CHECK_NOTNULL(cur_ptr); - } - dst = cur_ptr; - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, - InControlAnchorPtr &in_control) { - GE_CHECK_NOTNULL(node_ptr); - for (const auto &p : node_ptr->GetAllOutDataAnchors()) { - GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr"); - for (const auto &p_in : p->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr"); - out_data = p; - in_control = p_in; - return GRAPH_SUCCESS; - } - } - return GRAPH_FAILED; -} - -graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, - "node or in_data_anchor is nullptr"); - - bool find_flag = false; - uint32_t index = 0; - vector::iterator it = node_ptr->in_data_anchors_.end(); - for (const auto &tmp : node_ptr->in_data_anchors_) { - if (tmp == in_data_anchor) { - find_flag = true; - auto iter = node_ptr->in_data_anchors_.begin() + index; - if (iter != node_ptr->in_data_anchors_.end()) { - it = node_ptr->in_data_anchors_.erase(iter); - } - break; - } - index++; - } - for (; it != node_ptr->in_data_anchors_.end(); ++it) { - (*it)->SetIdx(index); - index++; - } - - if (!find_flag) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr"); - GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed"); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { - node.anchor_status_updated_ = true; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr"); - return IsAnchorStatusSet(*node_ptr); -} - -bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; } - -graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { - if ((origin_node == nullptr) || (new_node == nullptr)) { - return GRAPH_FAILED; - } - auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); - auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); - if (origin_out_data_anchors.size() != new_out_data_anchors.size()) { - return GRAPH_FAILED; - } - - for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) { - for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, - "unlink peer_anchor failed"); - GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - - for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, - "unlink peer_anchor failed"); - GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - } - - auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(origin_out_control_anchor); - auto new_out_control_anchor = new_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(new_out_control_anchor); - for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, - "linkto peer_anchor failed"); - } - origin_out_control_anchor->UnlinkAll(); - - return GRAPH_SUCCESS; -} - -bool NodeUtils::IsConst(const Node &node) { - auto src_node_type = node.GetType(); - bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); - return is_const; -} - -void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "node is null"); - return; - } - UpdateIsInputConst(*node_ptr); -} - -/// -/// update is_input_const -/// @param node -/// @return void -/// -void NodeUtils::UpdateIsInputConst(Node &node) { - std::vector is_input_const; - size_t anchor_num = node.GetAllInDataAnchors().size(); - for (size_t i = 0; i < anchor_num; i++) { - auto in_anchor = node.GetInDataAnchor(static_cast(i)); - if (in_anchor == nullptr) { - is_input_const.push_back(false); - continue; - } - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - is_input_const.push_back(false); - continue; - } - auto src_node = peer_out_anchor->GetOwnerNode(); - if (src_node == nullptr) { - is_input_const.push_back(false); - continue; - } - if (IsConst(*(src_node))) { - is_input_const.push_back(true); - } else { - is_input_const.push_back(false); - } - } - if (node.GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr"); - return; - } - node.GetOpDesc()->SetIsInputConst(is_input_const); -} - -void NodeUtils::UnlinkAll(const Node &node) { - for (const auto &anchor : node.GetAllOutAnchors()) { - anchor->UnlinkAll(); - } - for (const auto &anchor : node.GetAllInAnchors()) { - anchor->UnlinkAll(); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) { - if (node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); - return GRAPH_FAILED; - } - auto op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_FAILED; - } - bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - if (is_unknown_graph) { - return GRAPH_SUCCESS; - } - for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); - auto out_dims = output_tensor->GetShape().GetDims(); - auto out_dtype = output_tensor->GetDataType(); - ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); - output_tensor->SetOriginShape(output_tensor->GetShape()); - output_tensor->SetOriginDataType(output_tensor->GetDataType()); - - GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); - - for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); - continue; - } - auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); - if (peer_input_desc == nullptr) { - GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); - continue; - } - // check shape and dtype continuity. do not stop process - auto peer_input_dims = peer_input_desc->GetShape().GetDims(); - auto peer_input_dtype = peer_input_desc->GetDataType(); - if (out_dtype != peer_input_dtype) { - GELOGW( - "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th " - "input_dtype is [%s].The two dtype should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), - TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); - } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { - string out_shape_str, peer_in_shape_str; - out_shape_str += "["; - for (int64_t dim : out_dims) { - out_shape_str += std::to_string(dim) + " "; - } - out_shape_str += "]"; - peer_in_shape_str += "["; - for (int64_t dim : peer_input_dims) { - peer_in_shape_str += std::to_string(dim) + " "; - } - peer_in_shape_str += "]"; - - GELOGW( - "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " - "input_shape is [%s].The two shape should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); - } - GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), - output_tensor->GetDataType(), output_tensor->GetOriginDataType()); - peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); - peer_input_desc->SetShape(output_tensor->GetShape()); - peer_input_desc->SetDataType(output_tensor->GetDataType()); - peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); - std::vector> shape_range; - (void)output_tensor->GetShapeRange(shape_range); - peer_input_desc->SetShapeRange(shape_range); - ge::TensorUtils::SetRealDimCnt(*peer_input_desc, - static_cast(output_tensor->GetShape().GetDims().size())); - GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), - peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const auto &op_desc = node->GetOpDesc(); - for (size_t i = op_desc->GetInputsSize(); i < num; ++i) { - if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add input desc failed"); - return GRAPH_FAILED; - } - - auto anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); - return GRAPH_FAILED; - } - node->in_data_anchors_.push_back(anchor); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - const auto &op_desc = node->GetOpDesc(); - while (op_desc->GetInputsSize() > num) { - if (!OpDescUtils::ClearInputDesc(op_desc, num)) { - return GRAPH_FAILED; - } - } - - auto input_names = op_desc->GetAllInputName(); - (void)op_desc->UpdateInputName(input_names); - auto is_input_const = op_desc->GetIsInputConst(); - is_input_const.resize(num); - op_desc->SetIsInputConst(is_input_const); - - while (node->in_data_anchors_.size() > num) { - node->in_data_anchors_.pop_back(); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const OpDescPtr &op_desc = node->GetOpDesc(); - for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { - if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add output desc failed"); - return GRAPH_FAILED; - } - - auto anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); - return GRAPH_FAILED; - } - node->out_data_anchors_.push_back(anchor); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, - uint32_t num) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Input node is null"); - return GRAPH_FAILED; - } - - const auto &op_desc = node->GetOpDesc(); - auto output_names = op_desc->GetAllOutputName(); - while (op_desc->GetOutputsSize() > num) { - if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { - return GRAPH_FAILED; - } - } - (void)op_desc->UpdateOutputName(output_names); - - while (node->out_data_anchors_.size() > num) { - node->out_data_anchors_.pop_back(); - } - - return GRAPH_SUCCESS; -} - -bool NodeUtils::IsInNodesEmpty(const Node &node) { - for (const auto &in_anchor : node.in_data_anchors_) { - if (in_anchor != nullptr) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor != nullptr) { - if (out_anchor->GetOwnerNode() != nullptr) { - return false; - } - } - } - } - - if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { - auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - if (out_control_anchor != nullptr) { - if (out_control_anchor->GetOwnerNode() != nullptr) { - return false; - } - } - } - } - - return true; -} -GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GeTensorDesc(); - } - return desc->GetOutputDesc(index); -} -GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GeTensorDesc(); - } - return desc->GetInputDesc(index); -} -graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - auto output_desc = desc->MutableOutputDesc(index); - if (output_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - output_desc->SetShape(shape); - return GRAPH_SUCCESS; -} -graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) { - auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - auto input_desc = desc->MutableInputDesc(index); - if (input_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - input_desc->SetShape(shape); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { - auto desc = node.GetOpDesc(); - GE_CHECK_NOTNULL(desc); - // check self - is_unknow = OpShapeIsUnknown(desc); - if (is_unknow) { - return GRAPH_SUCCESS; - } - auto sub_graph_names = desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } else { - auto owner_graph = node.GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - for (auto &sub_graph_name : sub_graph_names) { - auto sub_graph = root_graph->GetSubgraph(sub_graph_name); - GE_CHECK_NOTNULL(sub_graph); - for (const auto &node_ptr : sub_graph->GetDirectNode()) { - auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); - if (status != GRAPH_SUCCESS) { - GE_LOGE("get node unknown shape status failed!"); - return status; - } - if (is_unknow) { - return GRAPH_SUCCESS; - } - } - } - } - return GRAPH_SUCCESS; -} - -std::string NodeUtils::GetNodeType(const Node &node) { - if (node.GetType() != FRAMEWORKOP) { - return node.GetType(); - } - - std::string type; - (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); - return type; -} - -std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } - -graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) { - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) { - return GRAPH_SUCCESS; -} - -ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { - auto op_desc = node.GetOpDesc(); - if (op_desc == nullptr) { - return nullptr; - } - auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - return nullptr; - } - return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); -} - -graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { - if (subgraph == nullptr) { - GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); - return GRAPH_PARAM_INVALID; - } - auto op_desc = node.GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); - if (ret != GRAPH_SUCCESS) { - GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); - return ret; - } - subgraph->SetParentNode(node.shared_from_this()); - subgraph->SetParentGraph(node.GetOwnerComputeGraph()); - return root_graph->AddSubgraph(subgraph); -} - -/// -/// Check if node is input of subgraph -/// @param [in] node -/// @return bool -/// -bool NodeUtils::IsSubgraphInput(const NodePtr &node) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr) || - (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { - return false; - } - - auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); - if (parent_op_desc == nullptr) { - return false; - } - - // dynamic shape unknown graph false - // dynamic shape known graph with functional subgraph maybe true - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { - if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { - return false; - } else { - if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { - return false; - } - } - } - - return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); -} - -/// -/// Check if node is output of subgraph -/// @param [in] node -/// @return bool -/// -bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr) || - (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { - return false; - } - - auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); - if (parent_op_desc == nullptr) { - return false; - } - - if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { - if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { - return false; - } else { - if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { - return false; - } - } - } - - for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { - if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { - return true; - } - } - - return false; -} - -/// -/// @brief Get subgraph original input node. -/// @param [in] node -/// @return Node -/// -NodePtr NodeUtils::GetParentInput(const Node &node) { - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return nullptr; - } - - // Subgraph Data Node, check for constant input. - const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); - GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - - const NodePtr &parent_node = graph->GetParentNode(); - GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); - - const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); - GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); - - const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); - - return peer_out_anchor->GetOwnerNode(); -} - -NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); } - -/// -/// @brief Get is dynamic shape graph from node. -/// @param [in] node -/// @return bool -/// -bool NodeUtils::IsDynamicShape(const Node &node) { - const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (graph == nullptr) { - return false; - } - - bool is_dynamic_shape = false; - (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); - return is_dynamic_shape; -} - -bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); } - -/// -/// @brief Check is varying_input for while node -/// @param [in] node: Data node for subgraph -/// @return bool -/// -bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->GetType() != DATA) { - return false; // not input_node for subgraph - } - - const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - if (parent_node == nullptr) { - return false; // root graph - } - - if (kWhileOpTypes.count(parent_node->GetType()) == 0) { - return false; // not input_node for while subgraph - } - - uint32_t index_i = 0; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { - GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); - return false; - } - bool varying_flag = true; - for (const auto &item : node->GetOutDataNodesAndAnchors()) { - if (item.first->GetType() != NETOUTPUT) { - continue; - } - OpDescPtr op_desc = item.first->GetOpDesc(); - uint32_t index_o = 0; - if ((op_desc == nullptr) || - !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { - continue; // input for while-cond subgraph - } - if (index_i != index_o) { - continue; // varying input for while-body subgraph - } - varying_flag = false; - break; - } - return varying_flag; -} - -/// -/// @brief Get subgraph input is constant. -/// @param [in] node -/// @param [out] string -/// @return bool -/// -bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { - if (node == nullptr) { - return false; - } - - if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { - type = node->GetType(); - return true; - } - - if (node->GetType() != DATA) { - return false; // not subgraph input node - } - - const auto &parent = GetParentInput(node); - return GetConstOpType(parent, type); -} - -/// -/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. -/// @param [in] node -/// @return return GRAPH_SUCCESS if remove successfully, other for failed. -/// -Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - return GRAPH_SUCCESS; - } else { - auto owner_graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - auto root_graph = GraphUtils::FindRootGraph(owner_graph); - GE_CHECK_NOTNULL(root_graph); - - std::unordered_set subgraph_to_remove; - for (auto &subgraph_name : subgraph_names) { - std::deque queue; - queue.push_back(subgraph_name); - subgraph_to_remove.insert(subgraph_name); - op_desc->RemoveSubgraphInstanceName(subgraph_name); - while (!queue.empty()) { - auto graph_name = queue.front(); - queue.pop_front(); - - auto subgraph = root_graph->GetSubgraph(graph_name); - GE_CHECK_NOTNULL(subgraph); - for (const auto &sub_node : subgraph->GetDirectNode()) { - auto sub_op_desc = sub_node->GetOpDesc(); - GE_CHECK_NOTNULL(sub_op_desc); - auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); - // Subgraph and all nodes in it will be removed later, - // no need to remove 'SubgraphInstanceName' in op desc here. - for (auto &name : sub_names) { - if (subgraph_to_remove.insert(name).second) { - queue.push_back(name); - } - } - } - } - } - // Remove subgraph from root_graph - for (const auto &name : subgraph_to_remove) { - GELOGI("Remove subgraph:%s.", name.c_str()); - root_graph->RemoveSubgraph(name); - } - } - - return GRAPH_SUCCESS; -} -/// -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -/// -vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { - vector in_data_node_vec; - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); - return in_data_node_vec; - } - auto compute_graph = node.GetOwnerComputeGraph(); - for (const std::string &instance_name : subgraph_names) { - auto subgraph = compute_graph->GetSubgraph(instance_name); - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - int parent_index = -1; - if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { - (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); - if (parent_index == index) { - in_data_node_vec.emplace_back(node_in_subgraph); - } - } - } - } - return in_data_node_vec; -} -/// -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -/// -vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { - vector out_data_node_vec; - auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); - return out_data_node_vec; - } - auto compute_graph = node.GetOwnerComputeGraph(); - for (const std::string &instance_name : subgraph_names) { - auto subgraph = compute_graph->GetSubgraph(instance_name); - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { - out_data_node_vec.emplace_back(node_in_subgraph); - } - } - } - return out_data_node_vec; -} - -NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) { - if (node.GetInDataAnchor(index) == nullptr) { - return nullptr; - } - if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { - return nullptr; - } - return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); -} - -vector> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) { - vector> out_data_nodes; - auto out_data_anchor = node.GetOutDataAnchor(index); - if (out_data_anchor == nullptr) { - return out_data_nodes; - } - - for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - if (peer_in_anchor == nullptr) { - continue; - } - if (peer_in_anchor->GetOwnerNode() == nullptr) { - continue; - } - out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode())); - } - return out_data_nodes; -} - -ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } -} // namespace ge diff --git a/metadef/graph/utils/op_desc_utils.cc b/metadef/graph/utils/op_desc_utils.cc deleted file mode 100644 index 63fff177..00000000 --- a/metadef/graph/utils/op_desc_utils.cc +++ /dev/null @@ -1,778 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "utils/op_desc_utils.h" -#include -#include "debug/ge_attr_define.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/ge_attr_value.h" -#include "utils/graph_utils.h" -#include "utils/node_utils.h" - -using std::vector; - -/*lint -e512 -e737 -e752*/ -namespace ge { -const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; -static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; - -bool OpDescUtils::ClearInputDesc(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); - GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); - vector index_list; - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (in_anchor->GetPeerOutAnchor() == nullptr) { - index_list.push_back(in_anchor->GetIdx()); - } - } - std::sort(index_list.begin(), index_list.end()); - // Node's in anchor index need shrink - for (size_t i = 0; i < index_list.size(); ++i) { - auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i]; - if (iter < node->GetOpDesc()->inputs_desc_.end()) { - (void)node->GetOpDesc()->inputs_desc_.erase(iter); - } else { - GELOGW("inputs_desc_ iterator out of range."); - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc, - const uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index); - - auto iter = op_desc->inputs_desc_.begin() + index; - if (iter < op_desc->inputs_desc_.end()) { - (void)op_desc->inputs_desc_.erase(iter); - } else { - GELOGW("inputs_desc_ iterator out of range."); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { - GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr"); - return op_desc->HasAttr(OP_DESC_QUANT_PARAMS); -} - -bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); - GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); - vector index_list; - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - if (out_anchor->GetPeerInDataAnchors().empty()) { - index_list.push_back(out_anchor->GetIdx()); - } - } - std::sort(index_list.begin(), index_list.end()); - // Node's out anchor index need shrink - for (size_t i = 0; i < index_list.size(); ++i) { - auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i]; - if (iter < node->GetOpDesc()->outputs_desc_.end()) { - (void)node->GetOpDesc()->outputs_desc_.erase(iter); - } else { - GELOGW("outputs_desc_ iterator out of range."); - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, - uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index); - - auto iter = op_desc->outputs_desc_.begin() + index; - if (iter < op_desc->outputs_desc_.end()) { - (void)op_desc->outputs_desc_.erase(iter); - } else { - GELOGW("outputs_desc_ iterator out of range."); - } - return true; -} - -bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) { - GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); - GeAttrValue attr_value; - GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, - "GetQuantizeFactorParams failed"); - return attr_value.GetValue(quant); -} - -graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) { - GeAttrValue attr_value; - GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, - "GetQuantizeFactorParams failed"); - return attr_value.GetValue(quant); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { - GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); - return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 -} - -graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { - return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 -} - -GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { - GeTensorPtr weight = nullptr; - if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) { - GELOGW("MutableTensor error"); - } - - return weight; -} - -GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) { - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc is null"); - return nullptr; - } - return MutableWeights(*op_desc); -} - -graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { - if (weight == nullptr) { - GELOGE(GRAPH_FAILED, "weight is null"); - return GRAPH_FAILED; - } - return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(weight); - return SetWeights(*op_desc, weight); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights(const ge::Node &node) { - auto weights = MutableWeights(node); - vector ret(weights.size()); - std::copy(weights.begin(), weights.end(), ret.begin()); - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetWeights( - const ge::ConstNodePtr &node) { - if (node == nullptr) { - return vector(); - } - return GetWeights(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputNode( - const ge::Node &node) { - vector ret; - auto in_anchors = node.GetAllInDataAnchors(); - for (const auto &in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - // normally out_anchor could be null, this is ok - GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str()); - continue; - } - auto in_node = out_anchor->GetOwnerNode(); - while (true) { - if (in_node == nullptr) { - break; - } - if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { - ret.push_back(in_node); - break; - } else if (in_node->GetType() == DATA) { - if (NodeUtils::IsWhileVaryingInput(in_node)) { - break; - } - in_node = NodeUtils::GetParentInput(in_node); - } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { - bool is_constant = false; - (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); - if (!is_constant) { - break; - } - // Enter node has and only has one input - if (in_node->GetInDataNodes().size() != 1) { - GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), - in_node->GetInDataNodes().size()); - break; - } - in_node = in_node->GetInDataNodes().at(0); - } else { - break; - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetInputData( - const vector &input_nodes) { - vector ret; - - for (const auto &input_node : input_nodes) { - auto temp_weight = MutableWeights(input_node->GetOpDesc()); - if (temp_weight == nullptr) { - GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); - return vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} -size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { - if (NodeUtils::IsAnchorStatusSet(node)) { - size_t input_num = 0; - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - input_num++; - continue; - } - } - return input_num; // lint !e712 - } else { - GE_IF_BOOL_EXEC( - node.GetInDataNodes().size() < GetConstInputs(node).size(), - GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); - return 0); - return node.GetInDataNodes().size() - GetConstInputs(node).size(); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); - return 0; - } - return GetNonConstInputsSize(*node); -} - -GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) { - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!"); - size_t i = 0; - if (NodeUtils::IsAnchorStatusSet(node)) { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - if (index_non_const == i) { - return node.GetOpDesc()->GetInputDesc(static_cast(anchor->GetIdx())); - } - ++i; - } - } - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - continue; - } - auto owner_node = peer_anchor->GetOwnerNode(); - if (owner_node == nullptr) { - continue; - } - if (owner_node->GetType() == CONSTANT) { - continue; - } - if (index_non_const == i) { - return node.GetOpDesc()->GetInputDesc(anchor->GetIdx()); - } - ++i; - } - } - return GeTensorDesc(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc -OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) { - CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc()); - return GetNonConstInputTensorDesc(*node, index_non_const); -} - -bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) { - bool ret = false; - size_t i = 0; - if (NodeUtils::IsAnchorStatusSet(node)) { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - if (index_non_const == i) { - index = static_cast(anchor->GetIdx()); - ret = true; - } - ++i; - } - } - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - continue; - } - auto owner_node = peer_anchor->GetOwnerNode(); - if (owner_node == nullptr) { - continue; - } - if (owner_node->GetType() == CONSTANT) { - continue; - } - if (index_non_const == i) { - index = static_cast(anchor->GetIdx()); - ret = true; - } - ++i; - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node, - size_t index_non_const, - size_t &index) { - CHECK_FALSE_EXEC(node != nullptr, return false); - return GetNonConstInputIndex(*node, index_non_const, index); -} - -bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { - bool ret = false; - if (index < node.GetAllInDataAnchors().size()) { - if (NodeUtils::IsAnchorStatusSet(node)) { - ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == ANCHOR_DATA); // lint !e712 - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (anchor->GetIdx() != static_cast(index)) { - continue; - } - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - break; - } - auto owner_node = peer_anchor->GetOwnerNode(); - if (owner_node == nullptr) { - break; - } - ret = (owner_node->GetType() != CONSTANT); - } - } - } - - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node, - size_t index) { - CHECK_FALSE_EXEC(node != nullptr, return false); - return IsNonConstInput(*node, index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs( - const ge::ConstNodePtr &node) { - if (node == nullptr) { - return vector(); - } - return GetConstInputs(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetNonConstTensorDesc( - const ge::ConstNodePtr &node) { - if (node == nullptr || node->GetOpDesc() == nullptr) { - return vector(); - } - vector ret; - if (NodeUtils::IsAnchorStatusSet(*node)) { - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { - ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); - } - } - } else { - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { - continue; - } - if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { - ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::GetConstInputs(const ge::Node &node) { - vector ret; - auto in_anchors = node.GetAllInDataAnchors(); - for (const auto &in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) continue; - - auto in_node = out_anchor->GetOwnerNode(); - if (in_node->GetType() == CONSTANT) { - ret.push_back(in_node); - } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) { - // const --> switch --> matmul - auto switch_input = GetConstInputs(*in_node); - if (switch_input.size() > 0) { - ret.insert(ret.end(), switch_input.begin(), switch_input.end()); - } - } else if (in_node->GetType() == DATA) { - auto parent = NodeUtils::GetParentInput(in_node); - if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { - ret.push_back(parent); - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { - vector ret; - auto op_desc = node.GetOpDesc(); - GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); - // Place holder operator, try to get the weight from parent node - // when parent node is const operator - if (node.GetType() == PLACEHOLDER) { - std::string parent_op; - (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); - // This if judgment is necessary because the current subgraph optimization is multithreaded - // and the parent node of the PLD operation should be a stable type, such as const - if (parent_op == CONSTANT || parent_op == CONSTANTOP) { - NodePtr parent_node = nullptr; - parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); - if (parent_node != nullptr) { - op_desc = parent_node->GetOpDesc(); - GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); - } - } - } - // Const operator, take the weight directly - if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { - auto weight = MutableWeights(op_desc); - if (weight == nullptr) { - GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); - return ret; - } - ret.push_back(weight); - return ret; - } - - if (node.GetType() == DATA) { - auto parent = NodeUtils::GetParentInput(node); - if ((parent != nullptr) && NodeUtils::IsConst(*parent)) { - auto weight = MutableWeights(parent->GetOpDesc()); - if (weight == nullptr) { - GELOGI("const op has no weight, op name:%s", parent->GetName().c_str()); - return ret; - } - ret.push_back(weight); - } - return ret; - } - - // Other operators, get weights from connected constop - auto input_nodes = GetConstInputs(node); - for (const auto &input_node : input_nodes) { - auto temp_weight = MutableWeights(input_node->GetOpDesc()); - if (temp_weight == nullptr) { - GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); - return vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::NodePtr node) { - if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr"); - return vector(); - } - return MutableWeights(*node); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::Node &node, const vector &weights) { - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!"); - if (node.GetOpDesc()->GetType() == CONSTANT) { - if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { - return SetWeights(node.GetOpDesc(), weights[0]); - } - GELOGI("const op weight size %zu should be 1", weights.size()); - return GRAPH_PARAM_INVALID; - } - - auto input_nodes = GetConstInputs(node); - if (weights.size() < input_nodes.size()) { - GELOGE(GRAPH_FAILED, "weights count can't be less than const input count"); - return GRAPH_PARAM_INVALID; - } - - ge::GeAttrValue::NAMED_ATTRS named_attrs; - (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); - vector copy_weights; - (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); - - for (size_t i = 0; i < input_nodes.size(); ++i) { - if (input_nodes[i]->GetOpDesc() != nullptr) { - SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]); - } - } - - // If set more weights than constop, need to add constop - for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) { - // Use org weight before SetWeights Overwrite - auto const_opdesc = CreateConstOp(copy_weights[i]); - GE_CHECK_NOTNULL(const_opdesc); - - auto owner_graph = node.GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto const_node = owner_graph->AddNodeFront(const_opdesc); - GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed锛"); - std::vector original_nodes; - ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); - } - return GRAPH_SUCCESS; -} - -OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { - GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!"); - shared_ptr const_opdesc = ComGraphMakeShared(); - if (const_opdesc == nullptr) { - GELOGE(GRAPH_FAILED, "failed to make_shared "); - return nullptr; - } - - CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr); - - const_opdesc->SetType(CONSTANT); - - thread_local int64_t const_count = 0; - const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count)); - GELOGI("add const op: %s", const_opdesc->GetName().c_str()); - ++const_count; - - (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); - - GELOGI("after add const op: %s", const_opdesc->GetName().c_str()); - - return const_opdesc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) { - GE_CHECK_NOTNULL(in_anchor); - GE_CHECK_NOTNULL(tensor_ptr); - auto const_opdesc = CreateConstOp(tensor_ptr); - GE_CHECK_NOTNULL(const_opdesc); - auto in_node = in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(in_node); - auto owner_graph = in_node->GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc); - GE_CHECK_NOTNULL(const_node); - if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed."); - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::NodePtr node, const vector &weights) { - GE_CHECK_NOTNULL(node); - return SetWeights(*node, weights); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) { - GE_CHECK_NOTNULL(node); - auto const_ops = GetConstInputs(node); - auto graph = node->GetOwnerComputeGraph(); - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "Graph is nullptr"); - return GRAPH_PARAM_INVALID; - } - for (const auto &const_op : const_ops) { - GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed", - const_op->GetName().c_str(), const_op->GetType().c_str()); - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op), - "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(), - const_op->GetType().c_str()); - } - return GRAPH_SUCCESS; -} - -/// -/// @brief Add input -/// @param [in] name -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { - inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -/// -/// @brief Add input -/// @param [in] name -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, - const GeTensorDesc &tensor) { - inputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -/// -/// @brief Add dynamic input -/// @param [in] name -/// @param [in] num -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, - uint32_t num) { - for (uint32_t i = 0; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -/// -/// @brief Add dynamic input -/// @param [in] name -/// @param [in] num -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( - const std::string &name, uint32_t num, const GeTensorDesc &tensor) { - for (uint32_t i = 0; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -/// -/// @brief Add output -/// @param [in] name -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { - outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -/// -/// @brief Add output -/// @param [in] name -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, - const GeTensorDesc &tensor) { - outputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -/// -/// @brief Add dynamic output -/// @param [in] name -/// @param [in] num -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, - uint32_t num) { - for (uint32_t i = 0; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -/// -/// @brief Add dynamic output -/// @param [in] name -/// @param [in] num -/// @param [in] tensor -/// @return OpDescBuilder -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( - const std::string &name, uint32_t num, const GeTensorDesc &tensor) { - for (uint32_t i = 0; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -/// -/// @brief Build op_desc -/// @return OpDescPtr -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { - OpDescPtr op_desc = shared_ptr(new (std::nothrow) OpDesc(name_, type_)); - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); - return nullptr; - } - - for (auto &input : inputs_) { - if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add input_desc failed."); - return nullptr; - } - } - - for (auto &output : outputs_) { - if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Add output_desc failed."); - return nullptr; - } - } - - return op_desc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( - const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - auto iter = subgraph_names_to_index.find(subgraph_name); - if (iter == subgraph_names_to_index.end()) { - GELOGE(GRAPH_PARAM_INVALID, - "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", - subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), - subgraph_name.c_str()); - return GRAPH_PARAM_INVALID; - } - - return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); -} -} // namespace ge -/*lint +e512 +e737 +e752*/ diff --git a/metadef/graph/utils/string_utils.h b/metadef/graph/utils/string_utils.h deleted file mode 100644 index a9700469..00000000 --- a/metadef/graph/utils/string_utils.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_UTILS_STRING_UTILS_H_ -#define COMMON_GRAPH_UTILS_STRING_UTILS_H_ - -#include -#include -#include -#include -#include -#include "securec.h" - -namespace ge { -class StringUtils { - public: - static std::string &Ltrim(std::string &s) { - (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); - return s; - } - - static std::string &Rtrim(std::string &s) { - (void)s.erase(std::find_if(s.rbegin(), s.rend(), [](int c) { return !std::isspace(c); }).base(), s.end()); - return s; - } - - /// @ingroup domi_common - /// @brief trim space - static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } - - // split string - static std::vector Split(const std::string &str, char delim) { - std::vector elems; - - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream ss(str); - std::string item; - - while (getline(ss, item, delim)) { - elems.push_back(item); - } - auto str_size = str.size(); - if (str_size > 0 && str[str_size - 1] == delim) { - elems.emplace_back(""); - } - - return elems; - } -}; -} // namespace ge -#endif // COMMON_GRAPH_UTILS_STRING_UTILS_H_ diff --git a/metadef/graph/utils/tensor_utils.cc b/metadef/graph/utils/tensor_utils.cc deleted file mode 100644 index 26ac8cc8..00000000 --- a/metadef/graph/utils/tensor_utils.cc +++ /dev/null @@ -1,401 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/tensor_utils.h" -#include - -#include "debug/ge_log.h" -#include "framework/common/debug/ge_log.h" -#include "common/util/error_manager/error_manager.h" -#include "graph/ge_tensor.h" -#include "graph/types.h" -#include "graph/utils/type_utils.h" - -namespace ge { -namespace { -// When nc1hwc0 dim size = 5, calc element count directly. -const uint32_t kNc1hwc0CalcByDimsSize = 5; - -// Unknown shape element num -const int64_t kElementCntUnknownShape = -1; - -// Unknown shape mem size -const int64_t kMemSizeUnknownShape = -1; - -// Nchw and nhwc dim size must be 4 -const uint32_t kDimSize4d = 4; - -// C1HWNCoC0 dim size must be 6 -const uint32_t kDimSizeC1hwncoc0 = 6; - -// Cube size is 16 -const uint32_t kTheCubeSize = 16; - -// Default c0 size equals cube size. -const uint32_t kC0SizeDefault = kTheCubeSize; - -// Size equals int8 cube size is 32 -const uint32_t kC0SizeInt8 = 32; - -// NCHW dim N index -const int32_t kNchwDimIdxN = 0; -// NCHW dim C index -const int32_t kNchwDimIdxC = 1; -// NCHW dim H index -const int32_t kNchwDimIdxH = 2; -// NCHW dim W index -const int32_t kNchwDimIdxW = 3; - -const int kDataMemAlignSize = 32; -const int kNum2 = 2; -} // namespace - -/// -/// Check if a * b overflow. -/// @param a multiplier -/// @param b Multiplicand -/// @return true: overflow -/// false: not overflow -/// -static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return true; - } - } else { - if (b < (INT64_MIN / a)) { - return true; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return true; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return true; - } - } - } - return false; -} - -/// -/// Calculate element num by dims directly. -/// @param dims dim info -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntByDims(const std::vector &dims, int64_t &element_cnt) { - element_cnt = 1; - for (int64_t dim : dims) { - if (CheckMultiplyOverflowInt64(element_cnt, dim)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19013", {"function", "var1", "var2"}, - {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); - GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); - return GRAPH_FAILED; - } - element_cnt *= dim; - } - return GRAPH_SUCCESS; -} - -/// -/// Calculate fixed dims element num. -/// @param dims dim info -/// @param fixed_dim_size fixed dim size -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfFixedDims(const std::vector &dims, Format format, uint32_t fixed_dim_size, - int64_t &element_cnt) { - if (dims.size() != fixed_dim_size) { - GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format, - TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size()); - } - return CalcElementCntByDims(dims, element_cnt); -} - -/// -/// Get dim c0 size by type -/// @param data_type data type -/// @return c0 size -/// -static uint32_t GetDimC0(DataType &data_type) { - bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) || - (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8); - return is_int8_size ? kC0SizeInt8 : kC0SizeDefault; -} - -/// -/// Calculate nc1hwc0 element num. -/// @param dims dim info -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfNc1hwc0(const std::vector &dims, DataType data_type, int64_t &element_cnt) { - // When nc1hwc0 dims size = 5, no need split dim c - if (dims.size() == kNc1hwc0CalcByDimsSize) { - return CalcElementCntByDims(dims, element_cnt); - } else if (dims.size() != kDimSize4d) { - GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d, - kNc1hwc0CalcByDimsSize); - return GRAPH_FAILED; - } - - auto c0 = static_cast(GetDimC0(data_type)); - // Nc1hwc0 dims is according to nchw, dim c index is 1. - auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); - // Store dims is split c to c1 and c0. - std::vector store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; - return CalcElementCntByDims(store_dims, element_cnt); -} - -/// -/// Calculate FractalZ element num. -/// @param dims dim info -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfFractalZ(const std::vector &dims, DataType data_type, - int64_t &element_cnt) { - static char *parser_priority = std::getenv("PARSER_PRIORITY"); - if (parser_priority != nullptr && string(parser_priority) == "cce") { - if (dims.size() != kDimSize4d) { - GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d); - return GRAPH_FAILED; - } - auto c0 = static_cast(GetDimC0(data_type)); - // FractalZ dims is according to nchw, dim c index is 1. - auto c1 = static_cast(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); - - // Spread NC1HWC0 as a two dimension array, n as column dimension, - // C1HWC0 as row dimension - std::vector r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; - - int64_t r_count = 1; - graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH], - dims[kNchwDimIdxW], c0); - return graph_status; - } - - // Cube count in n - auto nc_cnt = static_cast(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize)); - - // Cube count in vertical direction(C1HWC0) - int64_t vc_cnt = r_count / c0; - // Element count in each cube - int64_t cube_elem_cnt = c0 * kTheCubeSize; - - if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) { - GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt); - return GRAPH_FAILED; - } - // Read data times needed by cube - int64_t c_cnt = nc_cnt * vc_cnt; - - if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) { - GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt); - return GRAPH_FAILED; - } - // Element count after fractal arrangement - element_cnt = c_cnt * cube_elem_cnt; - return GRAPH_SUCCESS; - } else { - return CalcElementCntByDims(dims, element_cnt); - } -} - -/// -/// Calculate tensor element num. -/// @param dims dim info -/// @param format tensor format -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcTensorElementCnt(const std::vector &dims, Format format, DataType data_type, - int64_t &element_cnt) { - const string format_str = TypeUtils::FormatToSerialString(format); - // Check dims - for (size_t i = 0; i < dims.size(); ++i) { - int64_t dim = dims[i]; - if (dim < 0) { - GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str()); - element_cnt = kElementCntUnknownShape; - return GRAPH_SUCCESS; - } else if (dim == 0) { - GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str()); - element_cnt = 0; - return GRAPH_SUCCESS; - } - } - - graphStatus graph_status; - switch (format) { - case FORMAT_ND: - case FORMAT_MD: - graph_status = CalcElementCntByDims(dims, element_cnt); - break; - case FORMAT_NCHW: - case FORMAT_HWCN: - case FORMAT_NHWC: - case FORMAT_CHWN: - graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt); - break; - case FORMAT_C1HWNCoC0: - graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt); - break; - case FORMAT_NC1HWC0: - graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt); - break; - case FORMAT_FRACTAL_Z: - graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); - break; - case FORMAT_FRACTAL_NZ: - case FORMAT_FRACTAL_ZZ: - case FORMAT_NDHWC: - case FORMAT_NCDHW: - case FORMAT_DHWCN: - case FORMAT_DHWNC: - case FORMAT_FRACTAL_Z_3D: - case FORMAT_FRACTAL_Z_3D_TRANSPOSE: - case FORMAT_NDC1HWC0: - case FORMAT_FRACTAL_Z_C04: - case FORMAT_FRACTAL_ZN_LSTM: - case FORMAT_NC1HWC0_C04: - graph_status = CalcElementCntByDims(dims, element_cnt); - break; - default: - GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); - graph_status = GRAPH_FAILED; - break; - } - - const string type_str = TypeUtils::DataTypeToSerialString(data_type); - if (graph_status == GRAPH_SUCCESS) { - GELOGD( - "CalcTensorElementCnt end, format=%d(%s)," - " data_type=%d(%s), element_cnt=%ld.", - format, format_str.c_str(), data_type, type_str.c_str(), element_cnt); - } else { - GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(), - data_type, type_str.c_str()); - } - return graph_status; -} - -/// -/// Calculate tensor mem size. -/// @param shape tensor shape -/// @param format tensor format -/// @param data_type tensor data type -/// @param mem_size -1 means unknown shape,other means mem size -/// @return GRAPH_SUCCESS:success, other:failed -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape, - Format format, - DataType data_type, - int64_t &mem_size) { - const string format_str = TypeUtils::FormatToSerialString(format); - const string type_str = TypeUtils::DataTypeToSerialString(data_type); - uint32_t type_size = 0; - bool result = TypeUtils::GetDataTypeLength(data_type, type_size); - if (!result) { - GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str()); - return GRAPH_FAILED; - } - - std::vector dims = shape.GetDims(); - int64_t element_cnt = 0; - graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format, - format_str.c_str(), data_type, type_str.c_str()); - return status; - } - // Support unknown shape - if (element_cnt < 0) { - mem_size = kMemSizeUnknownShape; - GELOGD( - "element_cnt is unknown. " - "format=%d(%s), data_type=%d(%s), mem_size=%ld", - format, format_str.c_str(), data_type, type_str.c_str(), mem_size); - return GRAPH_SUCCESS; - } - auto type_size_int64 = static_cast(type_size); - if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) { - GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).", - element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str()); - return GRAPH_FAILED; - } - mem_size = element_cnt * type_size_int64; - - GELOGD( - "CalcTensorMemSize end, " - "format=%d(%s), data_type=%d(%s), mem_size=%ld", - format, format_str.c_str(), data_type, type_str.c_str(), mem_size); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { - graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); - if (graph_status != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - // 64-byte alignment, if size is 0, align to 32 bytes - if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { - GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); - } else { - size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; - } - return GRAPH_SUCCESS; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { - GeShape output_shape = desc_temp.GetShape(); - Format format = desc_temp.GetFormat(); - DataType data_type = desc_temp.GetDataType(); - int64_t output_mem_size = 0; - graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size); - if (graph_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!"); - return GRAPH_FAILED; - } - - if (output_mem_size < 0) { - GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", - output_mem_size, INT64_MAX); - return GRAPH_FAILED; - } - - size_temp = output_mem_size; - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/metadef/graph/utils/tuning_utils.cc b/metadef/graph/utils/tuning_utils.cc deleted file mode 100644 index 0f07a197..00000000 --- a/metadef/graph/utils/tuning_utils.cc +++ /dev/null @@ -1,684 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/tuning_utils.h" -#include "../debug/ge_util.h" -#include "../debug/ge_op_types.h" - -namespace ge { -const std::string peer_node_name_attr = "_peerNodeName"; -const std::string parent_node_name_attr = "_parentNodeName"; -const std::string alias_name_attr = "_aliasName"; -const std::string parent_node_attr = "parentNode"; -const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; -const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; -const std::string non_tuning_subgraph_prefix = "/subgraph_"; -const std::set kPartitionOpTypes = {PLACEHOLDER, END}; -const std::set kExeTypes = {DATA, NETOUTPUT}; -NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; -NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; -NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; -NodeSet TuningUtils::netoutput_nodes_; -NodeSet TuningUtils::merged_graph_nodes_; -SubgraphCreateOutNode TuningUtils::create_output_; -std::mutex TuningUtils::mutex_; - -std::string TuningUtils::PrintCheckLog() { - std::stringstream ss; - ss << "d2n:{"; - for (const auto &pair : data_2_netoutput_) { - ss << "data:" << pair.first << "-" - << "netoutput:" << pair.second; - ss << " | "; - } - ss << "}"; - ss << "netoutputs:{"; - for (const auto &node : netoutput_nodes_) { - ss << "netoutput:" << node->GetName(); - ss << " | "; - } - ss << "}"; - return ss.str(); -} - -std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) { - if (anchor == nullptr) { - GELOGE(GRAPH_FAILED, "Anchor is nullptr"); - return "Null"; - } - auto node = anchor->GetOwnerNode(); - return node == nullptr ? "Null" : node->GetName(); -} - -// part 1 -graphStatus TuningUtils::ConvertGraphToFile(std::vector tuning_subgraphs, - std::vector non_tuning_subgraphs, bool exe_flag, - const std::string &path, const std::string &user_path) { - int64_t i = 0; - int64_t j = 0; - std::lock_guard lock(mutex_); - for (auto &subgraph : tuning_subgraphs) { - create_output_.emplace(subgraph, nullptr); - auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; - if (MakeExeGraph(subgraph, help_info) != SUCCESS) { - GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i); - return GRAPH_FAILED; - } - i++; - } - - for (auto &subgraph : non_tuning_subgraphs) { - create_output_.emplace(subgraph, nullptr); - auto help_info = HelpInfo{j, true, false, path, user_path}; - if (MakeExeGraph(subgraph, help_info) != SUCCESS) { - GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j); - return GRAPH_FAILED; - } - j++; - } - create_output_.clear(); - return SUCCESS; -} - -// +---------------+ -// | pld pld | -// | \ / | -// | relu relu | -// | \ / | -// | add | -// | | | -// | end | -// +---------------+ -// | -// | -// V -// +---------------+ -// | data data | -// | \ / | -// | relu relu | -// | \ / | -// | add | -// | | | -// | netoutput | -// +---------------+ -graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { - GE_CHECK_NOTNULL(exe_graph); - // if not make exe, just dump and return - if (!help_info.exe_flag) { - DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); - GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index); - return SUCCESS; - } - // modify sub graph - for (NodePtr &node : exe_graph->GetDirectNode()) { - // 1.handle pld - if (node->GetType() == PLACEHOLDER) { - if (HandlePld(node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), - exe_graph->GetName().c_str()); - return FAILED; - } - } - // 2.handle end - if (node->GetType() == END) { - if (HandleEnd(node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), - exe_graph->GetName().c_str()); - return FAILED; - } - } - } - graphStatus ret = exe_graph->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); - return ret; - } - // dump subgraphs which modified by us - if (help_info.user_path.empty()) { - DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path); - } - return SUCCESS; -} - -void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) { - if (!path.empty()) { - if (is_tuning_graph) { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } - } else { - path = "./"; - if (is_tuning_graph) { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } - } -} - -graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) { - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto data_op_desc = ComGraphMakeShared(node->GetName(), DATA); - GE_CHECK_NOTNULL(data_op_desc); - auto pld_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(pld_op_desc); - auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data - // data inputdesc & outputdesc set as same - if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) { - GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); - return FAILED; - } - if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) { - GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); - return FAILED; - } - data_node = graph->AddNode(data_op_desc); - GE_CHECK_NOTNULL(data_node); - if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); - return FAILED; - } - return SUCCESS; -} - -graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) { - auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - auto pld_desc = pld->GetOpDesc(); - GE_CHECK_NOTNULL(pld_desc); - // inherit - // a. set `end's input node type` as attr - std::string parent_op_type; - if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { - GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); - // b. set `end's input node name` as attr - std::string parent_op_name; - if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { - GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); - // c. set `end's input node's out anchor index` as attr - int parent_node_anchor_index; - if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { - GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); - GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), - data_node->GetName().c_str(), data_node->GetType().c_str()); - // d. set `end node name` as attr - std::string peer_end_name; - if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { - GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); - GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), - data_node->GetName().c_str(), data_node->GetType().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) { - auto type_pld = node->GetType(); - auto type_data = data_node->GetType(); - if (type_pld != PLACEHOLDER || type_data != DATA) { - GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(), - type_data.c_str()); - return FAILED; - } - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - std::vector output_map(node->GetAllOutDataAnchorsSize()); - for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { - output_map[i] = static_cast(i); - } - - auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(), - data_node->GetName().c_str(), ret); - return FAILED; - } - - NodeUtils::UnlinkAll(*node); - - ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); - return FAILED; - } - - GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(), - node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); - return ret; -} - -graphStatus TuningUtils::HandlePld(NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - NodePtr data_node = nullptr; - - // 1. create data node - if (CreateDataNode(node, data_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 2. add necessary info to data_node for recovery whole graph - if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 3. replace pld node by data node created before - if (ChangePld2Data(node, data_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) { - GE_CHECK_NOTNULL(node); - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - auto search = create_output_.find(graph); - if (search == create_output_.end()) { - GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(), - graph->GetName().c_str()); - return FAILED; - } - if (search->second != nullptr) { - out_node = search->second; - GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); - return SUCCESS; - } - auto out_op_desc = ComGraphMakeShared(node->GetName(), NETOUTPUT); - GE_CHECK_NOTNULL(out_op_desc); - out_node = graph->AddNode(out_op_desc); - GE_CHECK_NOTNULL(out_node); - if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); - return FAILED; - } - create_output_[graph] = out_node; - return SUCCESS; -} - -graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) { - GE_CHECK_NOTNULL(end); - GE_CHECK_NOTNULL(out_node); - auto op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector alias_names = {}; - (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); - alias_names.push_back(end->GetName()); - (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); - return SUCCESS; -} - -graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { - GE_CHECK_NOTNULL(end_node); - GE_CHECK_NOTNULL(out_node); - // get end in node is control node or normal node - AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) - ? Anchor::DynamicAnchorCast(end_node->GetInControlAnchor()) - : Anchor::DynamicAnchorCast(end_node->GetInDataAnchor(0)); - auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 - if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - // add edge between `end in node` and `out_node` - if (src_anchor->IsTypeOf()) { - std::shared_ptr anchor = - ComGraphMakeShared(out_node, out_node->GetAllInDataAnchors().size()); - GE_CHECK_NOTNULL(anchor); - out_node->in_data_anchors_.push_back(anchor); - if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - auto end_op_desc = end_node->GetOpDesc(); - GE_CHECK_NOTNULL(end_op_desc); - auto out_node_op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(out_node_op_desc); - // end node always has one input - if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); - return FAILED; - } - } else if (src_anchor->IsTypeOf()) { - auto anchor = out_node->GetInControlAnchor(); - if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - } else { - GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - - return SUCCESS; -} - -graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { - GE_CHECK_NOTNULL(end_node); - GE_CHECK_NOTNULL(out_node); - auto type_end = end_node->GetType(); - auto type_out = out_node->GetType(); - if (type_end != END || type_out != NETOUTPUT) { - GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(), - type_end.c_str(), type_out.c_str()); - return FAILED; - } - // link all `end nodes's in node` to this out_node - if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str()); - return FAILED; - } - // remove `end node` - NodeUtils::UnlinkAll(*end_node); - auto graph = end_node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) { - GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str()); - return FAILED; - } - return SUCCESS; -} - -graphStatus TuningUtils::HandleEnd(NodePtr &node) { - GE_CHECK_NOTNULL(node); - auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - NodePtr out_node = nullptr; - - // 1. create net_output node , add only one NetOutput node to one subgraph - if (CreateNetOutput(node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 2. add necessary info to out_node for recovery whole graph - if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 3. replace all end nodes by one output node created before - if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); - return SUCCESS; -} - -// part 2 -graphStatus TuningUtils::ConvertFileToGraph(const map &options, ge::Graph &graph) { - // 1. get all subgraph object - std::vector graphs; - // options format like {index:"subgraph_path"} - for (const auto &pair : options) { - ComputeGraphPtr compute_graph = ComGraphMakeShared(std::to_string(pair.first)); - if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) { - GELOGE(FAILED, "TUU:load graph from file failed"); - } - graphs.push_back(compute_graph); - } - // 2. merge graph - ComputeGraphPtr merged_graph = ComGraphMakeShared("whole_graph_after_tune"); - GE_CHECK_NOTNULL(merged_graph); - if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) { - GELOGE(FAILED, "TUU:MergeGraph failed"); - return FAILED; - } - // 3. set parent graph - for (const auto &node : merged_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str()); - return FAILED; - } - } - graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph); - return SUCCESS; -} - -// +----------------------------------+ -// | const const | -// | \ / | -// | netoutput(end,end) | -// +----------------------------------+ -// + -// +----------------------------------+ -// | data(pld) data(pld) | -// | \ / | -// | relu relu | -// | \ / | -// | \ / | -// | add | -// | | | -// | netoutput(end) | -// +----------------------------------+ -// + -// +----------------------------------+ -// | data(pld) | -// | / | -// | netoutput | -// +----------------------------------+ -// | -// | -// V -// +----------------------------------+ -// | const const | -// | \ / | -// | relu relu | -// | \ / | -// | \ / | -// | add | -// | | | -// | netoutput | -// +----------------------------------+ -graphStatus TuningUtils::MergeAllSubGraph(std::vector &subgraphs, - ComputeGraphPtr &output_merged_compute_graph) { - GE_CHECK_NOTNULL(output_merged_compute_graph); - // 1. handle all subgraphs - for (auto &subgraph : subgraphs) { - Status ret_status = MergeSubGraph(subgraph); - if (ret_status != SUCCESS) { - GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str()); - return ret_status; - } - } - - for (const auto &node : merged_graph_nodes_) { - (void)output_merged_compute_graph->AddNode(node); - GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); - } - - // 2. remove data and output node added by us - if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { - GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); - return FAILED; - } - graphStatus ret = output_merged_compute_graph->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret); - return ret; - } - GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); - GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) { - for (auto &node : subgraph->GetDirectNode()) { - if (kPartitionOpTypes.count(node->GetType()) > 0) { - GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type"); - return FAILED; - } - // handle data converted from pld node - if (node->GetType() == DATA) { - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::string peer_out_name; - bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); - if (has_valid_str) { - std::lock_guard lock(mutex_); - data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name); - data_node_2_netoutput_.emplace(node, peer_out_name); - continue; - } - } - // handle netoutput converted from end node - if (node->GetType() == NETOUTPUT) { - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector out_alias_name; - bool has_valid_str = - (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); - if (has_valid_str) { - std::lock_guard lock(mutex_); - netoutput_nodes_.insert(node); - } - } - { - std::lock_guard lock(mutex_); - merged_graph_nodes_.emplace(node); - } - GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); - } - GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(graph); - // 1. traverse - for (auto &pair : data_node_2_netoutput_) { - auto data_node = pair.first; - GE_CHECK_NOTNULL(data_node); - auto netoutput_name = pair.second; - auto netoutput_node = graph->FindNode(netoutput_name); - GE_CHECK_NOTNULL(netoutput_node); - data_node_2_netoutput_node_.emplace(data_node, netoutput_node); - // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` - AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) - ? Anchor::DynamicAnchorCast(data_node->GetOutControlAnchor()) - : Anchor::DynamicAnchorCast(data_node->GetOutDataAnchor(0)); - AnchorPtr net_output_in_anchor = nullptr; - AnchorPtr src_out_anchor = nullptr; - if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed", - netoutput_node->GetName().c_str(), data_node->GetName().c_str()); - return FAILED; - } - // 3. relink - if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), - GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GE_CHECK_NOTNULL(data_out_anchor); - for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { - if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - } - } - // 4. remove out nodes added by us - for (auto &node : netoutput_nodes_) { - NodeUtils::UnlinkAll(*node); - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); - } - return SUCCESS; -} - -graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, - AnchorPtr &src_out_anchor) { - // 1. get `data parent node name`, i.e. `netoutput input node name` - std::string netoutput_input_name; - auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) { - GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str()); - return FAILED; - } - // 2. find index - int parent_node_anchor_index; - if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) { - GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str()); - return FAILED; - } - // 3.find in data or ctrl anchor by 1&2 step - for (auto &in_anchor : out_node->GetAllInAnchors()) { - GE_CHECK_NOTNULL(in_anchor); - for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl - GE_CHECK_NOTNULL(src_anchor); - auto src_node = src_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(src_node); - if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) { - dest_in_anchor = in_anchor; - src_out_anchor = src_anchor; - GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", - out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(), - parent_node_anchor_index, data_node->GetName().c_str()); - break; - } - } - } - GE_CHECK_NOTNULL(dest_in_anchor); - GE_CHECK_NOTNULL(src_out_anchor); - return SUCCESS; -} - -} // namespace ge \ No newline at end of file diff --git a/metadef/graph/utils/type_utils.cc b/metadef/graph/utils/type_utils.cc deleted file mode 100644 index 2efc530e..00000000 --- a/metadef/graph/utils/type_utils.cc +++ /dev/null @@ -1,448 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/utils/type_utils.h" -#include "debug/ge_util.h" - -using domi::domiTensorFormat_t; - -namespace ge { -static const std::map kFormatToStringMap = { - {FORMAT_NCHW, "NCHW"}, - {FORMAT_NHWC, "NHWC"}, - {FORMAT_ND, "ND"}, - {FORMAT_NC1HWC0, "NC1HWC0"}, - {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, - {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, - {FORMAT_NHWC1C0, "NHWC1C0"}, - {FORMAT_FSR_NCHW, "FSR_NCHW"}, - {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, - {FORMAT_C1HWNC0, "C1HWNC0"}, - {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, - {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, - {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, - {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, - {FORMAT_CHWN, "CHWN"}, - {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, - {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, - {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, - {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, - {FORMAT_HWCN, "HWCN"}, - {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, - {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, - {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, - {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, - {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, - {FORMAT_MD, "MD"}, - {FORMAT_NDHWC, "NDHWC"}, - {FORMAT_NCDHW, "NCDHW"}, - {FORMAT_DHWCN, "DHWCN"}, - {FORMAT_DHWNC, "DHWNC"}, - {FORMAT_NDC1HWC0, "NDC1HWC0"}, - {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, - {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, - {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, - {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, - {FORMAT_CN, "CN"}, - {FORMAT_NC, "NC"}, - {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, - {FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"}, - {FORMAT_RESERVED, "FORMAT_RESERVED"}, - {FORMAT_ALL, "ALL"}}; - -static const std::map kDomiFormatToGeFormat = { - {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, - {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, - {domi::DOMI_TENSOR_ND, FORMAT_ND}, - {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, - {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, - {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, - {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, - {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, - {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, - {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, - {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, - {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, - {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, - {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, - {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, - {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, - {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}}; - -static const std::unordered_set kInternalFormat = {"NC1HWC0", - "FRACTAL_Z", - "NC1C0HWPAD", - "NHWC1C0", - "FRACTAL_DECONV", - "C1HWNC0", - "FRACTAL_DECONV_TRANSPOSE", - "FRACTAL_DECONV_SP_STRIDE_TRANS", - "NC1HWC0_C04", - "FRACTAL_Z_C04", - "FRACTAL_DECONV_SP_STRIDE8_TRANS", - "NC1KHKWHWC0", - "C1HWNCoC0", - "FRACTAL_ZZ", - "FRACTAL_NZ", - "NDC1HWC0", - "FORMAT_FRACTAL_Z_3D", - "FORMAT_FRACTAL_Z_3D_TRANSPOSE", - "FORMAT_FRACTAL_ZN_LSTM", - "FORMAT_FRACTAL_Z_G"}; - -static const std::map kDataFormatMap = { - {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; - -static const std::map kStringToFormatMap = { - {"NCHW", FORMAT_NCHW}, - {"NHWC", FORMAT_NHWC}, - {"ND", FORMAT_ND}, - {"NC1HWC0", FORMAT_NC1HWC0}, - {"FRACTAL_Z", FORMAT_FRACTAL_Z}, - {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, - {"NHWC1C0", FORMAT_NHWC1C0}, - {"FSR_NCHW", FORMAT_FSR_NCHW}, - {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, - {"C1HWNC0", FORMAT_C1HWNC0}, - {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, - {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, - {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, - {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, - {"CHWN", FORMAT_CHWN}, - {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, - {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, - {"BN_WEIGHT", FORMAT_BN_WEIGHT}, - {"FILTER_HWCK", FORMAT_FILTER_HWCK}, - {"HWCN", FORMAT_HWCN}, - {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, - {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, - {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, - {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, - {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, - {"MD", FORMAT_MD}, - {"C1HWNCoC0", FORMAT_C1HWNCoC0}, - {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, - {"NDHWC", FORMAT_NDHWC}, - {"NCDHW", FORMAT_NCDHW}, - {"DHWCN", FORMAT_DHWCN}, - {"DHWNC", FORMAT_DHWNC}, - {"NDC1HWC0", FORMAT_NDC1HWC0}, - {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, - {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, - {"CN", FORMAT_CN}, - {"NC", FORMAT_NC}, - {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, - {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, - {"FORMAT_RESERVED", FORMAT_RESERVED}, - {"ALL", FORMAT_ALL}, - {"NULL", FORMAT_NULL}}; - -static const std::map kDataTypeToStringMap = { - {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. - {DT_FLOAT, "DT_FLOAT"}, // float type - {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type - {DT_INT8, "DT_INT8"}, // int8 type - {DT_INT16, "DT_INT16"}, // int16 type - {DT_UINT16, "DT_UINT16"}, // uint16 type - {DT_UINT8, "DT_UINT8"}, // uint8 type - {DT_INT32, "DT_INT32"}, // uint32 type - {DT_INT64, "DT_INT64"}, // int64 type - {DT_UINT32, "DT_UINT32"}, // unsigned int32 - {DT_UINT64, "DT_UINT64"}, // unsigned int64 - {DT_BOOL, "DT_BOOL"}, // bool type - {DT_DOUBLE, "DT_DOUBLE"}, // double type - {DT_DUAL, "DT_DUAL"}, // dual output type - {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type - {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type - {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type - {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type - {DT_QINT8, "DT_QINT8"}, // qint8 type - {DT_QINT16, "DT_QINT16"}, // qint16 type - {DT_QINT32, "DT_QINT32"}, // qint32 type - {DT_QUINT8, "DT_QUINT8"}, // quint8 type - {DT_QUINT16, "DT_QUINT16"}, // quint16 type - {DT_RESOURCE, "DT_RESOURCE"}, // resource type - {DT_STRING_REF, "DT_STRING_REF"}, // string ref type - {DT_STRING, "DT_STRING"}, // string type -}; - -static const std::map kStringTodataTypeMap = { - {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. - {"DT_FLOAT", DT_FLOAT}, // float type - { - "DT_FLOAT16", - DT_FLOAT16, - }, // fp16 type - {"DT_INT8", DT_INT8}, // int8 type - {"DT_INT16", DT_INT16}, // int16 type - {"DT_UINT16", DT_UINT16}, // uint16 type - {"DT_UINT8", DT_UINT8}, // uint8 type - {"DT_INT32", DT_INT32}, // uint32 type - {"DT_INT64", DT_INT64}, // int64 type - {"DT_UINT32", DT_UINT32}, // unsigned int32 - {"DT_UINT64", DT_UINT64}, // unsigned int64 - {"DT_BOOL", DT_BOOL}, // bool type - {"DT_DOUBLE", DT_DOUBLE}, // double type - {"DT_DUAL", DT_DUAL}, // dual output type - {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type - {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type - {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type - {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type - {"DT_QINT8", DT_QINT8}, // qint8 type - {"DT_QINT16", DT_QINT16}, // qint16 type - {"DT_QINT32", DT_QINT32}, // qint32 type - {"DT_QUINT8", DT_QUINT8}, // quint8 type - {"DT_QUINT16", DT_QUINT16}, // quint16 type - {"DT_RESOURCE", DT_RESOURCE}, // resource type - {"DT_STRING_REF", DT_STRING_REF}, // string ref type - {"DT_STRING", DT_STRING}, // string type -}; - -static const std::map kDataTypeToLength = { - {DT_BOOL, sizeof(bool)}, - {DT_INT64, sizeof(int64_t)}, - {DT_UINT64, sizeof(int64_t)}, - {DT_FLOAT, sizeof(float)}, - {DT_INT32, sizeof(int32_t)}, - {DT_UINT32, sizeof(int32_t)}, - {DT_INT8, sizeof(char)}, - {DT_UINT8, sizeof(char)}, - {DT_INT16, sizeof(int16_t)}, - {DT_UINT16, sizeof(int16_t)}, - {DT_FLOAT16, sizeof(int16_t)}, - {DT_DOUBLE, sizeof(double)}, - {DT_DUAL, sizeof(float) + sizeof(int8_t)}, - {DT_DUAL_SUB_INT8, sizeof(int8_t)}, - {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, - {DT_COMPLEX64, sizeof(int64_t)}, - {DT_COMPLEX128, sizeof(int64_t) * 2}, - {DT_QINT8, sizeof(int8_t)}, - {DT_QINT16, sizeof(int16_t)}, - {DT_QINT32, sizeof(int32_t)}, - {DT_QUINT8, sizeof(uint8_t)}, - {DT_QUINT16, sizeof(uint16_t)}, - {DT_STRING_REF, sizeof(uint64_t) * 2}, - {DT_STRING, sizeof(uint64_t)}, - {DT_RESOURCE, sizeof(uint64_t)}, -}; - -static const std::map kFmkTypeToString = { - {domi::CAFFE, "caffe"}, {domi::MINDSPORE, "mindspore"}, {domi::TENSORFLOW, "tensorflow"}, - {domi::ANDROID_NN, "android_nn"}, {domi::ONNX, "onnx"}, {domi::FRAMEWORK_RESERVED, "framework_reserved"}, -}; - -static const std::map kImplyTypeToString = { - {domi::ImplyType::BUILDIN, "buildin"}, {domi::ImplyType::TVM, "tvm"}, {domi::ImplyType::CUSTOM, "custom"}, - {domi::ImplyType::AI_CPU, "ai_cpu"}, {domi::ImplyType::CCE, "cce"}, {domi::ImplyType::GELOCAL, "gelocal"}, - {domi::ImplyType::HCCL, "hccl"}, {domi::ImplyType::INVALID, "invalid"}}; - -std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) { - auto it = kImplyTypeToString.find(imply_type); - if (it != kImplyTypeToString.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type); - return "UNDEFINED"; - } -} - -bool TypeUtils::IsDataTypeValid(DataType dt) { - uint32_t num = static_cast(dt); - GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid"); - return true; -} - -std::string TypeUtils::DataTypeToSerialString(DataType data_type) { - auto it = kDataTypeToStringMap.find(data_type); - if (it != kDataTypeToStringMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "DataTypeToSerialString: datatype not support %u", data_type); - return "UNDEFINED"; - } -} - -DataType TypeUtils::SerialStringToDataType(const std::string &str) { - auto it = kStringTodataTypeMap.find(str); - if (it != kStringTodataTypeMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "SerialStringToDataType: datatype not support %s", str.c_str()); - return DT_UNDEFINED; - } -} - -bool TypeUtils::IsFormatValid(Format format) { - uint32_t num = static_cast(format); - GE_CHK_BOOL_EXEC((num <= FORMAT_RESERVED), return false, "The Format is invalid"); - return true; -} - -bool TypeUtils::IsInternalFormat(Format format) { - std::string serial_format = FormatToSerialString(format); - auto iter = kInternalFormat.find(serial_format); - bool result = (iter == kInternalFormat.end()) ? false : true; - return result; -} - -std::string TypeUtils::FormatToSerialString(Format format) { - auto it = kFormatToStringMap.find(format); - if (it != kFormatToStringMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "Format not support %u", format); - return "RESERVED"; - } -} -Format TypeUtils::SerialStringToFormat(const std::string &str) { - auto it = kStringToFormatMap.find(str); - if (it != kStringToFormatMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); - return FORMAT_RESERVED; - } -} - -Format TypeUtils::DataFormatToFormat(const std::string &str) { - auto it = kDataFormatMap.find(str); - if (it != kDataFormatMap.end()) { - return it->second; - } else { - GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); - return FORMAT_RESERVED; - } -} - -Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { - auto it = kDomiFormatToGeFormat.find(domi_format); - if (it != kDomiFormatToGeFormat.end()) { - return it->second; - } - GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); - return FORMAT_RESERVED; -} - -std::string TypeUtils::FmkTypeToSerialString(domi::FrameworkType fmk_type) { - auto it = kFmkTypeToString.find(fmk_type); - if (it != kFmkTypeToString.end()) { - return it->second; - } else { - GELOGW("Framework type not support %d.", fmk_type); - return ""; - } -} - -static inline void CopyDataFromBuffer(vector &data, const Buffer &buffer) { - data.clear(); - if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { - data.assign(buffer.GetData(), buffer.GetData() + buffer.GetSize()); - } -} - -graphStatus Usr2DefQuantizeFactor(const UsrQuantizeFactor &usr, QuantizeFactor &def) { - def.scale_mode = uint32_t(usr.scale_mode); - def.set_scale_value(usr.scale_value.data(), usr.scale_value.size()); - def.scale_offset = usr.scale_offset; - def.set_offset_data_value(usr.offset_data_value.data(), usr.offset_data_value.size()); - def.offset_data_offset = usr.offset_data_offset; - def.set_offset_weight_value(usr.offset_weight_value.data(), usr.offset_weight_value.size()); - def.offset_weight_offset = usr.offset_weight_offset; - def.set_offset_pad_value(usr.offset_pad_value.data(), usr.offset_pad_value.size()); - def.offset_pad_offset = usr.offset_pad_offset; - return GRAPH_SUCCESS; -} -graphStatus Def2UsrQuantizeFactor(const QuantizeFactor &def, UsrQuantizeFactor &usr) { - usr.scale_mode = UsrQuantizeScaleMode(def.scale_mode); - CopyDataFromBuffer(usr.scale_value, def.scale_value); - usr.scale_offset = def.scale_offset; - CopyDataFromBuffer(usr.offset_data_value, def.offset_data_value); - usr.offset_data_offset = def.offset_data_offset; - CopyDataFromBuffer(usr.offset_weight_value, def.offset_weight_value); - usr.offset_weight_offset = def.offset_weight_offset; - CopyDataFromBuffer(usr.offset_pad_value, def.offset_pad_value); - usr.offset_pad_offset = def.offset_pad_offset; - return GRAPH_SUCCESS; -} -graphStatus Usr2DefUsrQuantizeCalcFactor(const UsrQuantizeCalcFactor &usr, QuantizeCalcFactor &def) { - def.set_offsetw(usr.offsetw.data(), usr.offsetw.size()); - def.offsetw_offset = usr.offsetw_offset; - def.set_offsetd(usr.offsetd.data(), usr.offsetd.size()); - def.offsetd_offset = usr.offsetd_offset; - def.set_scalereq(usr.scalereq.data(), usr.scalereq.size()); - def.scaledreq_offset = usr.scaledreq_offset; - def.set_offsetdnext(usr.offsetdnext.data(), usr.offsetdnext.size()); - def.offsetdnext_offset = usr.offsetdnext_offset; - return GRAPH_SUCCESS; -} -graphStatus Def2UsrQuantizeCalcFactor(const QuantizeCalcFactor &def, UsrQuantizeCalcFactor &usr) { - CopyDataFromBuffer(usr.offsetw, def.offsetw); - usr.offsetw_offset = def.offsetw_offset; - CopyDataFromBuffer(usr.offsetd, def.offsetd); - usr.offsetd_offset = def.offsetd_offset; - CopyDataFromBuffer(usr.scalereq, def.scalereq); - usr.scaledreq_offset = def.scaledreq_offset; - CopyDataFromBuffer(usr.offsetdnext, def.offsetdnext); - usr.offsetdnext_offset = def.offsetdnext_offset; - return GRAPH_SUCCESS; -} -graphStatus TypeUtils::Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def) { - def.quantize_algo = uint32_t(usr.quantize_algo); - def.scale_type = uint32_t(usr.scale_type); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.quantize_param, def.quantize_param), - "Usr2DefQuantizeFactor quantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.dequantize_param, def.dequantize_param), - "Usr2DefQuantizeFactor dequantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.requantize_param, def.requantize_param), - "Usr2DefQuantizeFactor requantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefUsrQuantizeCalcFactor(usr.quantizecalc_param, def.quantizecalc_param), - "Usr2DefQuantizeFactor quantizecalc_param failed"); - return GRAPH_SUCCESS; -} -graphStatus TypeUtils::Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr) { - usr.quantize_algo = UsrQuantizeAlgorithm(def.quantize_algo); - usr.scale_type = UsrQuantizeScaleType(def.scale_type); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.quantize_param, usr.quantize_param), - "Def2UsrQuantizeFactor quantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.dequantize_param, usr.dequantize_param), - "Def2UsrQuantizeFactor dequantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.requantize_param, usr.requantize_param), - "Def2UsrQuantizeFactor requantize_param failed"); - GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeCalcFactor(def.quantizecalc_param, usr.quantizecalc_param), - "Def2UsrQuantizeCalcFactor quantizecalc_param failed"); - return GRAPH_SUCCESS; -} -bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) { - auto it = kDataTypeToLength.find(data_type); - if (it != kDataTypeToLength.end()) { - length = it->second; - return true; - } else { - GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); - return false; - } -} -bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) { - // Not overflow - if (a == 0) { - return false; - } - if ((ULLONG_MAX / a) >= b) { - return false; - } - return true; -} -} // namespace ge diff --git a/metadef/inc/external/graph/attr_value.h b/metadef/inc/external/graph/attr_value.h deleted file mode 100644 index af430f9b..00000000 --- a/metadef/inc/external/graph/attr_value.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ -#define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ - -#include -#include -#include -#include - -#include "./ge_error_codes.h" - -using std::make_shared; -using std::map; -using std::pair; -using std::string; -using std::to_string; -using std::unique_ptr; -using std::vector; - -namespace ge { -class AttrValueImpl; -/*lint -e148*/ -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { - public: - using INT = int64_t; - using FLOAT = float; - using STR = std::string; - - AttrValue(); - ~AttrValue() = default; - - // GetValue, not list type - template - graphStatus GetValue(DT &val) const { - T valGet; - auto status = GetValue(valGet); - if (status != GRAPH_SUCCESS) { - return status; - } - val = DT(valGet); - return GRAPH_SUCCESS; - } - - template - static T CreateFrom(DT &&val) { - return val; - } - - std::shared_ptr impl; - - private: -#define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const; - VALUE_SET_GET_DEC(AttrValue::STR) - VALUE_SET_GET_DEC(AttrValue::INT) - VALUE_SET_GET_DEC(AttrValue::FLOAT) -#undef VALUE_SET_GET_DEC -}; -/*lint +e148*/ -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ diff --git a/metadef/inc/external/graph/ge_error_codes.h b/metadef/inc/external/graph/ge_error_codes.h deleted file mode 100644 index d815a22d..00000000 --- a/metadef/inc/external/graph/ge_error_codes.h +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ -#define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ - -namespace ge { -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif - -using graphStatus = uint32_t; -const graphStatus GRAPH_FAILED = 0xFFFFFFFF; -const graphStatus GRAPH_SUCCESS = 0; -const graphStatus GRAPH_PARAM_INVALID = 50331649; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ diff --git a/metadef/inc/external/graph/graph.h b/metadef/inc/external/graph/graph.h deleted file mode 100644 index 30886733..00000000 --- a/metadef/inc/external/graph/graph.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ -#define INC_EXTERNAL_GRAPH_GRAPH_H_ - -#include -#include -#include -#include - -#include "./operator.h" - -namespace ge { -class GraphImpl; - -using GraphImplPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { - friend class GraphUtils; - - public: - explicit Graph(const std::string &name); - - Graph() = default; - - ~Graph() = default; - - Graph &SetInputs(const std::vector &inputs); - - Graph &SetOutputs(const std::vector &outputs); - - Graph &SetOutputs(const std::vector>> &output_indexs); - - Graph &SetOutputs(const std::vector> &outputs); - - Graph &SetTargets(const std::vector &targets); - - bool IsValid() const; - - graphStatus AddOp(const ge::Operator &op); - - graphStatus FindOpByName(const string &name, ge::Operator &op) const; - - graphStatus FindOpByType(const string &type, std::vector &ops) const; - - graphStatus GetAllOpName(std::vector &op_name) const; - - graphStatus SaveToFile(const string &file_name) const; - - graphStatus LoadFromFile(const string &file_name); - - const std::string &GetName() const; - - /// - /// Set is need train iteration. - /// If set true, it means this graph need to be run iteration some - /// times(according variant "npu_runconfig/iterations_per_loop"). - /// @param need_iteration need_iteration:whether to set iteration or not - /// - void SetNeedIteration(bool need_iteration); - - private: - GraphImplPtr impl_{nullptr}; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/metadef/inc/external/graph/inference_context.h b/metadef/inc/external/graph/inference_context.h deleted file mode 100644 index 69079142..00000000 --- a/metadef/inc/external/graph/inference_context.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ -#define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ - -#include -#include -#include - -#include "./tensor.h" -#include "./types.h" - -namespace ge { -class InferenceContext; -using InferenceContextPtr = std::shared_ptr; - -class ShapeAndTypeImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { - public: - ShapeAndType(); - ~ShapeAndType() = default; - - ShapeAndType(const Shape &shape, DataType dataType); - - void SetShape(const Shape &shape); - - void SetType(DataType dataType); - - Shape GetShape() const; - - DataType GetDataType() const; - - private: - std::shared_ptr shape_and_type_impl_; -}; - -class InferenceContextImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { - public: - ~InferenceContext() = default; - InferenceContext(const InferenceContext &context) = delete; - InferenceContext(const InferenceContext &&context) = delete; - InferenceContext &operator=(const InferenceContext &context) = delete; - InferenceContext &operator=(const InferenceContext &&context) = delete; - - void SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types); - const std::vector> &GetInputHandleShapesAndTypes() const; - const std::vector> &GetOutputHandleShapesAndTypes() const; - void SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types); - void SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types); - - void SetMarks(const std::vector &marks); - const std::vector &GetMarks() const; - - static std::unique_ptr Create(); - - private: - explicit InferenceContext(std::unique_ptr &impl); - std::shared_ptr inference_context_impl_; -}; -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ diff --git a/metadef/inc/external/graph/operator.h b/metadef/inc/external/graph/operator.h deleted file mode 100644 index 81d726eb..00000000 --- a/metadef/inc/external/graph/operator.h +++ /dev/null @@ -1,289 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_H_ - -#include -#include -#include -#include -#include - -#include "./ge_error_codes.h" -#include "./inference_context.h" -#include "./tensor.h" - -#ifndef USER_GE_LOGI -#define USER_GE_LOGI(...) -#endif // USER_GE_LOGI - -#ifndef USER_GE_LOGW -#define USER_GE_LOGW(...) -#endif // USER_GE_LOGW - -#ifndef USER_GE_LOGE -#define USER_GE_LOGE(...) -#endif // USER_GE_LOGE - -#define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") -#define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") - -namespace ge { -class Operator; -class OperatorImpl; -class NodeUtils; -class NamedAttrs; -class Graph; -class AttrValue; -class Node; - -using SubgraphBuilder = std::function; -using OperatorImplPtr = std::shared_ptr; -using OperatorPtr = std::shared_ptr; - -class OpIO; -using OutHandler = std::shared_ptr; -using InHandler = std::shared_ptr; - -using std::function; -using std::shared_ptr; -using std::string; - -/*lint -e148*/ -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { - public: - friend class OperatorImpl; - friend class GraphBuilderImpl; - friend class NodeUtils; - - using OpInt = int64_t; - using OpFloat = float; - using OpString = string; - using OpBool = bool; - using OpTensor = Tensor; - using OpType = ge::DataType; - using OpNamedAttrs = ge::NamedAttrs; - using OpListInt = std::vector; - using OpListFloat = std::vector; - using OpListString = std::vector; - using OpListBool = std::vector; - using OpListTensor = std::vector; - using OpBytes = std::vector; - using OpListListInt = std::vector>; - using OpListType = std::vector; - using OpListNamedAttrs = std::vector; - - Operator() {} - - explicit Operator(const string &type); - - Operator(const string &name, const string &type); // lint !e148 - - virtual ~Operator() = default; - - bool IsEmpty() const; - - string GetName() const; - - string GetOpType() const; - - // Only has one output index = 0 - Operator &SetInput(const string &dst_name, const Operator &src_oprt); - - Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 - - Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); - - Operator &AddControlInput(const Operator &src_oprt); - - graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; - - TensorDesc GetInputDesc(const string &name) const; - - TensorDesc GetInputDesc(uint32_t index) const; - - int GetDynamicOutputNum(const string &name) const; - - int GetDynamicInputNum(const string &name) const; - - graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; - - graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); - - TensorDesc GetOutputDesc(const string &name) const; - - TensorDesc GetOutputDesc(uint32_t index) const; - - graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 - - TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; - - graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 - - TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; - - graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 - - graphStatus InferShapeAndType(); // lint !e148 - - void SetInferenceContext(const InferenceContextPtr &inference_context); - InferenceContextPtr GetInferenceContext() const; - - graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 - - size_t GetInputsSize() const; - - size_t GetOutputsSize() const; - - const std::map GetAllAttrNamesAndTypes() const; - - Operator &SetAttr(const string &name, int64_t attr_value); - Operator &SetAttr(const string &name, int32_t attr_value); - Operator &SetAttr(const string &name, uint32_t attr_value); - graphStatus GetAttr(const string &name, int64_t &attr_value) const; - graphStatus GetAttr(const string &name, int32_t &attr_value) const; - graphStatus GetAttr(const string &name, uint32_t &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - Operator &SetAttr(const string &name, const std::vector &attr_value); - Operator &SetAttr(const string &name, const std::vector &attr_value); - Operator &SetAttr(const string &name, std::initializer_list &&attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, float attr_value); - graphStatus GetAttr(const string &name, float &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - Operator &SetAttr(const string &name, AttrValue &&attr_value); - graphStatus GetAttr(const string &name, AttrValue &attr_value) const; - - Operator &SetAttr(const string &name, const string &attr_value); - graphStatus GetAttr(const string &name, string &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, bool attr_value); - graphStatus GetAttr(const string &name, bool &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, const Tensor &attr_value); - graphStatus GetAttr(const string &name, Tensor &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - // Bytes type - Operator &SetAttr(const string &name, const OpBytes &attr_value); - // Bytes type - graphStatus GetAttr(const string &name, OpBytes &attr_value) const; - - Operator &SetAttr(const string &name, const std::vector> &attr_value); - graphStatus GetAttr(const string &name, std::vector> &attr_value) const; - - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - Operator &SetAttr(const string &name, const ge::DataType &attr_value); - graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; - - // func type - Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); - graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; - Operator &SetAttr(const string &name, const std::vector &attr_value); - graphStatus GetAttr(const string &name, std::vector &attr_value) const; - - void BreakConnect() const; - - size_t GetSubgraphNamesCount() const; - std::vector GetSubgraphNames() const; - SubgraphBuilder GetSubgraphBuilder(const string &name) const; - Graph GetSubgraph(const string &name) const; - SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; - Graph GetDynamicSubgraph(const string &name, uint32_t index) const; - - protected: - void AttrRegister(const string &name, float attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, int64_t attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const string &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, bool attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const Tensor &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const OpBytes &attr_value); - void AttrRegister(const string &name, const std::vector> &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - void AttrRegister(const string &name, const ge::DataType &attr_value); - void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); - void AttrRegister(const string &name, const std::vector &attr_value); - - explicit Operator(OperatorImplPtr &&op_impl); - - void InputRegister(const string &name); - - void OptionalInputRegister(const string &name); - - void InferFuncRegister(const std::function &func); - - void VerifierFuncRegister(const std::function &func); - - void InferFormatFuncRegister(const std::function &func); - - void OutputRegister(const string &name); - - void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); - - void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); - - void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); - - void RequiredAttrRegister(const string &name); - - graphStatus VerifyAll(); // lint !e148 - - // Only has one output index = 0 - Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); - - Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, - const string &name); // lint !e148 - - void SubgraphRegister(const string &ir_name, bool dynamic); - void SubgraphCountRegister(const string &ir_name, uint32_t count); - void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); - - private: - Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 - - OutHandler GetOutput(const string &name) const; - - OutHandler GetOutput(uint32_t index) const; - - OperatorImplPtr GetOperatorImplPtr() const; - - OperatorImplPtr operator_impl_{nullptr}; - - graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; - - std::shared_ptr GetNode() const; -}; -/*lint +e148*/ -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ diff --git a/metadef/inc/external/graph/operator_factory.h b/metadef/inc/external/graph/operator_factory.h deleted file mode 100644 index f9ec7669..00000000 --- a/metadef/inc/external/graph/operator_factory.h +++ /dev/null @@ -1,68 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ - -#include -#include -#include -#include - -#include "./operator.h" -#include "./ge_error_codes.h" - -namespace ge { -using OpCreator = std::function; -using InferShapeFunc = std::function; -using InferFormatFunc = std::function; -using VerifyFunc = std::function; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory { - public: - static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); - - static graphStatus GetOpsTypeList(std::vector &all_ops); - - static bool IsExistOp(const string &operator_type); -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister { - public: - OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator); - ~OperatorCreatorRegister() = default; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister { - public: - InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func); - ~InferShapeFuncRegister() = default; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister { - public: - InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func); - ~InferFormatFuncRegister() = default; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister { - public: - VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func); - ~VerifyFuncRegister() = default; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ diff --git a/metadef/inc/external/graph/operator_reg.h b/metadef/inc/external/graph/operator_reg.h deleted file mode 100644 index 759c70f2..00000000 --- a/metadef/inc/external/graph/operator_reg.h +++ /dev/null @@ -1,376 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ - -#include -#include -#include -#include - -#include "graph/operator.h" -#include "graph/operator_factory.h" -#include "graph/tensor.h" -#include "graph/types.h" -#include "graph/graph.h" - -namespace ge { -using std::function; -using std::string; -using std::vector; - -class OpReg { - public: - OpReg &N() { return *this; } - - OpReg &ATTR() { return *this; } - - OpReg &REQUIRED_ATTR() { return *this; } - - OpReg &INPUT() { return *this; } - - OpReg &OPTIONAL_INPUT() { return *this; } - - OpReg &OUTPUT() { return *this; } - - OpReg &GRAPH() { return *this; } - - OpReg &DYNAMIC_GRAPH() { return *this; } - - OpReg &INFER_SHAPE_AND_TYPE() { return *this; } -}; - -#define REG_OP(x) \ - namespace op { \ - class x : public Operator { \ - typedef x _THIS_TYPE; \ - \ - public: \ - explicit x(const string &name) : Operator(name, #x) { __##x(); } \ - x() : Operator(#x) { __##x(); } \ - \ - private: \ - void __##x() { \ - OpReg() - -#define ATTR(x, Type, ...) \ - N(); \ - __attr_##x(); \ - } \ - \ - public: \ - static const string name_attr_##x() { return #x; } \ - Op##Type get_attr_##x() const { \ - Op##Type ret = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return ret; \ - } \ - return ret; \ - } \ - _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ - \ - private: \ - void __attr_##x() { \ - Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ - string attr_name(#x); \ - (void)OpReg() - -#define REQUIRED_ATTR(x, Type) \ - N(); \ - __required_attr_##x(); \ - } \ - \ - public: \ - static const string name_attr_##x() { return #x; } \ - Op##Type get_attr_##x() const { \ - Op##Type ret; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return ret; \ - } \ - return ret; \ - } \ - _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { return *this; } \ - \ - private: \ - void __required_attr_##x() { \ - Operator::RequiredAttrRegister(#x); \ - string attr_name(#x); \ - (void)OpReg() - -#define INPUT(x, t) \ - N(); \ - __input_##x(); \ - } \ - \ - public: \ - static const string name_in_##x() { return #x; } \ - _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ - Operator::SetInput(#x, v, srcName); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ - Operator::SetInput(#x, v, index); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v) { \ - Operator::SetInput(#x, v); \ - return *this; \ - } \ - TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ - graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateInputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void __input_##x() { \ - Operator::InputRegister(#x); \ - (void)OpReg() - -#define OPTIONAL_INPUT(x, t) \ - N(); \ - __optional_input_##x(); \ - } \ - \ - public: \ - static const string name_in_##x() { return #x; } \ - _THIS_TYPE &set_input_##x(Operator &v) { \ - Operator::SetInput(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ - Operator::SetInput(#x, v, srcName); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ - Operator::SetInput(#x, v, index); \ - return *this; \ - } \ - TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ - graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateInputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void __optional_input_##x() { \ - Operator::OptionalInputRegister(#x); \ - (void)OpReg() - -#define OUTPUT(x, t) \ - N(); \ - __out_##x(); \ - } \ - \ - public: \ - static const string name_out_##x() { return #x; } \ - TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \ - graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateOutputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void __out_##x() { \ - Operator::OutputRegister(#x); \ - (void)OpReg() - -#define DYNAMIC_INPUT(x, t) \ - N(); \ - __dy_input_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ - Operator::DynamicInputRegister(#x, num, isPushBack); \ - return *this; \ - } \ - _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ - Operator::DynamicInputRegisterByIndex(#x, num, index); \ - return *this; \ - } \ - TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \ - graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ - } \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ - Operator::SetInput(#x, dstIndex, v); \ - return *this; \ - } \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ - Operator::SetInput(#x, dstIndex, v, srcName); \ - return *this; \ - } \ - \ - private: \ - void __dy_input_##x() { \ - Operator::DynamicInputRegister(#x, 0, true); \ - (void)OpReg() - -#define DYNAMIC_OUTPUT(x, t) \ - N(); \ - __dy_output_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ - Operator::DynamicOutputRegister(#x, num, isPushBack); \ - return *this; \ - } \ - TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \ - graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ - } \ - \ - private: \ - void __dy_output_##x() { \ - Operator::DynamicOutputRegister(#x, 0, true); \ - (void)OpReg() - -#define GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - static const string name_graph_##x() { return #x; } \ - SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ - _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, 0, v); \ - return *this; \ - } \ - Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ - \ - private: \ - void __graph_##x() { \ - Operator::SubgraphRegister(#x, false); \ - Operator::SubgraphCountRegister(#x, 1); \ - (void)OpReg() - -#define DYNAMIC_GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - static const string name_graph_##x() { return #x; } \ - _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ - Operator::SubgraphCountRegister(#x, num); \ - return *this; \ - } \ - SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ - return Operator::GetDynamicSubgraphBuilder(#x, index); \ - } \ - Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \ - _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, index, v); \ - return *this; \ - } \ - \ - private: \ - void __graph_##x() { \ - Operator::SubgraphRegister(#x, true); \ - (void)OpReg() - -#define PASTE(g_register, y) g_register##y -#define __OP_END_IMPL__(x, y) \ - N(); \ - } \ - static_assert( \ - std::is_same::value, \ - "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ - } \ - ; \ - static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ - } -#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) - -// Specialized shape inferencer macro - -#define IMPLEMT_INFERFUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -#define IMPLEMT_COMMON_INFERFUNC(func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) - -#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -// Specialized verifier macro - -#define IMPLEMT_VERIFIER(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) - -#define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } - -#define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); } - -#define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } - -#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) - -#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) -// Infer format func register -#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ - static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) - -// Shape inferencer & verifier register macro - -#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) - -#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__) - -#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) - -// Infer format func reg -#define INFER_FORMAT_FUNC_REG(op_name, x) \ - __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) - -// Common shape inferencer - -#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ - [](Operator op) -> graphStatus { \ - auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ - auto x_type = op.GetInputDesc(in_name).GetDataType(); \ - TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ - op_output_desc.SetShape(ge::Shape(x_shape)); \ - op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ - op_output_desc.SetDataType(x_type); \ - return op.UpdateOutputDesc(out_name, op_output_desc); \ - } - -graphStatus BroadCastInfer(const function()> &get_in1_shape, - const function()> &get_in2_shape, - const function &y_shape)> &set_out_shape); - -#define BROADCAST_INFER(in1_name, in2_name, out_name) \ - [](Operator op) -> graphStatus { \ - return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ - [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ - [&](const vector &y_shape) { \ - TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ - op_output_desc.SetShape(ge::Shape(y_shape)); \ - (void)op.UpdateOutputDesc(out_name, op_output_desc); \ - }); \ - } -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ diff --git a/metadef/inc/external/graph/tensor.h b/metadef/inc/external/graph/tensor.h deleted file mode 100644 index 800e1037..00000000 --- a/metadef/inc/external/graph/tensor.h +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_ -#define INC_EXTERNAL_GRAPH_TENSOR_H_ - -#include -#include -#include -#include -#include - -#include "./ge_error_codes.h" -#include "./types.h" - -namespace ge { -class ShapeImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { - public: - Shape(); - ~Shape() = default; - explicit Shape(const std::vector &dims); - - size_t GetDimNum() const; - // If the idx is invalid, return 0 - int64_t GetDim(size_t idx) const; - graphStatus SetDim(size_t idx, int64_t value); - std::vector GetDims() const; - int64_t GetShapeSize() const; - - private: - std::shared_ptr impl_; -}; - -class TensorDescImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { - public: - TensorDesc(); - ~TensorDesc() = default; - explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - // Copy - TensorDesc(const TensorDesc &desc); - // Move - TensorDesc(TensorDesc &&desc); - // Copy - TensorDesc &operator=(const TensorDesc &desc); - // Move - TensorDesc &operator=(TensorDesc &&desc); - - void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - Shape GetShape() const; - void SetShape(const Shape &shape); - // set shape with -2, it stand for unknown shape - graphStatus SetUnknownDimNumShape(); - // for unknown shape - graphStatus SetShapeRange(const std::vector> &range); - graphStatus GetShapeRange(std::vector> &range) const; - - Format GetFormat() const; - void SetFormat(Format format); - - Shape GetOriginShape() const; - void SetOriginShape(const Shape &originShape); - - Format GetOriginFormat() const; - void SetOriginFormat(Format originFormat); - - DataType GetDataType() const; - void SetDataType(DataType dt); - - std::string GetName() const; - void SetName(const std::string &name); - - // Attr acess - void SetSize(int64_t size); - int64_t GetSize() const; - - int64_t GetRealDimCnt() const; - void SetRealDimCnt(const int64_t realDimCnt); - - private: - std::shared_ptr impl; -}; - -class TensorImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { - public: - Tensor(); - ~Tensor() = default; - explicit Tensor(const TensorDesc &tensorDesc); - Tensor(const TensorDesc &tensorDesc, const std::vector &data); - Tensor(const TensorDesc &tensorDesc, const uint8_t *data, size_t size); - Tensor(TensorDesc &&tensorDesc, std::vector &&data); - - TensorDesc GetTensorDesc() const; - graphStatus SetTensorDesc(const TensorDesc &tensorDesc); - - const uint8_t *GetData() const; - uint8_t *GetData(); - size_t GetSize() const; - - graphStatus SetData(std::vector &&data); - graphStatus SetData(const std::vector &data); - graphStatus SetData(const uint8_t *data, size_t size); - graphStatus SetData(const std::string &data); - graphStatus SetData(const std::vector &data); - graphStatus IsValid(); - - Tensor Clone() const; - - private: - std::shared_ptr impl; - friend class TensorAdapter; -}; -} // namespace ge -/*lint +e148*/ - -#endif // INC_EXTERNAL_GRAPH_TENSOR_H_ diff --git a/metadef/inc/external/graph/types.h b/metadef/inc/external/graph/types.h deleted file mode 100644 index a1245c9d..00000000 --- a/metadef/inc/external/graph/types.h +++ /dev/null @@ -1,240 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_GRAPH_TYPES_H_ -#define INC_EXTERNAL_GRAPH_TYPES_H_ - -#include -#include -#include - -namespace ge { -static const int64_t UNKNOWN_DIM = -1; -static const int64_t UNKNOWN_DIM_NUM = -2; -static const std::vector UNKNOWN_SHAPE = {-1}; -static const std::vector UNKNOWN_RANK = {-2}; - -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif - -enum DataType { - DT_FLOAT = 0, // float type - DT_FLOAT16 = 1, // fp16 type - DT_INT8 = 2, // int8 type - DT_INT16 = 6, // int16 type - DT_UINT16 = 7, // uint16 type - DT_UINT8 = 4, // uint8 type - DT_INT32 = 3, // - DT_INT64 = 9, // int64 type - DT_UINT32 = 8, // unsigned int32 - DT_UINT64 = 10, // unsigned int64 - DT_BOOL = 12, // bool type - DT_DOUBLE = 11, // double type - DT_STRING = 13, // string type - DT_DUAL_SUB_INT8 = 14, // dual output int8 type - DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type - DT_COMPLEX64 = 16, // complex64 type - DT_COMPLEX128 = 17, // complex128 type - DT_QINT8 = 18, // qint8 type - DT_QINT16 = 19, // qint16 type - DT_QINT32 = 20, // qint32 type - DT_QUINT8 = 21, // quint8 type - DT_QUINT16 = 22, // quint16 type - DT_RESOURCE = 23, // resource type - DT_STRING_REF = 24, // string ref type - DT_DUAL = 25, // dual output type - DT_UNDEFINED // Used to indicate a DataType field has not been set. -}; - -inline int GetSizeByDataType(DataType data_type) { - static int data_type_size[DT_UNDEFINED] = { - 4, // DT_FLOAT = 0, float type - 2, // DT_FLOAT16 = 1, fp16 type - 1, // DT_INT8 = 2, int8 type - 4, // DT_INT32 = 3, - 1, // DT_UINT8 = 4, uint8 type - -1, - 2, // DT_INT16 = 6, int16 type - 2, // DT_UINT16 = 7, uint16 type - 4, // DT_UINT32 = 8, unsigned int32 - 8, // DT_INT64 = 9, int64 type - 8, // DT_UINT64 = 10, unsigned int64 - 8, // DT_DOUBLE = 11, double type - 1, // DT_BOOL = 12, bool type - -1, // DT_STRING = 13, string type - 1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type - 1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type - 8, // DT_COMPLEX64 = 16, complex64 type - 16, // DT_COMPLEX128 = 17, complex128 type - 1, // DT_QINT8 = 18, qint8 type - 2, // DT_QINT16 = 19, qint16 type - 4, // DT_QINT32 = 20, qint32 type - 1, // DT_QUINT8 = 21, quint8 type - 2, // DT_QUINT16 = 22, quint16 type - -1, // DT_RESOURCE = 23, resource type - -1, // DT_STRING_REF = 24, string ref type - 5, // DT_DUAL = 25, dual output type (float + int8) - // DT_UNDEFINED Used to indicate a DataType field has not been set. - }; - if (data_type >= DT_UNDEFINED) { - return -1; - } - return data_type_size[data_type]; -} - -enum Format { - FORMAT_NCHW = 0, // NCHW - FORMAT_NHWC, // NHWC - FORMAT_ND, // Nd Tensor - FORMAT_NC1HWC0, // NC1HWC0 - FORMAT_FRACTAL_Z, // FRACTAL_Z - FORMAT_NC1C0HWPAD, - FORMAT_NHWC1C0, - FORMAT_FSR_NCHW, - FORMAT_FRACTAL_DECONV, - FORMAT_C1HWNC0, - FORMAT_FRACTAL_DECONV_TRANSPOSE, - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, - FORMAT_NC1HWC0_C04, // NC1HWC0, C0 =4 - FORMAT_FRACTAL_Z_C04, // FRACZ, C0 =4 - FORMAT_CHWN, - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, - FORMAT_HWCN, - FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format - FORMAT_BN_WEIGHT, - FORMAT_FILTER_HWCK, // filter input tensor format - FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, - FORMAT_HASHTABLE_LOOKUP_KEYS, - FORMAT_HASHTABLE_LOOKUP_VALUE, - FORMAT_HASHTABLE_LOOKUP_OUTPUT, - FORMAT_HASHTABLE_LOOKUP_HITS = 24, - FORMAT_C1HWNCoC0, - FORMAT_MD, - FORMAT_NDHWC, - FORMAT_FRACTAL_ZZ, - FORMAT_FRACTAL_NZ, - FORMAT_NCDHW, - FORMAT_DHWCN, // 3D filter input tensor format - FORMAT_NDC1HWC0, - FORMAT_FRACTAL_Z_3D, - FORMAT_CN, - FORMAT_NC, - FORMAT_DHWNC, - FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format - FORMAT_FRACTAL_ZN_LSTM, - FORMAT_FRACTAL_Z_G, - FORMAT_RESERVED, - FORMAT_ALL, - FORMAT_NULL -}; - -// for unknown shape op type -enum UnknowShapeOpType { - DEPEND_IN_SHAPE = 1, // op out shape get by input shape - DEPEND_CONST_VALUE = 2, // op out shape get by const op value - DEPEND_SHAPE_RANGE = 3, // op out shape get by range - DEPEND_COMPUTE = 4 // op out shape get by totally computing -}; - -struct TensorDescInfo { - Format format_ = FORMAT_RESERVED; // tbe op register support format - DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype -}; - -enum DeviceType { - NPU = 0, - CPU = 1, -}; - -class TensorTypeImpl; -struct TensorType { - explicit TensorType(DataType dt); - - TensorType(const std::initializer_list &types); - - static TensorType ALL() { - return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, - DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, - DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; } - - static TensorType OrdinaryType() { - return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, - DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType BasicType() { - return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, - DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, - DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType NumberType() { - return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, - DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType RealNumberType() { - return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, - DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64}; } - - static TensorType IntegerDataType() { - return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; - } - - static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; } - - static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; } - - static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } - - static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; } - - static TensorType UnaryDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } - - static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } - - std::shared_ptr tensor_type_impl_; -}; -} // namespace ge - -namespace domi { -enum class ImplyType : unsigned int { - BUILDIN = 0, // Built in operator, normally executed by OME - TVM, // Compile to TVM bin file for execution - CUSTOM, // User defined calculation logic, executed by CPU - AI_CPU, // AICPU - CCE, // Cce - GELOCAL, // GE local, do node need execute by device - HCCL, // Hccl - INVALID = 0xFFFFFFFF, -}; -} // namespace domi - -#endif // INC_EXTERNAL_GRAPH_TYPES_H_ diff --git a/metadef/inc/external/register/register.h b/metadef/inc/external/register/register.h deleted file mode 100644 index f3091fae..00000000 --- a/metadef/inc/external/register/register.h +++ /dev/null @@ -1,163 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "graph/operator.h" -#include "register/register_error_codes.h" -#include "register/register_fmk_types.h" -#include "register/register_types.h" - -using std::make_shared; -using std::map; -using std::pair; -using std::string; -using std::to_string; -using std::unique_ptr; -using std::vector; - -/*lint -e148*/ -namespace ge { -class Operator; -class TensorDesc; -class Tensor; -class TBEPluginManager; -} // namespace ge - -namespace google { -namespace protobuf { -class Message; -} -} // namespace google - -namespace domi { -const int64_t kMaxNameLength = 1048576; // 1M - -enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 }; -struct DynamicInputOutputInfo { - DynamicType type; // input/output - const char *port_name; - int64_t port_name_len; - const char *attr_name; - int64_t attr_name_len; - DynamicInputOutputInfo() - : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {} - DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name, - int64_t attr_name_len) - : type(type), - port_name(port_name), - port_name_len(port_name_len), - attr_name(attr_name), - attr_name_len(attr_name_len) {} -}; -Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); -Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, - const vector &dynamic_name_attr_value); -Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); -Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, - std::map> dynamic_name_attr_value, - int in_pos = -1, int out_pos = -1); -Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function &input, - const std::function &output); -Status AutoMappingSubgraphIndex(const ge::Graph &graph, - const std::function &input, - const std::function &output); -using google::protobuf::Message; -class OpRegistrationDataImpl; - -using ParseParamFunc = std::function; -using ParseParamByOpFunc = std::function; -using FusionParseParamFunc = - std::function, ge::Operator &)>; -using FusionParseParamByOpFunc = std::function &, ge::Operator &)>; -using ParseSubgraphFunc = std::function; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { - public: - OpRegistrationData(const std::string &om_optype); - - ~OpRegistrationData(); - - OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type); - - OpRegistrationData &OriginOpType(const std::initializer_list &ori_optype_list); - - OpRegistrationData &OriginOpType(const std::string &ori_optype); - - OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); - - OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn); - - OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); - - OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); - - OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); - - OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); - - OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); - - OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); - - OpRegistrationData &InputReorderVector(const vector &input_order); - - domi::ImplyType GetImplyType() const; - std::string GetOmOptype() const; - std::set GetOriginOpTypeSet() const; - domi::FrameworkType GetFrameworkType() const; - ParseParamFunc GetParseParamFn() const; - ParseParamByOpFunc GetParseParamByOperatorFn() const; - FusionParseParamFunc GetFusionParseParamFn() const; - FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; - ParseSubgraphFunc GetParseSubgraphPostFn() const; - - private: - std::shared_ptr impl_; - friend class OpRegistry; - friend class OpRegistrationTbe; - friend class ge::TBEPluginManager; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { - public: - OpReceiver(OpRegistrationData ®_data); - ~OpReceiver() {} -}; - -#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) -#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) -#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ - static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name) -} // namespace domi - -namespace ge { -using OpRegistrationData = domi::OpRegistrationData; -using OpReceiver = domi::OpReceiver; -} // namespace ge -/*lint +e148*/ -#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/metadef/inc/external/register/register_error_codes.h b/metadef/inc/external/register/register_error_codes.h deleted file mode 100644 index 5e0ed79f..00000000 --- a/metadef/inc/external/register/register_error_codes.h +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ - -#define SYSID_FWK 3 // Subsystem ID -#define MODID_COMMON 0 // Common module ID - -#define DECLARE_ERRORNO(sysid, modid, name, value) \ - const domi::Status name = \ - ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); - -#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) - -namespace domi { -using Status = uint32_t; - -// General error code -DECLARE_ERRORNO(0, 0, SUCCESS, 0); -DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); -DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 -DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); -} // namespace domi - -#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ diff --git a/metadef/inc/external/register/register_fmk_types.h b/metadef/inc/external/register/register_fmk_types.h deleted file mode 100644 index 97616060..00000000 --- a/metadef/inc/external/register/register_fmk_types.h +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ - -#include - -namespace domi { -/// -/// @ingroup domi_omg -/// @brief AI framework types -/// -enum FrameworkType { - CAFFE = 0, - MINDSPORE = 1, - TENSORFLOW = 3, - ANDROID_NN, - ONNX, - FRAMEWORK_RESERVED, -}; -} // namespace domi - -#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ diff --git a/metadef/inc/external/register/register_types.h b/metadef/inc/external/register/register_types.h deleted file mode 100644 index 08d72713..00000000 --- a/metadef/inc/external/register/register_types.h +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ - -namespace domi { -#ifdef HOST_VISIBILITY -#define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define FMK_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define FMK_FUNC_DEV_VISIBILITY -#endif - -/// CCE defined constant - -/// -/// @ingroup domi -/// @brief original tensor type -/// -typedef enum tagDomiTensorFormat { - DOMI_TENSOR_NCHW = 0, // < NCHW - DOMI_TENSOR_NHWC, // < NHWC - DOMI_TENSOR_ND, // < Nd Tensor - DOMI_TENSOR_NC1HWC0, // < NC1HWC0 - DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z - DOMI_TENSOR_NC1C0HWPAD, - DOMI_TENSOR_NHWC1C0, - DOMI_TENSOR_FSR_NCHW, - DOMI_TENSOR_FRACTAL_DECONV, - DOMI_TENSOR_BN_WEIGHT, - DOMI_TENSOR_CHWN, // Android NN Depth CONV - DOMI_TENSOR_FILTER_HWCK, // filter input tensor format - DOMI_TENSOR_NDHWC, - DOMI_TENSOR_NCDHW, - DOMI_TENSOR_DHWCN, // 3D filter input tensor format - DOMI_TENSOR_DHWNC, - DOMI_TENSOR_RESERVED -} domiTensorFormat_t; -} // namespace domi - -#endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ diff --git a/metadef/inc/external/register/scope/scope_fusion_pass_register.h b/metadef/inc/external/register/scope/scope_fusion_pass_register.h deleted file mode 100644 index 8e5605a7..00000000 --- a/metadef/inc/external/register/scope/scope_fusion_pass_register.h +++ /dev/null @@ -1,334 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ -#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ - -#include -#include -#include -#include -#include -#include "ge/ge_api_error_codes.h" -#include "register/register_error_codes.h" -#include "register/register_types.h" -#include "graph/operator.h" - -#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ - do { \ - if (!(cond)) { \ - if ((fusion_rlt) != nullptr) { \ - (fusion_rlt)->SetType(ge::kScopeInvalidType); \ - } \ - return; \ - } \ - } while (0) - -namespace domi { -class TensorFlowModelParser; -} // namespace domi -namespace ge { -const int32_t kFusionDisableIndex = 99999; -const char *const kScopeToMultiNodes = "ScopeToMultiNodes"; -const char *const kScopeInvalidType = "ScopeInvalidType"; -const char *const kInputFromFusionScope = "InputFromFusionScope"; -const char *const kOutputToFusionScope = "OutputToFusionScope"; -class ScopePattern; -using ScopeFusionPatterns = std::vector>; - -class ScopePassManager; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { - public: - Scope(); - Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); - ~Scope(); - - const std::string &Name() const; - const std::string &SubType() const; - const std::unordered_map &AllNodesMap() const; - Scope *GetSubScope(const std::string &scope_name) const; - const std::string LastName() const; - const std::vector &GetAllSubScopes() const; - const Scope *GetFatherScope() const; - - private: - class ScopeImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; - friend class ScopeTree; - friend class NodeOpTypeFeature; - friend class NodeAttrFeature; - friend class ScopeFeature; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { - public: - FusionScopesResult(); - Status Init(); - ~FusionScopesResult(); - void SetName(const std::string &name); - void SetType(const std::string &type); - void SetDescription(const std::string &description); - const std::string &Name() const; - const std::vector &Nodes() const; - void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); - void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); - - class InnerNodeInfo { - public: - explicit InnerNodeInfo(const std::string &fusion_node_name); - InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); - InnerNodeInfo(InnerNodeInfo &&other) noexcept; - InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; - InnerNodeInfo(const InnerNodeInfo &) = delete; - InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; - ~InnerNodeInfo(); - InnerNodeInfo &SetName(const std::string &name); - InnerNodeInfo &SetType(const std::string &type); - InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); - InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); - ge::graphStatus BuildInnerNode(); - ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); - ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); - ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); - ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); - ge::Operator *MutableOperator(); - - std::string GetName() const; - std::string GetType() const; - std::vector> GetInputs() const; - std::vector> GetOutputs() const; - - private: - class InnerNodeInfoImpl; - std::unique_ptr impl_; - }; - - InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); - InnerNodeInfo *MutableRecentInnerNode(); - InnerNodeInfo *MutableInnerNode(uint32_t index); - ge::graphStatus CheckInnerNodesInfo(); - - private: - class FusionScopesResultImpl; - std::unique_ptr impl_; - friend class ScopeGraph; - friend class ScopeBasePass; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { - public: - ScopeTree(); - Status Init(); - ScopeTree(const ScopeTree &scopetree) = delete; - ScopeTree &operator=(const ScopeTree &scopetree) = delete; - ~ScopeTree(); - - const std::vector &GetAllScopes() const; - - private: - class ScopeTreeImpl; - std::unique_ptr impl_; - friend class ScopeGraph; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { - public: - ScopeGraph(); - Status Init(); - ScopeGraph(const ScopeGraph &scope_graph) = delete; - ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; - ~ScopeGraph(); - - const ScopeTree *GetScopeTree() const; - const std::unordered_map &GetNodesMap() const; - - private: - class ScopeGraphImpl; - std::unique_ptr impl_; - friend class ScopePassManager; - friend class ScopeBasePass; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { - public: - ScopeAttrValue(); - ScopeAttrValue(ScopeAttrValue const &attr_value); - ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); - ~ScopeAttrValue(); - - void SetIntValue(int64_t value); - void SetFloatValue(float value); - void SetStringValue(std::string value); - void SetBoolValue(bool value); - - private: - class ScopeAttrValueImpl; - std::unique_ptr impl_; - friend class NodeAttrFeature; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { - public: - virtual bool Match(const Scope *scope) = 0; - virtual ~ScopeBaseFeature(){}; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { - public: - NodeOpTypeFeature(std::string nodeType, int num, int step = 0); - NodeOpTypeFeature(NodeOpTypeFeature const &feature); - NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); - ~NodeOpTypeFeature(); - bool Match(const Scope *scope) override; - - private: - class NodeOpTypeFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { - public: - NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value); - NodeAttrFeature(NodeAttrFeature const &feature); - NodeAttrFeature &operator=(NodeAttrFeature const &feature); - ~NodeAttrFeature(); - bool Match(const Scope *scope) override; - - private: - class NodeAttrFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { - public: - ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "", - int step = 0); - ScopeFeature(ScopeFeature const &feature); - ScopeFeature &operator=(ScopeFeature const &feature); - ~ScopeFeature(); - bool Match(const Scope *scope) override; - - private: - class ScopeFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { - public: - ScopePattern(); - ~ScopePattern(); - - ScopePattern &SetSubType(const std::string &sub_type); - ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); - ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); - ScopePattern &AddScopeFeature(ScopeFeature feature); - - private: - class ScopePatternImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { - public: - ScopesResult(); - ScopesResult(ScopesResult const &result); - ScopesResult &operator=(ScopesResult const &result); - ~ScopesResult(); - - void SetScopes(std::vector &scopes); - void SetNodes(std::vector &nodes); - - private: - class ScopesResultImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { - public: - ScopeBasePass(); - virtual ~ScopeBasePass(); - - protected: - // Subclasses implement respective fusion strategies and build the Patterns - virtual std::vector DefinePatterns() = 0; - // Define the name of the scope pass - virtual std::string PassName() = 0; - // Subclasses implement respective multi-scope or operator fusion methods across scopes - virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) = 0; - // Subclasses implement their own results and set the input and output of the final fusion operator - virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; - - private: - class ScopeBasePassImpl; - std::unique_ptr impl_; - friend class ge::ScopePassManager; - friend class ScopeBasePassImpl; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { - public: - using CreateFn = ScopeBasePass *(*)(); - ~ScopeFusionPassRegistry(); - - static ScopeFusionPassRegistry &GetInstance() { - static ScopeFusionPassRegistry instance; - return instance; - } - - void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); - - private: - ScopeFusionPassRegistry(); - class ScopeFusionPassRegistryImpl; - /*lint -e148*/ - std::unique_ptr impl_; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { - public: - static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); - static void FreeScopePatterns(ScopeFusionPatterns &patterns); - static void FreeOneBatchPattern(std::vector &one_batch_pattern); -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { - public: - ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); - ~ScopeFusionPassRegistrar() {} -}; - -#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ - REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general) - -#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ - REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) - -#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ - static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ - ::ge::ScopeFusionPassRegistrar( \ - pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general) -} // namespace ge - -#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ diff --git a/metadef/inc/graph/anchor.h b/metadef/inc/graph/anchor.h deleted file mode 100644 index 565f0843..00000000 --- a/metadef/inc/graph/anchor.h +++ /dev/null @@ -1,284 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_ANCHOR_H_ -#define INC_GRAPH_ANCHOR_H_ - -#include -#include -#include -#include "graph/ge_error_codes.h" -#include "graph/range_vistor.h" -#include "graph/types.h" - -namespace ge { -enum AnchorStatus { - ANCHOR_SUSPEND = 0, // dat null - ANCHOR_CONST = 1, - ANCHOR_DATA = 2, // Effective - ANCHOR_RESERVED = 3 -}; -using std::string; -using std::vector; - -class Node; - -using NodePtr = std::shared_ptr; - -class Edge; - -using EdgePtr = std::shared_ptr; - -class Anchor; - -using AnchorPtr = std::shared_ptr; - -class DataAnchor; - -using DataAnchorPtr = std::shared_ptr; - -class InDataAnchor; - -using InDataAnchorPtr = std::shared_ptr; - -class OutDataAnchor; - -using OutDataAnchorPtr = std::shared_ptr; - -class ControlAnchor; - -using ControlAnchorPtr = std::shared_ptr; - -class InControlAnchor; - -using InControlAnchorPtr = std::shared_ptr; - -class OutControlAnchor; - -using OutControlAnchorPtr = std::shared_ptr; - -using ConstAnchor = const Anchor; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this { - friend class AnchorUtils; - - public: - using TYPE = const char *; - template - using Vistor = RangeVistor>; - - Anchor(const NodePtr &ownerNode, int idx); - - virtual ~Anchor() = default; - - protected: - // Whether the two anchor is equal - virtual bool Equal(AnchorPtr anchor) const = 0; - virtual bool IsTypeOf(TYPE type) const; - - public: - // Get all peer anchors connected to current anchor - Vistor GetPeerAnchors() const; - // Get peer anchor size - size_t GetPeerAnchorsSize() const; - // Get first peer anchor - AnchorPtr GetFirstPeerAnchor() const; - - // Get the anchor belong to which node - NodePtr GetOwnerNode() const; - - // Remove all links with the anchor - void UnlinkAll() noexcept; - - // Remove link with the given anchor - graphStatus Unlink(const AnchorPtr &peer); - - // Replace peer with new peers - graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); - - // Judge if the anchor is linked with the given anchor - bool IsLinkedWith(const AnchorPtr &peer); - - // Get anchor index of the node - int GetIdx() const; - - // set anchor index of the node - void SetIdx(int index); - - protected: - // All peer anchors connected to current anchor - vector> peer_anchors_; - // The owner node of anchor - std::weak_ptr owner_node_; - // The index of current anchor - int idx_; - template - static Anchor::TYPE TypeOf() { - static_assert(std::is_base_of::value, "T must be a Anchor!"); - return __PRETTY_FUNCTION__; - } - - public: - template - static std::shared_ptr DynamicAnchorCast(AnchorPtr anchorPtr) { - static_assert(std::is_base_of::value, "T must be a Anchor!"); - if (anchorPtr == nullptr || !anchorPtr->IsTypeOf()) { - return nullptr; - } - return std::static_pointer_cast(anchorPtr); - } - - template - bool IsTypeOf() { - return IsTypeOf(TypeOf()); - } -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor { - friend class AnchorUtils; - - public: - explicit DataAnchor(const NodePtr &ownerNode, int idx); - - virtual ~DataAnchor() = default; - - protected: - bool IsTypeOf(TYPE type) const override; - - private: - Format format_{FORMAT_ND}; - AnchorStatus status_{ANCHOR_SUSPEND}; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor { - friend class OutDataAnchor; - - friend class OutControlAnchor; - - public: - explicit InDataAnchor(const NodePtr &ownerNode, int idx); - - virtual ~InDataAnchor() = default; - - // Get source out data anchor - OutDataAnchorPtr GetPeerOutAnchor() const; - - // Build connection from OutDataAnchor to InDataAnchor - graphStatus LinkFrom(const OutDataAnchorPtr &src); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor { - friend class InDataAnchor; - - friend class AnchorUtils; - - public: - template - using Vistor = RangeVistor>; - - explicit OutDataAnchor(const NodePtr &ownerNode, int idx); - - virtual ~OutDataAnchor() = default; - // Get dst in data anchor(one or more) - Vistor GetPeerInDataAnchors() const; - uint32_t GetPeerInDataNodesSize() const; - - // Get dst in control anchor(one or more) - Vistor GetPeerInControlAnchors() const; - - // Build connection from OutDataAnchor to InDataAnchor - graphStatus LinkTo(const InDataAnchorPtr &dest); - - // Build connection from OutDataAnchor to InControlAnchor - graphStatus LinkTo(const InControlAnchorPtr &dest); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor { - public: - explicit ControlAnchor(const NodePtr &ownerNode); - - explicit ControlAnchor(const NodePtr &ownerNode, int idx); - - virtual ~ControlAnchor() = default; - - protected: - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor { - friend class OutControlAnchor; - - friend class OutDataAnchor; - - public: - explicit InControlAnchor(const NodePtr &ownerNode); - - explicit InControlAnchor(const NodePtr &ownerNode, int idx); - - virtual ~InControlAnchor() = default; - - // Get source out control anchors - Vistor GetPeerOutControlAnchors() const; - bool IsPeerOutAnchorsEmpty() const { return peer_anchors_.empty(); } - - // Get source out data anchors - Vistor GetPeerOutDataAnchors() const; - - // Build connection from OutControlAnchor to InControlAnchor - graphStatus LinkFrom(const OutControlAnchorPtr &src); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor { - friend class InControlAnchor; - - public: - template - using Vistor = RangeVistor>; - - explicit OutControlAnchor(const NodePtr &ownerNode); - - explicit OutControlAnchor(const NodePtr &ownerNode, int idx); - - virtual ~OutControlAnchor() = default; - - // Get dst in control anchor(one or more) - Vistor GetPeerInControlAnchors() const; - // Get dst data anchor in control anchor(one or more) - Vistor GetPeerInDataAnchors() const; - - // Build connection from OutControlAnchor to InControlAnchor - graphStatus LinkTo(const InControlAnchorPtr &dest); - // Build connection from OutDataAnchor to InDataAnchor - graphStatus LinkTo(const InDataAnchorPtr &dest); - - protected: - bool Equal(AnchorPtr anchor) const override; - bool IsTypeOf(TYPE type) const override; -}; -} // namespace ge -#endif // INC_GRAPH_ANCHOR_H_ diff --git a/metadef/inc/graph/attr_value_serializable.h b/metadef/inc/graph/attr_value_serializable.h deleted file mode 100644 index a69beb96..00000000 --- a/metadef/inc/graph/attr_value_serializable.h +++ /dev/null @@ -1,191 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ -#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ - -#include -#include -#include "graph/ge_attr_value.h" - -namespace ge { - -class GeAttrValue; -class _GeSerializable { - public: - template - struct ge_serializable_int64_t_support_type { - using DT = typename std::remove_cv::type; - static const bool value = std::is_same::value // by cast - || std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value; - }; - - template - static GeAttrValue SaveItemAsAttrValue(const T &t) { - return GeAttrValue::CreateFrom(t); - } - - template - static GeAttrValue SaveItemAsAttrValue(const vector &t) { - return GeAttrValue::CreateFrom(t); - } - - template = 0, typename DT = typename std::remove_cv::type> - static GeAttrValue SaveItemAsAttrValue(const T &t) { - return GeAttrValue::CreateFrom
(t); - } - // int64_t support type - template ::value, int>::type = 0> - static GeAttrValue SaveItemAsAttrValue(const T &t) { - return GeAttrValue::CreateFrom(t); - } - // vector int64_t support type - template ::value, int>::type = 0> - static GeAttrValue SaveItemAsAttrValue(const vector &t) { - return GeAttrValue::CreateFrom(t); - } - - template - static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template - static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template = 0, typename DT = typename std::remove_cv::type> - static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { - return attrVal.GetValue
(t); - } - - template ::value, int>::type = 0> - static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template ::value, int>::type = 0> - static graphStatus LoadItemFromAttrValue(vector &t, GeAttrValue &attrVal) { - return attrVal.GetValue(t); - } - - template - static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { - GeAttrValue itemVal = SaveItemAsAttrValue(item); - (void)namedAttrs.SetAttr(itemName, itemVal); - SaveItem(namedAttrs, args...); - } - - static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} - - template - static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { - auto itemVal = namedAttrs.GetItem(itemName); - auto status = LoadItemFromAttrValue(item, itemVal); - if (status != GRAPH_SUCCESS) { - return status; - } - return LoadItem(namedAttrs, args...); - } - - static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { - return GRAPH_SUCCESS; - } -}; - -#define _GE_FI(a) #a, a -#define _GE_MAP_FIELDS1(a1) _GE_FI(a1) -#define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) -#define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3) -#define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4) -#define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5) -#define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6) -#define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7) -#define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8) -#define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9) -#define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10) -#define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11) -#define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12) -#define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) -#define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) -#define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ - _GE_FI(a1) \ - , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ - _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) - -#define _GE_PRIVATE_ARGS_GLUE(x, y) x y - -#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \ - ...) \ - N -#define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args -#define _GE_COUNT_MACRO_VAR_ARGS(...) \ - _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) - -#define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count -#define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) -#define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) - -#define _GE_INVOKE_VAR_MACRO(...) \ - _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \ - (__VA_ARGS__)) - -#define GE_SERIALIZABLE(...) \ - public: \ - friend class ge::GeAttrValue; \ - using __ge_serializable = int; \ - \ - private: \ - ge::graphStatus Save(GeAttrValue &ar) const { \ - GeAttrValue::NAMED_ATTRS named_attrs; \ - _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ - return ar.SetValue(named_attrs); \ - } \ - ge::graphStatus Load(const GeAttrValue &ar) { \ - GeAttrValue::NAMED_ATTRS named_attrs; \ - ge::graphStatus status = ar.GetValue(named_attrs); \ - if (status != GRAPH_SUCCESS) { \ - return status; \ - } \ - return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ - } - -// end NamedAttrs Helper: GE_SERIALIZABLE -} // namespace ge -#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ diff --git a/metadef/inc/graph/buffer.h b/metadef/inc/graph/buffer.h deleted file mode 100644 index ca4355a7..00000000 --- a/metadef/inc/graph/buffer.h +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_BUFFER_H_ -#define INC_GRAPH_BUFFER_H_ - -#include -#include -#include -#include -#include "detail/attributes_holder.h" - -namespace ge { -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif - -using std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { - public: - Buffer(); - Buffer(const Buffer &other); - - explicit Buffer(std::size_t bufferSize, std::uint8_t defualtVal = 0); - - ~Buffer() = default; - - Buffer &operator=(const Buffer &other); - - static Buffer CopyFrom(const std::uint8_t *data, std::size_t bufferSize); - - const std::uint8_t *GetData() const; - std::uint8_t *GetData(); - std::size_t GetSize() const; - void ClearBuffer(); - - // For compatibility - inline const std::uint8_t *data() const { return GetData(); } - inline std::uint8_t *data() { return GetData(); } // lint !e659 - inline std::size_t size() const { return GetSize(); } - inline void clear() { return ClearBuffer(); } - uint8_t operator[](size_t index) const { // lint !e1022 !e1042 - if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 - return (uint8_t)(*buffer_)[index]; - } - return 0xff; - } - - private: - GeIrProtoHelper data_; - std::string *buffer_ = nullptr; - - // Create from protobuf obj - Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); - Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); - - friend class GeAttrValueImp; - friend class GeTensor; -}; -} // namespace ge -#endif // INC_GRAPH_BUFFER_H_ diff --git a/metadef/inc/graph/compute_graph.h b/metadef/inc/graph/compute_graph.h deleted file mode 100644 index 2ec6b663..00000000 --- a/metadef/inc/graph/compute_graph.h +++ /dev/null @@ -1,308 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ -#define INC_GRAPH_COMPUTE_GRAPH_H_ - -#include -#include -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/anchor.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/range_vistor.h" - -namespace ge { -class Node; -using NodePtr = std::shared_ptr; -class Edge; -using EdgePtr = std::shared_ptr; - -class InDataAnchor; -using InDataAnchorPtr = std::shared_ptr; - -class OutDataAnchor; -using OutDataAnchorPtr = std::shared_ptr; - -class ControlAnchor; -using ControlAnchorPtr = std::shared_ptr; -class InControlAnchor; -using InControlAnchorPtr = std::shared_ptr; -class OutControlAnchor; -using OutControlAnchorPtr = std::shared_ptr; -class GeAttrValue; -using AttrValuePtr = std::shared_ptr; -using ConstComputeGraph = const ComputeGraph; - -class OperatorImpl; -using OperatorImplPtr = std::shared_ptr; - -class ComputeGraph : public std::enable_shared_from_this, public AttrHolder { - friend class GraphUtils; - - public: - template - using Vistor = RangeVistor>; - - explicit ComputeGraph(const std::string &name); - ~ComputeGraph() override; - - std::string GetName() const; - void SetName(const std::string &name); - - using AttrHolder::DelAttr; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - size_t GetAllNodesSize() const; - Vistor GetAllNodes() const; - // is_unknown_shape: false, same with GetAllNodes func - // is_unknown_shape: true, same with GetDirectNodes func - Vistor GetNodes(bool is_unknown_shape) const; - size_t GetDirectNodesSize() const; - Vistor GetDirectNode() const; - Vistor GetInputNodes() const; - Vistor GetOutputNodes() const; - - NodePtr FindNode(const std::string &name) const; - NodePtr FindFirstNodeMatchType(const std::string &name) const; - /*lint -e504*/ - // AddNode with NodePtr - NodePtr AddNode(NodePtr node); - NodePtr AddNode(OpDescPtr op); - NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize - NodePtr AddNodeFront(NodePtr node); - NodePtr AddNodeFront(const OpDescPtr &op); - NodePtr AddInputNode(NodePtr node); - NodePtr AddOutputNode(NodePtr node); - NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index); - // insert node with specific pre_node - NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); - NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); - - graphStatus RemoveNode(const NodePtr &node); - graphStatus RemoveInputNode(const NodePtr &node); - graphStatus RemoveOutputNode(const NodePtr &node); - graphStatus RemoveConstInput(const NodePtr &node); - - /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, - /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph - /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() - /// must equal to subgraph->GetOwnerGraph(). - /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. - /// The subgraph's name SHOULD(not must) be the same as the parameter `name` - graphStatus AddSubgraph(const std::string &name, const std::shared_ptr &subgraph); - graphStatus AddSubgraph(const std::shared_ptr &subgraph); - - void RemoveSubgraph(const std::string &name); - void RemoveSubgraph(const std::shared_ptr &subgraph); - - std::shared_ptr GetSubgraph(const std::string &name) const; - std::vector> GetAllSubgraphs() const; - - // obsolete - std::shared_ptr AddSubGraph(std::shared_ptr sub_graph); - // obsolete - graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); - - /// - /// @brief Update input-mapping - /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input - /// @return graphStatus - /// - graphStatus UpdateInputMapping(const std::map &input_mapping); - - /// - /// @brief Update output-mapping - /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output - /// @return graphStatus - /// - graphStatus UpdateOutputMapping(const std::map &output_mapping); - - graphStatus TopologicalSorting(); - bool IsValid() const; - void InValid() { is_valid_flag_ = false; } - void Dump() const; - - void Swap(ComputeGraph &graph); - - graphStatus IsolateNode(const NodePtr &node); - graphStatus Verify(); - graphStatus InferShape(); - graphStatus InferOriginFormat(); - graphStatus InferShapeInNeed(); - graphStatus InsertEventNodes(); - bool operator==(const ComputeGraph &r_compute_graph) const; - - /*lint +e504*/ - const std::map, std::vector> &GetShareParamLayer() const { - return params_share_map_; - } - - void SetShareParamLayer(const std::map, std::vector> params_share_map) { - params_share_map_ = params_share_map; - } - - void SetInputsOrder(const std::vector &inputs_order) { inputs_order_ = inputs_order; } - - void SetGraphOutNodes(std::map> out_nodes_map) { out_nodes_map_ = out_nodes_map; } - - void AppendGraphOutNodes(std::map> out_nodes_map) { - for (auto &item : out_nodes_map) { - (void)out_nodes_map_.emplace(item.first, item.second); - } - } - - shared_ptr GetParentGraph(); - void SetParentGraph(const shared_ptr &parent); - shared_ptr GetParentNode(); - void SetParentNode(const shared_ptr &parent); - - const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } - - void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } - - ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } - void SetOutputSize(uint32_t size) { output_size_ = size; } - uint32_t GetOutputSize() const { return output_size_; } - void SetInputSize(uint32_t size) { input_size_ = size; } - uint32_t GetInputSize() const { return input_size_; } - - // false: known shape true: unknow shape - bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } - void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } - - /// - /// Set is need train iteration. - /// If set true, it means this graph need to be run iteration some - /// times(according variant "npu_runconfig/iterations_per_loop"). - /// @param need_iteration is need iteration - /// - void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; } - - void SetUserDefOutput(const std::string &output_name); - - const std::string GetOutput(); - - /// - /// Get is need train iteration. - /// @return is need iteration - /// - bool GetNeedIteration() const { return need_iteration_; } - - void SetGraphOpName(const std::map &op_name_map) { op_name_map_ = op_name_map; } - const std::map &GetGraphOpName() const { return op_name_map_; } - - const std::map &GetAllNodesInfo() const; - - void SetAllNodesInfo(const std::map &nodes) { all_nodes_infos_ = nodes; } - - void SetGraphOutNodesInfo(std::vector> &out_nodes_info) { - output_nodes_info_ = out_nodes_info; - } - - void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { - output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end()); - } - - const std::vector> &GetGraphOutNodesInfo() const { return output_nodes_info_; } - - void SetGraphTargetNodesInfo(const std::vector &target_nodes_info) { - target_nodes_info_ = target_nodes_info; - } - const std::vector &GetGraphTargetNodesInfo() const { return target_nodes_info_; } - - void SetSessionID(uint64_t session_id) { session_id_ = session_id; } - uint64_t GetSessionID() const { return session_id_; } - - void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; } - uint32_t GetGraphID() const { return graph_id_; } - - void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; } - ge::Format GetDataFormat() const { return data_format_; } - bool IsSummaryGraph() const { return is_summary_graph_; } - void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } - // Graph Before BFE - ComputeGraphPtr origGraph_; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, - std::vector &stack); - graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, - std::deque &stack); - graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map); - graphStatus TopologicalSortingGraph(); - graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); - Vistor AllGraphNodes(std::vector> &subgraphs) const; - size_t GetInEdgeSize(const NodePtr &node); - size_t GetOutEdgeSize(const NodePtr &node); - graphStatus RemoveExtraOutEdge(const NodePtr &node); - bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; - bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; - bool VectorInputNodePtrIsEqual(const std::vector &r_node_ptr_vector, - const std::vector &l_node_ptr_vector) const; - - void SetNodesOwner(); - - friend class ModelSerializeImp; - friend class GraphDebugImp; - friend class OnnxUtils; - friend class TuningUtils; - - std::string name_; - uint32_t graph_id_ = 0; - ProtoAttrMapHelper attrs_; - std::vector nodes_; - std::map all_nodes_infos_; - std::vector target_nodes_info_; - - std::vector input_nodes_; - std::vector inputs_order_; - uint32_t input_size_ = 1; - std::map> out_nodes_map_; - uint32_t output_size_ = 1; - std::vector> output_nodes_info_; - - std::vector> sub_graph_; - std::map> names_to_subgraph_; - std::weak_ptr parent_graph_; - std::weak_ptr parent_node_; - - // the members followed should not in the ComputeGraph class - bool is_valid_flag_; - bool is_summary_graph_ = false; - // Indicates whether it is need iteration - bool need_iteration_ = false; - std::map, std::vector> params_share_map_; - // TaskIdx -> op_name Map - std::map op_name_map_; - uint64_t session_id_ = 0; - ge::Format data_format_ = ge::FORMAT_ND; - // unknown graph indicator, default is false, mean known shape - bool is_unknown_shape_graph_ = false; -}; -} // namespace ge -#endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/metadef/inc/graph/debug/ge_attr_define.h b/metadef/inc/graph/debug/ge_attr_define.h deleted file mode 100644 index a32907bb..00000000 --- a/metadef/inc/graph/debug/ge_attr_define.h +++ /dev/null @@ -1,1122 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*lint -e618*/ -#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ -#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ - -#include -#include "graph/types.h" - -namespace ge { -#ifdef HOST_VISIBILITY -#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_HOST_VISIBILITY -#endif -#ifdef DEV_VISIBILITY -#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) -#else -#define GE_FUNC_DEV_VISIBILITY -#endif -// Public attribute -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHT_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALPHA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BETA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SCALE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WINDOWS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GLOBAL_POOLING; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELU_FLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALGO; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_NORM_REGION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_LOCAL_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_ALPHA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_BETA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROADCAST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TIDX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TPADDINGS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_W; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TMULTIPLES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTIPLES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_T; - -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string ATTR_NAME_N; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TSHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_OP_DEF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INFERRED_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; - -// to be deleted -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_LOC_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_CONF_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_OCR_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; - -// _Arg -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INDEX; -// _RetVal -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETVAL_ATTR_NAME_INDEX; -// Data -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DATA_ATTR_NAME_DATA_TYPE; - -// Send -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SEND_ATTR_EVENT_ID; - -// Recv -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RECV_ATTR_EVENT_ID; - -// Convolution -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COEF; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDES; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATIONS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ALGO; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_GROUP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD_MODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_STRIDE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_DILATION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_NUM_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_KERNEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_FILTER; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BIAS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_RELU_FLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ADJ; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_TARGET_SHAPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BEFORE_PAD; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_HAS_BIAS; - -// Pooling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAN_OPT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_GLOBAL_POOLING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_WINDOW; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_STRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_CEIL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_DATA_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_BEFORE_PAD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAME_ALGO; - -// Eltwise -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_COEFF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_WEIGHT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_RELU_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_ALPHA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_BETA; - -// BatchNorm -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_EPSILON; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; - -// Huberloss -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; - -// SSDRealDivTileMul -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; - -// SSDSumMulRealDivMean -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; -/// ConcatFive2Four -/// ConcatFour2Five -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; -// Scale -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; - -// FullConnection -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_FILTER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_RELU_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_ATTR_NAME_ALGO; - -// SoftmaxOpParams -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_ALGO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_MODE; - -// SparseSoftmaxCrossEntropy -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; -// Attr labelSmoothing -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; - -// ApplyMomentum -extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION; - -// Activation -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_COEF; - -// Concat -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_ATTR_NAME_AXIS; - -// Const -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; - -// Roipooling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; - -// DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; -// Ssd DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ETA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; - -// Refinedet DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; - -// Yolo DetectionOutput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ClASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BIASES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_RELATIVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; - -// DetectionPostprocess -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; - -// Spatialtransfrom -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; - -// Proposal -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_RATIO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_W; -// Softmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_AXIS; - -// Permute -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; - -// SSD Normalize -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_EPS; - -// Flatten -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_END_AXIS; - -// SsdPRIORBOX -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_FLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; - -// PRelu -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PRELU_ATTR_CHANNEL_SHARED; - -// Psroi pooling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_GROUP_SIZE; - -// Power -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_POWER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; - -// Log -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; -// Pack -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; - -// Dynamic stitch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; -// Unpack -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; -// Gathernd -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TINDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TPARAMS; - -// Argmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; - -// Upsample -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; -// Relu -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; - -// FreeSpaceExtract -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; - -// Split -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SLICE_POINT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_NUM_SPLIT; - -// Tvm -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; - -// Squeeze -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_DIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_OP_NAME; - -// Stride slice -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_END_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; - -// Slice -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_BEGINS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_SIZES; - -// Roialign -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SPATIAL_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; - -// Generate_rpn_proposal -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string - GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string - GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; -// Decode_bbox -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; - -// Cast -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DSTT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_SRCT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DST_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_TRUNCATE; - -// Fastrcnnn predications -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; - -// REORG -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_STRIDE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_REVERSE; - -// MERGE -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; -static const std::string NOT_NET_OUTPUT = "not_net_output"; - -// ENTER -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_CONSTANT_FLAG; - -// Concatv2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_TIDX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_N; -// SUM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_TIDX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_KEEP_DIMS; - -// ResizeBilinear -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_HEIGHT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_WIDTH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_END; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; - -// RetinaNet -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; -// MatMul -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_HAS_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_ATTR_IS_TRAINING; - -// Flatten -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_START_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_END_AXIS; - -// Reshape -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NUM_AXES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_ALPHA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_BETA; - -// Frameoworkop -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_IN_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_OUT_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_N; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_C; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_DEPTH_CONV; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_CONV; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BEFORE_PAD; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ANN_MEAN_KEEPDIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_PADDINGDS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_CONSTANT_VALUE; - -// ConvGradFilter -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; -// ConvGradInput -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; - -// Rnn -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_MODE_STATIC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MUTI_RNN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CELL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CNN_RNN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; - -// Upsample -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; - -// PadV2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; - -// MirrorPad -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; -// Filler -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; - -// Shufflechannel -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHUFFLE_CHANNEL_GROUP; - -// TopKV2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TOPKV2_ATTR_K; - -// Calibaration -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_H_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_W_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_TOP_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_BOTTOM_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_RIGHT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EPSILON; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POOLING_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CLASS_NUM; -// Model -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TARGET_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_STREAM_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; - -// Public attribute -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BYTE_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_INFERENCE_ID; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_OPDEF; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IO_OP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_SCOPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPATTR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUFLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SEQLEN_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_X_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONT_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_XSTATIC_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_MINI; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_TINY; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_LITE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; - -// Used for operators that do not generate task -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; - -// Used for operators that output reuse input -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; - -// L2_normalize -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; -// HCOM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; - -// Log time stamp -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; -// SpaceToDepth/DepthToSpace -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; - -// SparseSoftmaxCrossEntropyWithLogits -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; - -// MaxPoolGradWithArgmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; - -// AvgPoolGrad -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; - -// Varible -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; - -// Assign -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VAR_NAME; - -// ShapeN -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; - -// Space2bacth batch2space -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; -// Depth_to_space space_to_depth -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; -// FakeQuantWithMinMaxVars -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; -// Mobilenet_ssd_conv_fusion -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; - -// Lsh project -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; - -// Control flow -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; - -// GatherV2 attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; - -// Reshape attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; - -// Axis attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; -// The node link with SparseSoftmaxCrossEntropyWithLogits -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; -// For constant folding -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; - -// Used for mark the active label list to find stream of activated node -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; - -// Multi batch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER; - -// Control flow -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIG_NODE_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; - -// Function Op -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; - -// Used for mark the active node is for loop, type:bool -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_INPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_RANGE; - -// Atomic addr clean attrs -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_INPUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_OUTPUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_FUSION_NODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_ATOMIC_NODE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET; -// Used for find variable session_id -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MODEL_ATTR_SESSION_ID; - -// Source/dst format for Op FormatTransfer -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_SRC_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_DST_FORMAT; - -// For compile op by ge call -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEED_COMPILE; - -// For mutil-batch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; - -// For inserted op -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; - -// For compress weight -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; - -// For data dump -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; - -// used for lX fusion -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENGINE_NAME_FOR_LX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_LX_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZE_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_COMPILE_STRATEGY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; - -// for unregistered op -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; - -// op overflow dump -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; - -// functional ops attr -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; - -// used for label switch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; - -// Variable -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; - -// HCOM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; -// used for LX tiling -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; - -// Dynamic stitch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; - -// Used for support Horovod -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; -// for gradient group -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; - -// dynamic shape attrs -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; - -// atc user def dtype&format -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; - -// for fusion op plugin -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; - -// graph partition for aicpu -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; - -// input and output memory type -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; - -// input_output_offset -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; -} // namespace ge - -#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ -/*lint +e618*/ diff --git a/metadef/inc/graph/def_types.h b/metadef/inc/graph/def_types.h deleted file mode 100644 index 6d70fb18..00000000 --- a/metadef/inc/graph/def_types.h +++ /dev/null @@ -1,195 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DEF_TYPES_H_ -#define INC_GRAPH_DEF_TYPES_H_ - -#include -#include -#include -#include "graph/attr_value_serializable.h" -#include "graph/buffer.h" -namespace ge { -#define DEF_TYPE_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - type *mutable_##name() { return &name; } - -#define DEF_TYPE_HAS_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - \ - private: \ - bool has_mutable_##name{false}; \ - \ - public: \ - bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ - type *mutable_##name() { \ - has_mutable_##name = true; \ - return &name; \ - } - -#define DEF_TYPE_VEC_DEC(type, name) \ - inline int name##_size() const { return name.size(); } \ - inline void clear_##name() { name.clear(); } \ - inline void set_##name(int index, type value) { name[index] = value; } \ - inline void add_##name(type value) { name.push_back(value); } \ - inline std::vector *mutable_##name() { return &name; } - -#define DEF_TYPE_BYTES_DEC(name) \ - inline void clear_##name() { name.ClearBuffer(); } \ - inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ - inline Buffer *mutable_##name() { return &name; } - -struct CompressInfo { - public: - CompressInfo() {} - CompressInfo(int32_t blockRow, int32_t blockCol, int32_t fractalK, int32_t fractalN, int32_t lastFractalK, - int32_t lastFractalN, int32_t cubeSize, int32_t loadDir) { - blockrow = blockRow; - blockcol = blockCol; - fractalk = fractalK; - fractaln = fractalN; - lastfractalk = lastFractalK; - lastfractaln = lastFractalN; - cubesize = cubeSize; - loaddir = loadDir; - } - - int32_t blockrow{0}; // Block row - int32_t blockcol{0}; // Block col - int32_t fractalk{0}; // Fractal K - int32_t fractaln{0}; // Fractal N - int32_t lastfractalk{0}; // K of last fractal - int32_t lastfractaln{0}; // N of last fractal - int32_t cubesize{0}; // Cube's length - int32_t loaddir{0}; // Data load directtiono 0:col load 1:row load - DEF_TYPE_DEC(int32_t, blockrow); - DEF_TYPE_DEC(int32_t, blockcol); - DEF_TYPE_DEC(int32_t, fractalk); - DEF_TYPE_DEC(int32_t, fractaln); - DEF_TYPE_DEC(int32_t, lastfractalk); - DEF_TYPE_DEC(int32_t, lastfractaln); - DEF_TYPE_DEC(int32_t, cubesize); - DEF_TYPE_DEC(int32_t, loaddir); - - GE_SERIALIZABLE(blockrow, blockcol, fractalk, fractaln, lastfractalk, lastfractaln, cubesize, loaddir); -}; - -enum QuantizeScaleType { VECTOR_SCALE = 0, SCALAR_SCALE = 1 }; -enum QuantizeScaleMode { NORMAL_MODE = 0, SQRT_MODE = 1 }; -enum QuantizeAlgorithm { - NON_OFFSET_ALGO = 0, - HALF_OFFSET_ALGO = 1, - ALL_OFFSET_ALGO = 2, -}; -struct QuantizeFactor { - public: - // QuantizeScaleMode scale_mode; - uint32_t scale_mode{0}; - Buffer scale_value; - int64_t scale_offset{0}; - Buffer offset_data_value; - int64_t offset_data_offset{0}; - Buffer offset_weight_value; - int64_t offset_weight_offset{0}; - Buffer offset_pad_value; - int64_t offset_pad_offset{0}; - - DEF_TYPE_DEC(uint32_t, scale_mode); - DEF_TYPE_BYTES_DEC(scale_value); - - DEF_TYPE_DEC(int64_t, scale_offset); - DEF_TYPE_BYTES_DEC(offset_data_value); - DEF_TYPE_DEC(int64_t, offset_data_offset); - - DEF_TYPE_BYTES_DEC(offset_weight_value); - DEF_TYPE_DEC(int64_t, offset_weight_offset); - DEF_TYPE_BYTES_DEC(offset_pad_value); - DEF_TYPE_DEC(int64_t, offset_pad_offset); - - GE_SERIALIZABLE(scale_mode, scale_value, scale_offset, offset_data_value, offset_data_offset, offset_weight_value, - offset_weight_offset, offset_pad_value, offset_pad_offset) -}; - -static inline bool QuantizeFactorHasData(const QuantizeFactor &factor) { - return factor.scale_value.GetSize() > 0 || factor.offset_data_value.GetSize() > 0 || - factor.offset_weight_value.GetSize() > 0 || factor.offset_pad_value.GetSize() > 0; -} - -struct AllOffsetQuantizeInfo { - public: - AllOffsetQuantizeInfo() {} - AllOffsetQuantizeInfo(float s, int32_t o) : scale(s), offset(o) {} - float scale{0}; - int32_t offset{0}; - - DEF_TYPE_DEC(float, scale); - DEF_TYPE_DEC(int32_t, offset); - - GE_SERIALIZABLE(scale, offset) -}; - -struct QuantizeCalcFactor { - public: - Buffer offsetw; - int64_t offsetw_offset{0}; - Buffer offsetd; - int64_t offsetd_offset{0}; - Buffer scalereq; - int64_t scaledreq_offset{0}; - Buffer offsetdnext; - int64_t offsetdnext_offset{0}; - - DEF_TYPE_BYTES_DEC(offsetw); - DEF_TYPE_DEC(int64_t, offsetw_offset); - DEF_TYPE_BYTES_DEC(offsetd); - DEF_TYPE_DEC(int64_t, offsetd_offset); - DEF_TYPE_BYTES_DEC(scalereq); - DEF_TYPE_DEC(int64_t, scaledreq_offset); - DEF_TYPE_BYTES_DEC(offsetdnext); - DEF_TYPE_DEC(int64_t, offsetdnext_offset); - - GE_SERIALIZABLE(offsetw, offsetw_offset, offsetd, offsetd_offset, scalereq, scaledreq_offset, offsetdnext, - offsetdnext_offset); -}; - -static inline bool QuantizeFactorHasData(const QuantizeCalcFactor &factor) { - return factor.offsetw.GetSize() > 0 || factor.offsetd.GetSize() > 0 || factor.scalereq.GetSize() > 0 || - factor.offsetdnext.GetSize() > 0; -} - -struct QuantizeFactorParams { - uint32_t quantize_algo{0}; - uint32_t scale_type{0}; - QuantizeFactor quantize_param; - QuantizeFactor dequantize_param; - QuantizeFactor requantize_param; - QuantizeCalcFactor quantizecalc_param; - DEF_TYPE_DEC(uint32_t, quantize_algo); - DEF_TYPE_DEC(uint32_t, scale_type); - DEF_TYPE_HAS_DEC(QuantizeFactor, quantize_param); - DEF_TYPE_HAS_DEC(QuantizeFactor, dequantize_param); - DEF_TYPE_HAS_DEC(QuantizeFactor, requantize_param); - DEF_TYPE_HAS_DEC(QuantizeCalcFactor, quantizecalc_param); - - GE_SERIALIZABLE(quantize_algo, scale_type, quantize_param, dequantize_param, requantize_param, quantizecalc_param, - has_mutable_quantize_param, has_mutable_dequantize_param, has_mutable_requantize_param, - has_mutable_quantizecalc_param); -}; - -#undef DEF_TYPE_DEC -} // namespace ge - -#endif // INC_GRAPH_DEF_TYPES_H_ diff --git a/metadef/inc/graph/detail/any_map.h b/metadef/inc/graph/detail/any_map.h deleted file mode 100644 index 70533ea1..00000000 --- a/metadef/inc/graph/detail/any_map.h +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DETAIL_ANY_MAP_H_ -#define INC_GRAPH_DETAIL_ANY_MAP_H_ - -#include -#include -#include -#include - -namespace ge { -using std::shared_ptr; -using std::string; - -class TypeID { - public: - template - static TypeID Of() { - return TypeID(__PRETTY_FUNCTION__); - } - - ~TypeID() = default; - - bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } - - private: - explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 - - string type_; -}; - -class AnyMap { - public: - template - bool Set(const string &name, const DT &val); - - template - bool Get(const string &name, T &retValue) const; - - bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } - - void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); } - - private: - class Placeholder { - public: - virtual ~Placeholder() = default; - - virtual const TypeID &GetTypeInfo() const = 0; - }; - - template - class Holder : public Placeholder { - public: - explicit Holder(const VT &value) : value_(value) {} - - ~Holder() override = default; - - const TypeID &GetTypeInfo() const override { - static const TypeID typeId = TypeID::Of(); - return typeId; - } - - const VT value_; - }; - - std::map> anyValues_; -}; - -template -bool AnyMap::Set(const string &name, const DT &val) { - auto it = anyValues_.find(name); - - std::shared_ptr> tmp; - try { - tmp = std::make_shared>(val); - } catch (std::bad_alloc &e) { - tmp = nullptr; - } catch (...) { - tmp = nullptr; - } - - if (it == anyValues_.end()) { - (void)anyValues_.emplace(name, tmp); - } else { - if (it->second && it->second->GetTypeInfo() == TypeID::Of
()) { - it->second = tmp; - } else { - return false; - } - } - return true; -} - -template -bool AnyMap::Get(const string &name, T &retValue) const { - auto it = anyValues_.find(name); - if (it != anyValues_.end() && it->second && it->second->GetTypeInfo() == TypeID::Of()) { - auto retPtr = std::static_pointer_cast>(it->second); - retValue = retPtr->value_; - return true; - } - return false; -} -} // namespace ge -#endif // INC_GRAPH_DETAIL_ANY_MAP_H_ diff --git a/metadef/inc/graph/detail/attributes_holder.h b/metadef/inc/graph/detail/attributes_holder.h deleted file mode 100644 index 49741143..00000000 --- a/metadef/inc/graph/detail/attributes_holder.h +++ /dev/null @@ -1,165 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ -#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ - -#include -#include -#include -#include -#include -#include -#include "graph/detail/any_map.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" - -namespace google { -namespace protobuf { -class Message; -template -class Map; -} // namespace protobuf -} // namespace google - -namespace ge { -using std::string; -class GeAttrValue; - -namespace proto { -class AttrDef; -class TensorDef; -class TensorDescriptor; -class ShapeDef; -class NamedAttrs; -class ModelDef; -class OpDef; -class GraphDef; -} // namespace proto - -using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 -using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; - -template -class GeIrProtoHelper { - public: - GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *protoMsg) - : protoOwner_(protoOwner), protoMsg_(protoMsg) {} - - GeIrProtoHelper() { - protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr); - protoMsg_ = nullptr; - } - virtual ~GeIrProtoHelper() = default; - - template - GeIrProtoHelper(const GeIrProtoHelper &other) { - protoOwner_ = other.protoOwner_; - protoMsg_ = other.protoMsg_; - } - template - GeIrProtoHelper &operator=(const GeIrProtoHelper &other) { - protoOwner_ = other.protoOnwer_; - protoMsg_ = other.protoMsg_; - return *this; - } - void InitDefault(); - template - bool operator==(const GeIrProtoHelper &other) const { - return protoOwner_ == other.protoOwner_ && protoMsg_ == other.protoMsg_; - } - - inline const ProtoMsgOwner &GetProtoOwner() const { return protoOwner_; } - inline ProtoType *GetProtoMsg() const { return protoMsg_; } - void CopyValueFrom(const GeIrProtoHelper &other) { - if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { - *protoMsg_ = *other.protoMsg_; - } - } - void MoveValueFrom(GeIrProtoHelper &&other) { - if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { - *protoMsg_ = std::move(*other.protoMsg_); - } - } - - void Swap(GeIrProtoHelper &other) { - protoOwner_.swap(other.protoOwner_); - - ProtoType *temp = protoMsg_; - protoMsg_ = other.protoMsg_; - other.protoMsg_ = temp; - } - - // protoMsg_ is part of protoOwner_, they have the same runtime - ProtoMsgOwner protoOwner_ = nullptr; - ProtoType *protoMsg_ = nullptr; - friend class GeIrProtoHelper::value, typename std::remove_const::type, const ProtoType>::type>; -}; - -using ProtoAttrMapHelper = GeIrProtoHelper; -using ConstProtoAttrMapHelper = GeIrProtoHelper; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { - public: - AttrHolder() = default; - virtual ~AttrHolder() = default; - - graphStatus SetAttr(const string &name, const GeAttrValue &value); - - graphStatus GetAttr(const string &name, GeAttrValue &value) const; - - bool HasAttr(const string &name) const; - - graphStatus DelAttr(const string &name); - - void CopyAttrsFrom(const AttrHolder &holder); - - void Swap(AttrHolder &holder) { - requiredAttrs_.swap(holder.requiredAttrs_); - extAttrs_.Swap(holder.extAttrs_); - } - - template - bool SetExtAttr(const string &name, const T &value) { - return extAttrs_.Set(name, value); - } - template - T TryGetExtAttr(const string &name, T defaultValue) const { - T ret(defaultValue); - (void)extAttrs_.Get(name, ret); - return ret; - } - - protected: - graphStatus AddRequiredAttr(const std::string &name); - const std::unordered_set GetAllAttrNames() const; - const std::map GetAllAttrs() const; // lint !e1073 - - virtual ProtoAttrMapHelper MutableAttrMap() = 0; - virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; - - friend class ModelSerializeImp; - friend class AttrUtils; - friend class AttrUtilsHelper; - - std::vector requiredAttrs_; - - private: - AnyMap extAttrs_; -}; -} // namespace ge -#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/metadef/inc/graph/detail/model_serialize_imp.h b/metadef/inc/graph/detail/model_serialize_imp.h deleted file mode 100644 index ff27335a..00000000 --- a/metadef/inc/graph/detail/model_serialize_imp.h +++ /dev/null @@ -1,93 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ -#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ - -#include -#include -#include -#include -#include "graph/anchor.h" -#include "graph/detail/attributes_holder.h" -#include "graph/ge_tensor.h" -#include "graph/graph.h" -#include "graph/node.h" - -namespace ge { -using ComputeGraphPtr = std::shared_ptr; - -struct NodeNameGraphReq { - string node_name; - int32_t index; - ComputeGraphPtr graph; -}; - -struct NodeNameNodeReq { - string src_node_name; - int32_t src_out_index; - NodePtr dst_node; - int32_t dst_in_index; - string dst_node_name; -}; - -class ModelSerializeImp { - public: - bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); - - bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); - - bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); - - bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); - - bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); - - bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); - - bool UnserializeModel(Model &model, proto::ModelDef &modeProto); - - bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graphProto); - - bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graphProto); - - bool HandleNodeNameRef(); - - bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); - void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, - std::vector &value_in, std::vector &value_out, std::vector &opt); - void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); - - bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); - - bool UnserializeTensor(GeTensorPtr &tensor, proto::TensorDef &tensorProto); - - bool ParseNodeIndex(const string &node_index, string &nodeName, int32_t &index); - - void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } - - private: - bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs); - - std::vector graph_input_node_names_; - std::vector graph_output_node_names_; - std::vector node_input_node_names_; - std::map node_map_; - ProtoMsgOwner protobuf_owner_; -}; -} // namespace ge - -#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ diff --git a/metadef/inc/graph/ge_attr_value.h b/metadef/inc/graph/ge_attr_value.h deleted file mode 100644 index 0c265c20..00000000 --- a/metadef/inc/graph/ge_attr_value.h +++ /dev/null @@ -1,343 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_ATTR_VALUE_H_ -#define INC_GRAPH_GE_ATTR_VALUE_H_ - -#include -#include -#include -#include -#include -#include -#include "graph/buffer.h" -#include "detail/attributes_holder.h" -#include "graph/ge_error_codes.h" -#include "graph/ge_tensor.h" - -using std::map; -using std::string; -using std::vector; - -namespace ge { -class GeTensor; - -using GeTensorPtr = std::shared_ptr; -using ConstGeTensorPtr = std::shared_ptr; - -class ComputeGraph; -using ComputeGraphPtr = std::shared_ptr; -using ConstComputeGraphPtr = std::shared_ptr; - -class GeTensorDesc; -class GeAttrValue; -class GeAttrValueImp; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { - public: - NamedAttrs(); - virtual ~NamedAttrs() = default; - void SetName(const std::string &name); - string GetName() const; - GeAttrValue GetItem(const string &key) const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - // Create namedAttrs from protobuf obj - NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); - GeIrProtoHelper named_attrs_; - friend class GeAttrValueImp; - friend class GeAttrValue; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { - public: - using INT = int64_t; - using FLOAT = float; - using BOOL = bool; - using STR = std::string; - using TENSOR = GeTensorPtr; - using TENSOR_DESC = GeTensorDesc; - using GRAPH = ComputeGraphPtr; - using BYTES = Buffer; - using NAMED_ATTRS = ge::NamedAttrs; - using DATA_TYPE = ge::DataType; - - using LIST_INT = vector; - using LIST_FLOAT = vector; - using LIST_BOOL = vector; - using LIST_STR = vector; - using LIST_TENSOR = vector; - using LIST_TENSOR_DESC = vector; - using LIST_GRAPH = vector; - using LIST_BYTES = vector; - using LIST_NAMED_ATTRS = vector; - using LIST_LIST_INT = vector>; - using LIST_DATA_TYPE = vector; - - using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). - - enum ValueType { - VT_NONE = 0, - VT_STRING, - VT_FLOAT, - VT_BOOL, - VT_INT, - VT_TENSOR_DESC, - VT_TENSOR, - VT_BYTES, - VT_GRAPH, - VT_NAMED_ATTRS, - VT_LIST_LIST_INT, - VT_DATA_TYPE, - - VT_LIST_BASE = 1000, - VT_LIST_STRING = VT_LIST_BASE + VT_STRING, - VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT, - VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL, - VT_LIST_INT = VT_LIST_BASE + VT_INT, - VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC, - VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR, - VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES, - VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH, - VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS, - VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE, - }; - - template - struct IsAttrTypeEnable { - using DT = typename std::remove_cv::type; - - static bool const VALUE = std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value; - - // Not has list type of NamedAttrs - static bool const LIST_VALUE = std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || - std::is_same::value || - std::is_same::value || std::is_same::value; - }; - - template - // To cols - using enable_if_vector_type_valid_t = typename std::enable_if::LIST_VALUE, int>::type; - - template - using enable_if_one_type_valid_t = typename std::enable_if::VALUE, int>::type; - - template - using enable_if_type_valid_t = - typename std::enable_if::VALUE || IsAttrTypeEnable::LIST_VALUE, int>::type; - - template - using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; - - GeAttrValue(); - ~GeAttrValue() = default; - // SetValue, Set initializer_list - template = 0> - graphStatus SetValue(std::initializer_list
&&val) { - T vectorVal; - for (auto &item : val) { - vectorVal.push_back(item); - } - return SetValue(vectorVal); - } - - // SetValue, Set vector - template = 0> - graphStatus SetValue(const std::vector
&val) { - T vectorVal; - for (auto item : val) { - vectorVal.push_back(item); - } - return SetValue(vectorVal); - } - - // SetValue, not list type - template = 0> - graphStatus SetValue(DT &&val) { - return SetValue(T(std::forward
(val))); - } - - // GE_SERIALIZABLE - template = 0> - graphStatus SetValue(const T &t) { - return t.Save(*this); - } - - template = 0> - graphStatus SetValue(const vector &t) { - vector attrs; - for (auto &item : t) { - GeAttrValue val; - item.Save(val); - NamedAttrs attrsItem; - (void)val.GetValue(attrsItem); - attrs.push_back(attrsItem); - } - return SetValue(attrs); - } - - // GetValue, list value - template = 0, - typename std::enable_if::value, int>::type = 0> - graphStatus GetValue(std::vector
&val) const { - T valGet; - val.clear(); - auto status = GetValue(valGet); - if (status != GRAPH_SUCCESS) { - return status; - } - for (auto item : valGet) { - val.push_back(item); - } - return GRAPH_SUCCESS; - } - - // GetValue, not list type - template = 0, - typename std::enable_if::value, int>::type = 0> - graphStatus GetValue(DT &val) const { - T valGet; - auto status = GetValue(valGet); - if (status != GRAPH_SUCCESS) { - return status; - } - val = DT(valGet); - return GRAPH_SUCCESS; - } - - // GE_SERIALIZABLE - template = 0> - graphStatus GetValue(T &t) { - return t.Load(*this); - } - - template = 0> - graphStatus GetValue(vector &t) { - graphStatus status; - t.clear(); - vector attrs; - status = this->GetValue(attrs); - if (status != GRAPH_SUCCESS) { - return status; - } - for (auto &attr : attrs) { - T item; - GeAttrValue val; - (void)val.SetValue(attr); - status = item.Load(val); - if (status != GRAPH_SUCCESS) { - return status; - } - t.push_back(item); - } - return GRAPH_SUCCESS; - } - - template = 0> - static GeAttrValue CreateFrom(DT &&val) { - GeAttrValue valRet; - (void)valRet.SetValue(std::forward
(val)); - return valRet; - } - - template = 0> - static GeAttrValue CreateFrom(std::initializer_list
&&val) { - GeAttrValue valRet; - (void)valRet.SetValue(std::move(val)); - return valRet; - } - - template = 0> - static GeAttrValue CreateFrom(const T &val) { - GeAttrValue valRet; - (void)valRet.SetValue(val); - return valRet; - } - - template = 0> - static GeAttrValue CreateFrom(const vector &val) { - GeAttrValue valRet; - (void)valRet.SetValue(val); - return valRet; - } - - ValueType GetValueType() const; - - bool IsEmpty() const; - - GeAttrValue Copy() const; - - // For map key - bool operator==(const GeAttrValue &other) const { return value_ == other.value_; } - - graphStatus MutableTensor(GeTensorPtr &tensor); - graphStatus MutableListTensor(vector &list_tensor); - - private: -#define VALUE_SET_GET_DEC(DT) \ - graphStatus SetValue(const DT &val); \ - graphStatus GetValue(DT &val) const; - VALUE_SET_GET_DEC(GeAttrValue::STR) - VALUE_SET_GET_DEC(GeAttrValue::INT) - VALUE_SET_GET_DEC(GeAttrValue::FLOAT) - VALUE_SET_GET_DEC(GeAttrValue::BOOL) - VALUE_SET_GET_DEC(GeTensorDesc) - VALUE_SET_GET_DEC(GeAttrValue::TENSOR) - VALUE_SET_GET_DEC(GeAttrValue::GRAPH) - VALUE_SET_GET_DEC(BYTES) - VALUE_SET_GET_DEC(NamedAttrs) - VALUE_SET_GET_DEC(ge::DataType) // lint !e665 - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector>) // lint !e665 - VALUE_SET_GET_DEC(vector) // lint !e665 -#undef VALUE_SET_GET_DEC - - GeIrProtoHelper value_; - GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val); - - friend class AttrHolder; - friend class ModelSerializeImp; - friend class OnnxUtils; -}; - -class AttrValueImpl { - public: - AttrValueImpl() = default; - ~AttrValueImpl() = default; - - GeAttrValue geAttrValue_; -}; -} // namespace ge -#endif // INC_GRAPH_GE_ATTR_VALUE_H_ diff --git a/metadef/inc/graph/ge_context.h b/metadef/inc/graph/ge_context.h deleted file mode 100644 index 53985e9c..00000000 --- a/metadef/inc/graph/ge_context.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_CONTEXT_H_ -#define INC_GRAPH_GE_CONTEXT_H_ - -#include -#include "graph/ge_error_codes.h" - -namespace ge { -class GEContext { - public: - graphStatus GetOption(const std::string &key, std::string &option); - bool GetHostExecFlag(); - uint64_t SessionId(); - uint32_t DeviceId(); - uint64_t TraceId(); - void Init(); - void SetSessionId(uint64_t session_id); - void SetCtxDeviceId(uint32_t device_id); - - private: - uint64_t session_id_ = 0; - uint32_t device_id_ = 0; - uint64_t trace_id_ = 0; -}; // class GEContext - -/// Get context -/// @return -GEContext &GetContext(); -} // namespace ge - -#endif // INC_GRAPH_GE_CONTEXT_H_ diff --git a/metadef/inc/graph/ge_global_options.h b/metadef/inc/graph/ge_global_options.h deleted file mode 100644 index b55192e2..00000000 --- a/metadef/inc/graph/ge_global_options.h +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ -#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ - -#include -#include - -namespace ge { -std::map &GetMutableGlobalOptions(); -} -#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ diff --git a/metadef/inc/graph/ge_local_context.h b/metadef/inc/graph/ge_local_context.h deleted file mode 100644 index b47098fb..00000000 --- a/metadef/inc/graph/ge_local_context.h +++ /dev/null @@ -1,44 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ -#define INC_GRAPH_GE_LOCAL_CONTEXT_H_ - -#include -#include -#include -#include "graph/ge_error_codes.h" - -using std::map; -using std::string; - -namespace ge { -class GEThreadLocalContext { - public: - graphStatus GetOption(const string &key, string &option); - void SetGraphOption(map options_map); - void SetSessionOption(map options_map); - void SetGlobalOption(map options_map); - - private: - map graph_options_; - map session_options_; - map global_options_; -}; // class GEThreadLocalContext - -GEThreadLocalContext &GetThreadLocalContext(); -} // namespace ge -#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ diff --git a/metadef/inc/graph/ge_tensor.h b/metadef/inc/graph/ge_tensor.h deleted file mode 100644 index 834dca0b..00000000 --- a/metadef/inc/graph/ge_tensor.h +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GE_TENSOR_H_ -#define INC_GRAPH_GE_TENSOR_H_ - -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/buffer.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { - public: - GeShape(); - ~GeShape() = default; - explicit GeShape(std::vector s); - - size_t GetDimNum() const; - // If the idx is invalid, return 0 - int64_t GetDim(size_t idx) const; - graphStatus SetDim(size_t idx, int64_t value); - std::vector GetDims() const; - - int64_t GetShapeSize() const; - std::string ToString() const; - - /// - /// @brief Check is unknown shape - /// @return bool - /// - bool IsUnknownShape() const; - - /// - /// @brief Check is a scalar - /// @return bool - /// - bool IsScalar() const; - - GeShape(const GeShape &other); - GeShape(GeShape &&other); - GeShape &operator=(const GeShape &other); - GeShape &operator=(GeShape &&other); - - private: - GeIrProtoHelper shape_def_; - friend class GeTensorDesc; - // Create from proto obj - GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); - - void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder { - friend class TensorUtils; - friend class GeAttrValue; - friend class ModelSerialize; - - public: - GeTensorDesc(); - explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - GeTensorDesc(const GeTensorDesc &desc); - GeTensorDesc(GeTensorDesc &&desc); - - ~GeTensorDesc() = default; - bool operator==(const GeTensorDesc &r_ge_tensor_desc) const; - - void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); - - GeShape GetShape() const; - GeShape &MutableShape(); - void SetShape(GeShape shape); - - // set shape with -2, it stand for unknown shape - void SetUnknownDimNumShape(); - // for unknown shape - graphStatus SetShapeRange(const std::vector> &range); - graphStatus GetShapeRange(std::vector> &range) const; - - GeShape GetOriginShape() const; - void SetOriginShape(const GeShape &originShape); - - Format GetFormat() const; - void SetFormat(Format format); - - Format GetOriginFormat() const; - void SetOriginFormat(Format originFormat); - - void SetName(const std::string &name); - const std::string GetName() const; - - DataType GetDataType() const; - void SetDataType(DataType dt); - - DataType GetOriginDataType() const; - void SetOriginDataType(DataType originDataType); - - std::vector GetRefPortIndex() const; - void SetRefPortByIndex(const std::vector &index); - - GeTensorDesc Clone() const; - GeTensorDesc &operator=(const GeTensorDesc &desc); - GeTensorDesc &operator=(GeTensorDesc &&desc); - - graphStatus IsValid() const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const; - using AttrHolder::DelAttr; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - void Init(); - - // Create from proto obj - GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); - friend class GeTensor; - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class OnnxUtils; - - GeIrProtoHelper tensor_descriptor_; - // Reference from tensorDescriptor_, do not direct use - mutable GeShape __shape_; - - void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; } - GeShape &ShapeReference() const; -}; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { - public: - GeTensor(); - explicit GeTensor(const GeTensorDesc &tensorDesc); - explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector &data); - explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data); - explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); - explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector &&data); - ~GeTensor() = default; - - GeTensorDesc GetTensorDesc() const; - GeTensorDesc &MutableTensorDesc(); - void SetTensorDesc(const GeTensorDesc &tensorDesc); - - const Buffer GetData() const; - Buffer MutableData(); - graphStatus SetData(std::vector &&data); - graphStatus SetData(const std::vector &data); - graphStatus SetData(const Buffer &data); - graphStatus SetData(const uint8_t *data, size_t size); - - GeTensor Clone() const; - - // Share value - GeTensor(const GeTensor &other); - // Share value - GeTensor &operator=(const GeTensor &other); - - private: - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class OnnxUtils; - // Create from proto obj - GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); - GeIrProtoHelper tensor_def_; - // Reference from tensorDef_, do not direct use - mutable GeTensorDesc __desc_; - GeTensorDesc &DescReference() const; -}; -} // namespace ge -#endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/metadef/inc/graph/graph_util.h b/metadef/inc/graph/graph_util.h deleted file mode 100644 index c39ecbc1..00000000 --- a/metadef/inc/graph/graph_util.h +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_GRAPH_UTIL_H_ -#define INC_GRAPH_GRAPH_UTIL_H_ - -#include - -#include "proto/om.pb.h" - -namespace ge { -using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; -bool HasOpAttr(const OpDef *opdef, std::string attr_name); -bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); - -static const char OP_TYPE_DATA[] = "Data"; -static const char OP_TYPE_INPUT[] = "Input"; -static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; -static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; -static const char OP_TYPE_ANN_DATA[] = "AnnData"; -} // namespace ge - -#if !defined(__ANDROID__) && !defined(ANDROID) -#include "toolchain/slog.h" -const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; -#else -#include -#include -const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; -#endif - -#ifdef _MSC_VER -#define FUNC_NAME __FUNCTION__ -#else -#define FUNC_NAME __PRETTY_FUNCTION__ -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ - dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ - dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ - dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#else -#define D_GRAPH_LOG(level, format, ...) \ - do { \ - { \ - fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ - __FILE__, __LINE__, ##__VA_ARGS__); \ - syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ - ##__VA_ARGS__); \ - } \ - } while (0) -#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) -#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) -#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) -#else - -#define GRAPH_LOG(level, format, ...) \ - do { \ - { \ - fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ - __FILE__, __LINE__, ##__VA_ARGS__); \ - syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ - ##__VA_ARGS__); \ - } \ - } while (0) -#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#endif - -#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ - do { \ - const domi::graphStatus _status = (expr); \ - if (_status != domi::GRAPH_SUCCESS) { \ - return _status; \ - } \ - } while (0) - -#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GRAPH_LOGE(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ - { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - }; - -#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ - { \ - if (expr) { \ - exec_expr; \ - } \ - } - -#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ - do { \ - const ::domi::graphStatus _status = (expr); \ - if (_status) { \ - GRAPH_LOGE(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#endif // INC_GRAPH_GRAPH_UTIL_H_ diff --git a/metadef/inc/graph/model.h b/metadef/inc/graph/model.h deleted file mode 100644 index 38ea501b..00000000 --- a/metadef/inc/graph/model.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_MODEL_H_ -#define INC_GRAPH_MODEL_H_ - -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/ge_attr_value.h" -#include "graph/graph.h" - -namespace ge { -using std::map; -using std::string; -using std::vector; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { - public: - Model(); - - ~Model() = default; - - Model(const string &name, const string &custom_version); - - string GetName() const; - void SetName(const string &name); - - uint32_t GetVersion() const; - - void SetVersion(uint32_t version) { version_ = version; } - - std::string GetPlatformVersion() const; - - void SetPlatformVersion(string version) { platform_version_ = version; } - - Graph GetGraph() const; - - void SetGraph(const Graph &graph); - - void SetAttr(const ProtoAttrMapHelper &attrs); - - using AttrHolder::GetAllAttrNames; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - graphStatus Save(Buffer &buffer, bool is_dump = false) const; - - graphStatus SaveToFile(const string &file_name) const; - // Model will be rewrite - static graphStatus Load(const uint8_t *data, size_t len, Model &model); - graphStatus Load(ge::proto::ModelDef &model_def); - graphStatus LoadFromFile(const string &file_name); - - bool IsValid() const; - - protected: - ConstProtoAttrMapHelper GetAttrMap() const override; - ProtoAttrMapHelper MutableAttrMap() override; - - private: - void Init(); - ProtoAttrMapHelper attrs_; - friend class ModelSerializeImp; - friend class GraphDebugImp; - friend class OnnxUtils; - friend class ModelHelper; - friend class ModelBuilder; - string name_; - uint32_t version_; - std::string platform_version_{""}; - Graph graph_; -}; -} // namespace ge -using ModelPtr = std::shared_ptr; - -#endif // INC_GRAPH_MODEL_H_ diff --git a/metadef/inc/graph/model_serialize.h b/metadef/inc/graph/model_serialize.h deleted file mode 100644 index 16529512..00000000 --- a/metadef/inc/graph/model_serialize.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_MODEL_SERIALIZE_H_ -#define INC_GRAPH_MODEL_SERIALIZE_H_ - -#include -#include -#include "graph/buffer.h" -#include "graph/compute_graph.h" -#include "graph/model.h" - -namespace ge { -class ModelSerialize { - public: - Buffer SerializeModel(const Model &model, bool is_dump = false); - - Model UnserializeModel(const uint8_t *data, size_t len); - Model UnserializeModel(ge::proto::ModelDef &model_def); - - Buffer SerializeGraph(const ComputeGraphPtr &graph); - - ComputeGraphPtr UnserializeGraph(const uint8_t *data, size_t len); - - Buffer SerializeOpDesc(const ConstOpDescPtr &opDesc); - OpDescPtr UnserializeOpDesc(const uint8_t *data, size_t len); - - size_t GetSerializeModelSize(const Model &model); - - private: - static std::map &MutableTensorDescAttrMap(GeTensorDesc &tensorDesc); - - static const std::map &GetTensorDescAttrMap(const GeTensorDesc &tensorDesc); - - friend class ModelSerializeImp; - friend class GraphDebugImp; -}; -} // namespace ge -#endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/metadef/inc/graph/node.h b/metadef/inc/graph/node.h deleted file mode 100644 index f4a1c6a8..00000000 --- a/metadef/inc/graph/node.h +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_NODE_H_ -#define INC_GRAPH_NODE_H_ - -#include -#include -#include -#include -#include -#include -#include "graph/ge_attr_value.h" -#include "utils/attr_utils.h" - -#include "graph/op_desc.h" -#include "graph/range_vistor.h" - -namespace ge { -class ComputeGraph; - -using ComputeGraphPtr = std::shared_ptr; - -class Node; - -using NodePtr = std::shared_ptr; -using ConstNodePtr = std::shared_ptr; -using NodeRef = std::weak_ptr; - -class Anchor; - -using AnchorPtr = std::shared_ptr; - -class InDataAnchor; - -using InDataAnchorPtr = std::shared_ptr; - -class OutDataAnchor; - -using OutDataAnchorPtr = std::shared_ptr; - -class ControlAnchor; - -using ControlAnchorPtr = std::shared_ptr; - -class InControlAnchor; - -using InControlAnchorPtr = std::shared_ptr; - -class OutControlAnchor; - -using OutControlAnchorPtr = std::shared_ptr; - -using OpDescPtr = std::shared_ptr; - -using ConstNode = const Node; - -typedef std::vector> kFusionDataFlowVec_t; - -// Node is a component of ComputeGraph -class Node : public std::enable_shared_from_this { - friend class ComputeGraph; - friend class ModelSerializeImp; - - public: - template - using Vistor = RangeVistor>; - ~Node(); - Node(const Node &) = delete; - Node &operator=(const Node &) = delete; - bool operator==(const Node &r_node) const; - - protected: - Node() = default; - Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph); - - public: - graphStatus Init(); - - std::string GetName() const; - std::string GetType() const; - - ComputeGraphPtr GetOwnerComputeGraph() const; - graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); - - Vistor GetAllInDataAnchors() const; - Vistor GetAllOutDataAnchors() const; - uint32_t GetAllInDataAnchorsSize() const; - uint32_t GetAllOutDataAnchorsSize() const; - Vistor GetAllOutAnchors() const; - Vistor GetAllInAnchors() const; - InDataAnchorPtr GetInDataAnchor(int idx) const; - OutDataAnchorPtr GetOutDataAnchor(int idx) const; - InControlAnchorPtr GetInControlAnchor() const; - OutControlAnchorPtr GetOutControlAnchor() const; - Vistor GetInNodes() const; - Vistor GetOutNodes() const; - AnchorPtr GetInAnchor(int idx) const; - AnchorPtr GetOutAnchor(int idx) const; - - bool IsAllInNodesSeen(std::unordered_set &nodes_seen) const; - - // All in Data nodes - Vistor GetInDataNodes() const; - // All in Control nodes - Vistor GetInControlNodes() const; - // GetInAllNodes = InDataNodes + InControlNodes - Vistor GetInAllNodes() const; - - // All out Data nodes - Vistor GetOutDataNodes() const; - uint32_t GetOutDataNodesSize() const; - // All out Control nodes - Vistor GetOutControlNodes() const; - // GetOutAllNodes = OutDataNodes + InControlNodes - Vistor GetOutAllNodes() const; - - // Get all in data nodes and its out-anchor - Vistor> GetInDataNodesAndAnchors() const; - - // Get all out data nodes and its in-anchor - Vistor> GetOutDataNodesAndAnchors() const; - - graphStatus InferShapeAndType() const; - graphStatus Verify() const; - - graphStatus InferOriginFormat() const; - - OpDescPtr GetOpDesc() const; - - graphStatus UpdateOpDesc(const OpDescPtr &op); - - graphStatus AddLinkFrom(const NodePtr &input_node); - - graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node); - - graphStatus AddLinkFrom(const string &name, NodePtr input_node); - - graphStatus AddLinkFromForParse(const NodePtr &input_node); - - void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); } - - void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); } - - const std::vector &GetSendEventIdList() const { return send_event_id_list_; } - - const std::vector &GetRecvEventIdList() const { return recv_event_id_list_; } - void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { - fusion_input_list = fusion_input_dataflow_list_; - } - - void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { - fusion_output_list = fusion_output_dataflow_list_; - } - - void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { - fusion_input_dataflow_list_ = fusion_input_list; - } - - void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { - fusion_output_dataflow_list_ = fusion_output_list; - } - - bool GetHostNode() const { return host_node_; } - void SetHostNode(bool is_host) { host_node_ = is_host; } - - void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } - - NodePtr GetOrigNode() { return orig_node_; } - - private: - bool NodeMembersAreEqual(const Node &r_node) const; - bool NodeAttrsAreEqual(const Node &r_node) const; - bool NodeInConnectsAreEqual(const Node &r_node) const; - bool NodeOutConnectsAreEqual(const Node &r_node) const; - bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const; - OpDescPtr op_; - std::weak_ptr owner_graph_; - vector in_data_anchors_; - vector out_data_anchors_; - InControlAnchorPtr in_control_anchor_; - OutControlAnchorPtr out_control_anchor_; - map attrs_; // lint !e1073 - bool has_init_{false}; - bool host_node_{false}; - bool anchor_status_updated_{false}; - std::vector send_event_id_list_; - std::vector recv_event_id_list_; - - kFusionDataFlowVec_t fusion_input_dataflow_list_; - kFusionDataFlowVec_t fusion_output_dataflow_list_; - - NodePtr orig_node_; - friend class NodeUtils; - friend class OnnxUtils; - friend class TuningUtils; -}; -} // namespace ge - -#endif // INC_GRAPH_NODE_H_ diff --git a/metadef/inc/graph/op_desc.h b/metadef/inc/graph/op_desc.h deleted file mode 100644 index c7da30b7..00000000 --- a/metadef/inc/graph/op_desc.h +++ /dev/null @@ -1,328 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OP_DESC_H_ -#define INC_GRAPH_OP_DESC_H_ - -#include -#include -#include -#include -#include -#include -#include -#include "detail/attributes_holder.h" -#include "graph/range_vistor.h" - -#define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index))) -#define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index))) -namespace ge { -using std::map; -using std::pair; -using std::shared_ptr; -using std::string; -using std::vector; - -class Operator; -class GeTensorDesc; - -using GeTensorDescPtr = shared_ptr; -using ConstGeTensorDescPtr = shared_ptr; - -class OpDesc; - -using OpDescPtr = shared_ptr; -using ConstOpDescPtr = shared_ptr; - -class GeAttrValue; - -using ConstOpDesc = const OpDesc; - -enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; - -class OpDesc : public std::enable_shared_from_this, public AttrHolder { - public: - template - using Vistor = RangeVistor>; - - friend class GraphBuilderImpl; - - friend class OperatorImpl; - - OpDesc(const string &name, const string &type); - - OpDesc(); - - ~OpDesc(); - - bool operator==(const OpDesc &r_op_desc) const; - - string GetName() const; - - void SetName(const string &name); - - string GetType() const; - - void SetType(const string &type); - - graphStatus AddInputDesc(const GeTensorDesc &input_desc); - - graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc); - - graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc); - - graphStatus AddInputDescForward(const string &name, const unsigned int num); - - graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); - - graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); - - graphStatus AddOutputDescForward(const string &name, const unsigned int num); - - graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); - - graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc); - - graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc); - - bool InputIsSet(const string &name) const; - - GeTensorDesc GetInputDesc(uint32_t index) const; - - GeTensorDesc GetInputDesc(const string &name) const; - - Vistor GetAllInputNames() const; - - GeTensorDescPtr MutableInputDesc(uint32_t index) const; - - GeTensorDescPtr MutableInputDesc(const string &name) const; - - Vistor GetAllInputsDesc() const; - - Vistor GetAllInputsDescPtr() const; - - size_t GetInputsSize() const; - - size_t GetAllInputsSize() const; - - graphStatus AddOutputDesc(const GeTensorDesc &output_desc); - - graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); - - graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc); - - graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc); - - GeTensorDesc GetOutputDesc(uint32_t index) const; - - GeTensorDesc GetOutputDesc(const string &name) const; - - GeTensorDescPtr MutableOutputDesc(uint32_t index) const; - - GeTensorDescPtr MutableOutputDesc(const string &name) const; - - uint32_t GetAllOutputsDescSize() const; - - Vistor GetAllOutputsDesc() const; - - Vistor GetAllOutputsDescPtr() const; - - size_t GetOutputsSize() const; - - ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const; - - ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; - - ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; - - ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; - - graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); - - graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); - - graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); - - bool IsOptionalInput(const string &name) const; - - bool IsOptionalInput(uint32_t index) const; - - std::map GetAllInputName() const; - - std::map GetAllOutputName(); - - bool UpdateInputName(std::map inputNameIdx); - - bool UpdateOutputName(std::map outputNameIdx); - - void AddInferFunc(const std::function &func); - - std::function GetInferFunc() const; - - graphStatus InferShapeAndType(); - - void AddInferFormatFunc(const std::function &func); - - std::function GetInferFormatFunc() const; - - graphStatus DefaultInferFormat(); - - std::function GetVerifyFunc() const; - - void AddVerifierFunc(const std::function &func); - - graphStatus CallInferFormatFunc(Operator &op); - - graphStatus OpVerify(); - - graphStatus CommonVerify() const; - - graphStatus AddRegisterInputName(const string &name); - - graphStatus AddRegisterOutputName(const string &name); - - vector GetRegisterInputName() const; - - vector GetRegisterOutputName() const; - - using AttrHolder::AddRequiredAttr; - using AttrHolder::DelAttr; - using AttrHolder::GetAllAttrNames; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - void SetId(int64_t id); - int64_t GetId() const; - void SetStreamId(int64_t stream_id); - int64_t GetStreamId() const; - void SetInputName(const vector &input_name); - vector GetInputName() const; - void SetSrcName(const vector &src_name); - vector GetSrcName() const; - void SetSrcIndex(const vector &src_index); - vector GetSrcIndex() const; - void SetInputOffset(const vector &input); - vector GetInputOffset() const; - void SetOutputOffset(const vector &input); - vector GetOutputOffset() const; - void SetDstName(const vector &dst_name); - vector GetDstName() const; - void SetDstIndex(const vector &dst_index); - vector GetDstIndex() const; - void SetWorkspace(const vector &workspace); - vector GetWorkspace() const; - void SetWorkspaceBytes(const vector &workspace_bytes); - vector GetWorkspaceBytes() const; - void SetIsInputConst(const vector &is_input_const); - vector GetIsInputConst() const; - - void SetOpInferDepends(const vector &depend_names); - vector GetOpInferDepends() const; - - string GetInputNameByIndex(uint32_t index) const; - - int GetInputIndexByName(const string &name) const; - - string GetOutputNameByIndex(uint32_t index) const; - - int GetOutputIndexByName(const string &name) const; - - graphStatus RestoreInputNameIdx(const string &name, const int &index); - - graphStatus RestoreOutputNameIdx(const string &name, const int &index); - - graphStatus CallInferFunc(Operator &op); - - void SetOpKernelLibName(const std::string &name); - - std::string GetOpKernelLibName() const; - - void SetOpEngineName(const std::string &name); - - std::string GetOpEngineName() const; - - void RegisterSubgraphIrName(const std::string &name, SubgraphType type); - const std::map &GetSubgraphIrNames() const; - SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; - - graphStatus AddSubgraphName(const std::string &name); - const std::map &GetSubgraphNameIndexes() const; - - std::string GetSubgraphInstanceName(uint32_t index) const; - const std::vector &GetSubgraphInstanceNames() const; - /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, - /// because this kind of functions will only append a new subgraph instance name - /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. - /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. - /// \param index - /// \param name - /// \return - graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); - void RemoveSubgraphInstanceName(const std::string &name); - - graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def); - bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; - bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; - bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; - - GeIrProtoHelper op_def_; - std::vector subgraph_instance_names_; - - // subgraph names to index, for a `if` operator: - // then_branch: 0 - // else_branch: 1 - // or for a `case` node: - // branches0: 0 - // branches1: 1 - // branches2: 2 - std::map subgraph_names_to_index_; - - // subgraph ir names to type, for a `if` operator: - // then_branch: static - // else_branch: static - // or for a `case` op: - // branches: dynamic - std::map subgraph_ir_names_to_type_; - - vector inputs_desc_{}; - map input_name_idx_{}; - vector register_input_name_{}; - std::unordered_set optional_input_names_{}; - vector outputs_desc_{}; - map output_name_idx_{}; - vector register_output_name_{}; - std::function infer_func_ = nullptr; - std::function infer_format_func_ = nullptr; - std::function verifier_func_ = nullptr; - string op_kernel_lib_name_; - string engine_name_; - friend class OpDescUtils; - friend class ModelSerializeImp; - friend class AttrUtils; - friend class GeAttrValueImp; - friend class OnnxUtils; -}; -} // namespace ge -#endif // INC_GRAPH_OP_DESC_H_ diff --git a/metadef/inc/graph/op_kernel_bin.h b/metadef/inc/graph/op_kernel_bin.h deleted file mode 100644 index 3970460a..00000000 --- a/metadef/inc/graph/op_kernel_bin.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OP_KERNEL_BIN_H_ -#define INC_GRAPH_OP_KERNEL_BIN_H_ - -#include -#include -#include -#include - -namespace ge { -class OpKernelBin { - public: - OpKernelBin(std::string name, std::vector &&data) : name_(std::move(name)), data_(std::move(data)) {} - - ~OpKernelBin() = default; - - const std::string &GetName() const { return name_; } - const uint8_t *GetBinData() const { return (const uint8_t *)data_.data(); } - size_t GetBinDataSize() const { return data_.size(); } - OpKernelBin(const OpKernelBin &) = delete; - const OpKernelBin &operator=(const OpKernelBin &) = delete; - - private: - std::string name_; - std::vector data_; -}; - -using OpKernelBinPtr = std::shared_ptr; -const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; -const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel"; -} // namespace ge - -#endif // INC_GRAPH_OP_KERNEL_BIN_H_ diff --git a/metadef/inc/graph/operator_factory_impl.h b/metadef/inc/graph/operator_factory_impl.h deleted file mode 100644 index ea343ebc..00000000 --- a/metadef/inc/graph/operator_factory_impl.h +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ -#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ - -#include -#include -#include -#include -#include "graph/operator_factory.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { - public: - static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); - - static graphStatus GetOpsTypeList(std::vector &all_ops); - - static bool IsExistOp(const string &operator_type); - - static InferShapeFunc GetInferShapeFunc(const std::string &operator_type); - - static InferFormatFunc GetInferFormatFunc(const std::string &operator_type); - - static VerifyFunc GetVerifyFunc(const std::string &operator_type); - - static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator); - - static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func); - - static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func); - - static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); - - static shared_ptr> operator_creators_; - static shared_ptr> operator_infershape_funcs_; - static shared_ptr> operator_inferformat_funcs_; - static shared_ptr> operator_verify_funcs_; -}; -} // namespace ge - -#endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ diff --git a/metadef/inc/graph/opsproto_manager.h b/metadef/inc/graph/opsproto_manager.h deleted file mode 100644 index 06846573..00000000 --- a/metadef/inc/graph/opsproto_manager.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ -#define INC_GRAPH_OPSPROTO_MANAGER_H_ - -#include -#include -#include -#include -#include -#include -#include - -namespace ge { -class OpsProtoManager { - public: - static OpsProtoManager *Instance(); - - bool Initialize(const std::map &options); - void Finalize(); - - private: - void LoadOpsProtoPluginSo(std::string &path); - - std::string pluginPath_; - std::vector handles_; - bool is_init_ = false; - std::mutex mutex_; -}; -} // namespace ge - -#endif // INC_GRAPH_OPSPROTO_MANAGER_H_ diff --git a/metadef/inc/graph/range_vistor.h b/metadef/inc/graph/range_vistor.h deleted file mode 100644 index 20905bd9..00000000 --- a/metadef/inc/graph/range_vistor.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_RANGE_VISTOR_H_ -#define INC_GRAPH_RANGE_VISTOR_H_ - -#include - -template -class RangeVistor { - public: - using Iterator = typename std::vector::iterator; - using ConstIterator = typename std::vector::const_iterator; - - RangeVistor(O owner, const std::vector &vs) : owner_(owner), elements_(vs) {} - - ~RangeVistor() {} - - Iterator begin() { return elements_.begin(); } - - Iterator end() { return elements_.end(); } - - ConstIterator begin() const { return elements_.begin(); } - - ConstIterator end() const { return elements_.end(); } - - std::size_t size() const { return elements_.size(); } - - bool empty() const { return elements_.empty(); } - - E &at(std::size_t index) { return elements_.at(index); } - - const E &at(std::size_t index) const { return elements_.at(index); } - - private: - O owner_; - std::vector elements_; -}; - -#endif // INC_GRAPH_RANGE_VISTOR_H_ diff --git a/metadef/inc/graph/ref_relation.h b/metadef/inc/graph/ref_relation.h deleted file mode 100644 index 71457916..00000000 --- a/metadef/inc/graph/ref_relation.h +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMMON_GRAPH_REF_RELATION_H_ -#define COMMON_GRAPH_REF_RELATION_H_ - -#include -#include -#include -#include - -#include "graph/compute_graph.h" -#include "graph/types.h" -#include "graph/ge_error_codes.h" -#include "node.h" - -namespace ge { -enum InOutFlag { - NODE_IN = 0, // input flag - NODE_OUT = 1, // output flag -}; - -struct RefCell { - std::string node_name; - ge::NodePtr node = nullptr; - InOutFlag in_out = NODE_IN; - int in_out_idx = 0; - - bool operator==(const RefCell &c) const { - return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; - } - - RefCell() = default; - RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { - node_name = name; - node = node_ptr; - in_out = in_out_flag; - in_out_idx = idx; - }; - ~RefCell() = default; -}; - -struct RefCellHash { - size_t operator()(const RefCell &c) const { - unsigned long number = reinterpret_cast(reinterpret_cast(c.node.get())); - string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); - return std::hash()(tmp); - } -}; - -class RefRelations { - public: - graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); - graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); - graphStatus Clear(); - - RefRelations(); - ~RefRelations() = default; - - public: - class Impl; - std::shared_ptr impl_ = nullptr; -}; - -} // namespace ge -#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/metadef/inc/graph/runtime_inference_context.h b/metadef/inc/graph/runtime_inference_context.h deleted file mode 100644 index 6c6c82e7..00000000 --- a/metadef/inc/graph/runtime_inference_context.h +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ -#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ - -#include -#include -#include -#include -#include "external/graph/ge_error_codes.h" -#include "external/graph/tensor.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { - public: - static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); - static graphStatus CreateContext(const std::string &context_id); - static void DestroyContext(const std::string &context_id); - - graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); - graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); - - private: - std::map> tensors_; - std::mutex mu_; - - static std::map> contexts_; - static std::mutex ctx_mu_; -}; -} // namespace ge - -#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/metadef/inc/graph/shape_refiner.h b/metadef/inc/graph/shape_refiner.h deleted file mode 100644 index 4f8783a3..00000000 --- a/metadef/inc/graph/shape_refiner.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_SHAPE_REFINER_H_ -#define INC_GRAPH_SHAPE_REFINER_H_ - -#include -#include "external/graph/inference_context.h" - -#include "external/graph/ge_error_codes.h" -#include "graph/node.h" - -namespace ge { -// ShapeRefiner performs shape inference for compute graphs -class ShapeRefiner { - public: - static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); - static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); - static graphStatus InferShapeAndType(const NodePtr &node); - static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); - static void ClearContextMap(); - - private: - static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); -}; -} // namespace ge -#endif // INC_GRAPH_SHAPE_REFINER_H_ diff --git a/metadef/inc/graph/tuning_utils.h b/metadef/inc/graph/tuning_utils.h deleted file mode 100644 index 98262a23..00000000 --- a/metadef/inc/graph/tuning_utils.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MAIN_TUNING_UTILS_H -#define MAIN_TUNING_UTILS_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "framework/common/debug/ge_log.h" -#include "utils/attr_utils.h" -#include "utils/node_utils.h" -#include "external/ge/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -namespace ge { -// Configure build mode, default value is "normal" -const char *const BUILD_MODE = "ge.buildMode"; -const char *const BUILD_STEP = "ge.buildStep"; -// Configure tuning path -const char *const TUNING_PATH = "ge.tuningPath"; -// for interface: aclgrphBuildModel -const std::set ir_builder_supported_options_for_lx_fusion = {BUILD_MODE, BUILD_STEP, TUNING_PATH}; - -// Build model -const char *const BUILD_MODE_NORMAL = "normal"; -const char *const BUILD_MODE_TUNING = "tuning"; -const char *const BUILD_MODE_BASELINE = "baseline"; -const std::set build_mode_options = {BUILD_MODE_NORMAL, BUILD_MODE_TUNING, BUILD_MODE_BASELINE}; - -// Build step -const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match"; -const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match"; -const char *const BUILD_STEP_AFTER_BUILDER = "after_builder"; -const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub"; -const char *const BUILD_STEP_AFTER_MERGE = "after_merge"; -const std::set build_step_options = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_UB_MATCH, - BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB, - BUILD_STEP_AFTER_MERGE}; - -using SubgraphCreateOutNode = std::unordered_map; -using NodetoNodeMap = std::unordered_map; -using NodeSet = std::set; -using NodeNametoNodeNameMap = std::unordered_map; -using NodetoNodeNameMap = std::unordered_map; -class TuningUtils { - public: - TuningUtils() = default; - ~TuningUtils() = default; - // Dump all the subgraphs and modify - // the subgraphs in them to be executable subgraphs if exe_flag is true - // `tuning_path` means path to save the graphs - static graphStatus ConvertGraphToFile(std::vector tuning_subgraphs, - std::vector non_tuning_subgraphs = {}, bool exe_flag = false, - const std::string &path = "", const std::string &user_path = ""); - // Recovery `graph` from graph dump files configured in options - static graphStatus ConvertFileToGraph(const map &options, ge::Graph &graph); - - private: - // part 1 - struct HelpInfo { - int64_t index; - bool exe_flag; - bool is_tuning_graph; - const std::string &path; - const std::string &user_path; - }; - static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info); - static graphStatus HandlePld(NodePtr &node); - static graphStatus HandleEnd(NodePtr &node); - static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node); - static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node); - static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node); - static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node); - static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node); - static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node); - static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node); - static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path); - - static SubgraphCreateOutNode create_output_; - // part 2 - static graphStatus MergeAllSubGraph(std::vector &graphs, ComputeGraphPtr &graph); - static graphStatus MergeSubGraph(ComputeGraphPtr &graph); - // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 - static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); - static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, - AnchorPtr &src_out_anchor); - static NodeNametoNodeNameMap data_2_netoutput_; - static NodetoNodeNameMap data_node_2_netoutput_; - static NodetoNodeMap data_node_2_netoutput_node_; - static NodeSet netoutput_nodes_; - static NodeSet merged_graph_nodes_; - static std::mutex mutex_; - // for debug - static std::string PrintCheckLog(); - static std::string GetNodeNameByAnchor(const Anchor *anchor); -}; -} // namespace ge -#endif // MAIN_TUNING_UTILS_H diff --git a/metadef/inc/graph/usr_types.h b/metadef/inc/graph/usr_types.h deleted file mode 100644 index 90e02001..00000000 --- a/metadef/inc/graph/usr_types.h +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_USR_TYPES_H_ -#define INC_GRAPH_USR_TYPES_H_ - -#include -#include -#include -namespace ge { -#define USR_TYPE_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - type *mutable_##name() { return &name; } - -#define USR_TYPE_HAS_DEC(type, name) \ - inline void set_##name(const type &value) { name = value; } \ - \ - private: \ - bool has_mutable_##name{false}; \ - \ - public: \ - bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ - type *mutable_##name() { \ - has_mutable_##name = true; \ - return &name; \ - } - -#define USR_TYPE_BYTES_DEC(name) \ - inline void clear_##name() { name.clear(); } \ - inline void set_##name(const void *value, size_t size) { \ - name.assign(reinterpret_cast(const_cast(value)), \ - reinterpret_cast(const_cast(value)) + size); \ - } - -enum UsrQuantizeScaleType { USR_VECTOR_SCALE = 0, USR_SCALAR_SCALE = 1 }; -enum UsrQuantizeScaleMode { USR_NORMAL_MODE = 0, USR_SQRT_MODE = 1 }; -enum UsrQuantizeAlgorithm { - USR_NON_OFFSET_ALGO = 0, - USR_HALF_OFFSET_ALGO = 1, - USR_ALL_OFFSET_ALGO = 2, -}; - -struct UsrQuantizeFactor { - public: - // QuantizeScaleMode scale_mode; - UsrQuantizeScaleMode scale_mode{USR_NORMAL_MODE}; - std::vector scale_value; - int64_t scale_offset{0}; - std::vector offset_data_value; - int64_t offset_data_offset{0}; - std::vector offset_weight_value; - int64_t offset_weight_offset{0}; - std::vector offset_pad_value; - int64_t offset_pad_offset{0}; - - USR_TYPE_DEC(UsrQuantizeScaleMode, scale_mode); - USR_TYPE_BYTES_DEC(scale_value); - - USR_TYPE_DEC(int64_t, scale_offset); - USR_TYPE_BYTES_DEC(offset_data_value); - USR_TYPE_DEC(int64_t, offset_data_offset); - - USR_TYPE_BYTES_DEC(offset_weight_value); - USR_TYPE_DEC(int64_t, offset_weight_offset); - USR_TYPE_BYTES_DEC(offset_pad_value); - USR_TYPE_DEC(int64_t, offset_pad_offset); -}; - -static inline bool QuantizeFactorHasData(const UsrQuantizeFactor &factor) { - return factor.scale_value.size() > 0 || factor.offset_data_value.size() > 0 || - factor.offset_weight_value.size() > 0 || factor.offset_pad_value.size() > 0; -} - -struct UsrQuantizeCalcFactor { - public: - std::vector offsetw; - int64_t offsetw_offset{0}; - std::vector offsetd; - int64_t offsetd_offset{0}; - std::vector scalereq; - int64_t scaledreq_offset{0}; - std::vector offsetdnext; - int64_t offsetdnext_offset{0}; - - USR_TYPE_BYTES_DEC(offsetw); - USR_TYPE_DEC(int64_t, offsetw_offset); - USR_TYPE_BYTES_DEC(offsetd); - USR_TYPE_DEC(int64_t, offsetd_offset); - USR_TYPE_BYTES_DEC(scalereq); - USR_TYPE_DEC(int64_t, scaledreq_offset); - USR_TYPE_BYTES_DEC(offsetdnext); - USR_TYPE_DEC(int64_t, offsetdnext_offset); -}; - -static inline bool QuantizeFactorHasData(const UsrQuantizeCalcFactor &factor) { - return factor.offsetw.size() > 0 || factor.offsetd.size() > 0 || factor.scalereq.size() > 0 || - factor.offsetdnext.size() > 0; -} - -struct UsrQuantizeFactorParams { - UsrQuantizeAlgorithm quantize_algo{USR_NON_OFFSET_ALGO}; - UsrQuantizeScaleType scale_type{USR_VECTOR_SCALE}; - UsrQuantizeFactor quantize_param; - UsrQuantizeFactor dequantize_param; - UsrQuantizeFactor requantize_param; - UsrQuantizeCalcFactor quantizecalc_param; - USR_TYPE_DEC(UsrQuantizeAlgorithm, quantize_algo); - USR_TYPE_DEC(UsrQuantizeScaleType, scale_type); - USR_TYPE_HAS_DEC(UsrQuantizeFactor, quantize_param); - USR_TYPE_HAS_DEC(UsrQuantizeFactor, dequantize_param); - USR_TYPE_HAS_DEC(UsrQuantizeFactor, requantize_param); - USR_TYPE_HAS_DEC(UsrQuantizeCalcFactor, quantizecalc_param); -}; - -#undef USR_TYPE_DEC -#undef USR_TYPE_HAS_DEC -#undef USR_TYPE_BYTES_DEC -} // namespace ge - -#endif // INC_GRAPH_USR_TYPES_H_ diff --git a/metadef/inc/graph/utils/anchor_utils.h b/metadef/inc/graph/utils/anchor_utils.h deleted file mode 100644 index 35b3b035..00000000 --- a/metadef/inc/graph/utils/anchor_utils.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ -#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ - -#include "graph/anchor.h" -#include "graph/node.h" - -namespace ge { -class AnchorUtils { - public: - // Get anchor format - static Format GetFormat(const DataAnchorPtr &dataAnchor); - - // Set anchor format - static graphStatus SetFormat(const DataAnchorPtr &dataAnchor, Format dataFormat); - - // Get anchor status - static AnchorStatus GetStatus(const DataAnchorPtr &dataAnchor); - - // Set anchor status - static graphStatus SetStatus(const DataAnchorPtr &dataAnchor, AnchorStatus anchorStatus); - - static bool HasControlEdge(const AnchorPtr &anchor); - - static bool IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst); - - static int GetIdx(const AnchorPtr &anchor); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ diff --git a/metadef/inc/graph/utils/attr_utils.h b/metadef/inc/graph/utils/attr_utils.h deleted file mode 100644 index 15a815d4..00000000 --- a/metadef/inc/graph/utils/attr_utils.h +++ /dev/null @@ -1,150 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ -#define INC_GRAPH_UTILS_ATTR_UTILS_H_ - -#include -#include -#include -#include "graph/detail/attributes_holder.h" -#include "graph/ge_attr_value.h" -#include "graph/types.h" - -namespace ge { -class OpDesc; -using OpDescPtr = std::shared_ptr; -using ConstOpDescPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { - public: - class ConstAttrHolderAdapter; - class AttrHolderAdapter; - // Set - static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name); - - static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value); - - static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value); - static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value); - static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value); - static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value); - static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value); - static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value); - static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, - std::initializer_list &&value); - static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value); - static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); - static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); - static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, - const vector &value); - static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); - - // Get - static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value); - static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value); - static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value); - static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value); - static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value); - static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value); - static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value); - static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value); - static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value); - static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value); - static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); - static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); - static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, - vector &value); - static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - // Value will be moved - static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); - static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); - // Value will be moved - static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector &listBuffer); - static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &listBuffer); - - static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector> &value); - static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector> &value); - - static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - - static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value); - static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value); - - static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc); - - static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); - - static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); - - class AttrHolderAdapter { - public: - AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} - ~AttrHolderAdapter() {} - template - AttrHolderAdapter(const std::shared_ptr &obj) : obj_(obj.get()) {} - AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} - operator bool() const { return obj_ != nullptr; } - AttrHolder *operator->() { return obj_; } - AttrHolder *get() { return obj_; } - - AttrHolder *obj_; - }; - - class ConstAttrHolderAdapter { - public: - ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {} - ~ConstAttrHolderAdapter() {} - template - ConstAttrHolderAdapter(const std::shared_ptr obj) : obj_(obj.get()) {} - ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} - operator bool() const { return obj_ != nullptr; } - const AttrHolder *operator->() const { return obj_; } - const AttrHolder *get() const { return obj_; } - - private: - const AttrHolder *obj_; - }; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_ diff --git a/metadef/inc/graph/utils/graph_utils.h b/metadef/inc/graph/utils/graph_utils.h deleted file mode 100644 index fdcbe1a9..00000000 --- a/metadef/inc/graph/utils/graph_utils.h +++ /dev/null @@ -1,771 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ -#define INC_GRAPH_UTILS_GRAPH_UTILS_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "graph/anchor.h" -#include "graph/node.h" -#include "graph/compute_graph.h" -#include "graph/utils/anchor_utils.h" -#include "graph/graph.h" -#include "graph/model.h" - -#define GE_DUMP(compute_graph, name) \ - do { \ - GraphUtils::DumpGEGraph(compute_graph, name); \ - GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ - uint64_t i = 0; \ - for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ - auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ - GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ - GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ - } \ - } while (0) - -#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ - do { \ - DataType ret; \ - attr.GetValue(ret); \ - } while (0) - -#define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ - do { \ - if (value_type == VT_ENUM) { \ - REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ - stream << ret; \ - } \ - } while (0) - -#define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ - do { \ - if (value_type == VT_ENUM) { \ - REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ - stream << "["; \ - for (int i = 0; i < ret.size(); i++) { \ - stream << ret[i]; \ - if (i + 1 != ret.size()) stream << ", "; \ - } \ - stream << "]"; \ - } \ - } while (0) - -#define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ - else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) - -#define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ - else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) - -#define PRINT_SHAPE(i_o, n, idx, stream) \ - do { \ - auto op = n->GetOpDesc(); \ - GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \ - auto shape = td.GetShape().GetDims(); \ - stream << "["; \ - for (int i = 0; i < shape.size(); i++) { \ - stream << shape[i]; \ - if (i + 1 < shape.size()) stream << ", "; \ - } \ - stream << "]"; \ - } while (0) - -#define PRINT_ATTR_FUNC(stream) \ - [&](GeAttrValue attr) { \ - auto type = attr.GetValueType(); \ - PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \ - PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \ - PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \ - PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \ - PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \ - else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \ - else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \ - else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \ - else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \ - else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \ - else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \ - }; - -namespace ge { -enum IOType { kIn, kOut }; - -struct NodeIndexIO { - NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) - : node_(std::move(node)), index_(index), io_type_(io_type) { - if (node_ != nullptr) { - value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); - } - } - NodeIndexIO(ge::NodePtr node, int index, IOType io_type) - : node_(std::move(node)), index_(static_cast(index)), io_type_(io_type) { - if (node_ != nullptr) { - value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); - } - } - ~NodeIndexIO() {} - - NodePtr node_ = nullptr; - uint32_t index_ = 0; - IOType io_type_ = kOut; - std::string value_; - - const std::string &ToString() const { return value_; } -}; - -class GraphUtils { - public: - static ComputeGraphPtr GetComputeGraph(const Graph &graph); - - static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); - - static graphStatus RecoverGraphOperators(const Graph &graph); - - static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector &inputs); - - static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); - - static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst, - const Format &dst_format); - - static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); - - static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); - - static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); - - // check whether src is link to dst and then remove - static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); - - static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); - - static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); - - static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); - - static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const InDataAnchorPtr &new_dst); - - static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, - const InControlAnchorPtr &new_dst); - - static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const NodePtr &new_node); - - static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); - - static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); - - static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, - const std::vector &vec_op_desc); - - /// - /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst - /// @param [in] src - /// @param [in] dsts - /// @param [in] insert_node - /// @param [in] input_index - /// @param [in] output_index - /// @return graphStatus - /// - static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); - - static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); - - static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); - - static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node); - - static void RecordOriginalNames(std::vector names_tmp, const ge::NodePtr &node); - - static bool MatchDumpStr(const std::string &suffix); - - static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, - const std::string &user_graph_name = ""); - - static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); - - static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); - - static void BreakConnect(const std::map &all_nodes_infos); - - static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); - - static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); - - static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); - - static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path); - - static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); - - /// - /// Isolating `node`, relinking data links from the in-anchor peer nodes to - /// the out-anchor peer nodes according to `io_map`, relinking control links - /// to ensure that input nodes of `node` are before out nodes - /// - /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer - /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0, - /// unlink all links from `i` output anchor without any relinking. - /// - /// @param node - /// @param io_map - /// @return - /// - static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list &io_map); - static graphStatus IsolateNode(const NodePtr &node, const std::vector &io_map); - - /// - /// Isolate `node` which must be one input one output, equivalent to - /// `IsolateNode(node, {0})` - /// @param node - /// @return - /// - static graphStatus IsolateNodeOneIO(const NodePtr &node); - - /// - /// The data anchors replacing behavior is the same with - /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control - /// anchors with `new_node`'s. - /// @param new_node - /// @param old_node - /// @param inputs_map - /// @param outputs_map - /// @return - /// - static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, - std::initializer_list inputs_map, std::initializer_list outputs_map); - - static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::vector &inputs_map, const std::vector &outputs_map); - - /// - /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`. - /// Replace the `i` in/out data anchor on `old_node` with - /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`. - /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in - /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain - /// on `old_node`. - /// @param new_node - /// @param old_node - /// @param inputs_map - /// @param outputs_map - /// @return - /// - static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - std::initializer_list inputs_map, - std::initializer_list outputs_map); - - static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::vector &inputs_map, const std::vector &outputs_map); - - /// - /// Copy all in-control edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return - /// - static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); - - static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); - - /// - /// Copy all out-control edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return success: GRAPH_SUCESS - /// - static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); - - /// - /// Move all out-control edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return success: GRAPH_SUCESS - /// - static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); - - /// - /// Copy all in-data edges from `src_node` to `dst_node` - /// @param src_node - /// @param dst_node - /// @return - /// - static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); - - static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); - - /// - /// Make a copy of ComputeGraph. - /// @param graph: original graph. - /// @param prefix: node name prefix of new graph. - /// @return ComputeGraphPtr - /// - static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, - std::vector &input_nodes, std::vector &output_nodes); - - /// - /// Copy tensor attribute to new node. - /// @param [in] dst_desc: cloned node. - /// @param [in] src_node: original node. - /// @return success: GRAPH_SUCESS - /// - static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); - - static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); - - /// - /// Get reference-mapping of all data_anchors in graph - /// @param [in] graph - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus GetRefMapping(const ComputeGraphPtr &graph, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs - /// of the graph have UNKNOWN_SHAPE operators or not. - /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following - /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE - /// ROOT graph: A -----> B -----> C - /// K subgraph U - /// | - /// V - /// SUB graph: D --> E --> F - /// K K K - /// @param [in] graph - /// @return bool - /// - static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); - - static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); - - private: - /// - /// Get reference-mapping for in_data_anchors of node - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleInAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Get reference-mapping for out_data_anchors of node - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleOutAnchorMapping(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Handle input of subgraph - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleSubgraphInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Handle input of Merge op - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleMergeInput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Handle output of subgraph - /// @param [in] node - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus HandleSubgraphOutput(const NodePtr &node, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Relink all edges for cloned ComputeGraph. - /// @param [in] node: original node. - /// @param [in] prefix: node name prefix of new node. - /// @param [in] all_nodes: all nodes in new graph. - /// @return success: GRAPH_SUCESS - /// - static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, - const std::unordered_map &all_nodes); - - /// - /// Union ref-mapping - /// @param [in] exist_node_info1 - /// @param [in] exist_node_info2 - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @param [out] symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol, std::string &symbol); - - /// - /// Update symbol mapping with a new reference pair - /// @param [in] cur_node_info - /// @param [in] exist_node_info - /// @param [out] symbol_to_anchors - /// @param [out] anchor_to_symbol - /// @return success: GRAPH_SUCESS - /// - static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, - std::map> &symbol_to_anchors, - std::map &anchor_to_symbol); - - /// - /// Check if out_data_anchor is reference of input - /// @param [in] out_data_anchor - /// @param [out] reuse_in_index - /// @return bool - /// - static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); -}; - -class ComputeGraphBuilder { - public: - ComputeGraphBuilder() : owner_graph_(nullptr) {} - ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; - ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; - ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; - ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; - ~ComputeGraphBuilder() = default; - - /// - /// @brief Add node to graph - /// @param [in] op_desc - /// @return ComputeGraphBuilder - /// - virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); - - /// - /// @brief Add data-link among nodes in graph - /// @param [in] src_name - /// @param [in] out_anchor_ind - /// @param [in] dst_name - /// @param [in] in_anchor_ind - /// @return ComputeGraphBuilder - /// - virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, - const std::string &dst_name, uint32_t in_anchor_ind); - - /// - /// @brief Add ctrl-link among nodes in graph - /// @param [in] src_name - /// @param [in] dst_name - /// @return ComputeGraphBuilder - /// - virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); - - /// - /// @brief Build graph - /// @param [out] error_code - /// @param [out] error_msg - /// @return ComputeGraphPtr - /// - virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; - - /// @brief Get node with name - /// @param [in] name - /// @return NodePtr - /// - NodePtr GetNode(const std::string &name); - - /// @brief Get all nodes - /// @return std::vector - /// - std::vector GetAllNodes(); - - protected: - /// - /// @brief Build nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildNodes(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Build data-links - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildDataLinks(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Build ctrl-links - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); - - ComputeGraphPtr owner_graph_; - - // node_name -> node - std::map node_names_; - std::vector nodes_; - - // -> - std::vector, std::pair>> data_links_; - // src_node_name -> dst_node_name - std::vector> ctrl_links_; -}; - -class CompleteGraphBuilder : public ComputeGraphBuilder { - public: - explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} - CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; - CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; - CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; - CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; - ~CompleteGraphBuilder() = default; - - /// - /// @brief Add node to graph - /// @param [in] op_desc - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; - - /// - /// @brief Add data-link among nodes in graph - /// @param [in] src_name - /// @param [in] out_anchor_ind - /// @param [in] dst_name - /// @param [in] in_anchor_ind - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, - uint32_t in_anchor_ind) override; - - /// - /// @brief Add ctrl-link among nodes in graph - /// @param [in] src_name - /// @param [in] dst_name - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; - - /// - /// @brief Set index_th input anchor for graph - /// @param [in] index - /// @param [in] node_names - /// @param [in] anchor_inds - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetInput(uint32_t index, const std::vector &node_names, - const std::vector &anchor_inds); - - /// - /// @brief Set index_th input of graph as useless - /// @param [in] index - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetUselessInput(uint32_t index); - - /// - /// @brief Add output anchor for graph - /// @param [in] owner_node_name - /// @param [in] anchor_ind - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); - - /// - /// @brief Add target for graph - /// @param [in] target_name - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &AddTarget(const std::string &target_name); - - /// - /// @brief Set parent-node of graph - /// @param [in] parent_node - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); - - /// - /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node - /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetInputMapping(const std::map &input_mapping); - - /// - /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind - /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node - /// @return CompleteGraphBuilder - /// - CompleteGraphBuilder &SetOutputMapping(const std::map &output_mapping); - - /// - /// @brief Build graph - /// @param [out] error_code - /// @param [out] error_msg - /// @return ComputeGraphPtr - /// - ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; - - private: - /// - /// @brief Add data nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void AddDataNodes(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Add data node - /// @param [in] index - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Add RetVal nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void AddRetValNodes(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Build target-nodes for graph - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); - - std::string name_; - NodePtr parent_node_; - std::map, std::vector>> graph_inputs_; - std::vector> graph_outputs_; - std::vector graph_targets_; - - // index_of_graph_input -> in_anchor_index_of_parent_node - std::map input_mapping_; - // index_of_graph_output -> out_anchor_index_of_parent_node - std::map output_mapping_; -}; - -class PartialGraphBuilder : public ComputeGraphBuilder { - public: - PartialGraphBuilder() = default; - PartialGraphBuilder(const PartialGraphBuilder &) = delete; - PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; - PartialGraphBuilder(const PartialGraphBuilder &&) = delete; - PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; - ~PartialGraphBuilder() = default; - - /// - /// @brief Add node to graph - /// @param [in] op_desc - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; - - /// - /// @brief Add data-link among nodes in graph - /// @param [in] src_name - /// @param [in] out_anchor_ind - /// @param [in] dst_name - /// @param [in] in_anchor_ind - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, - uint32_t in_anchor_ind) override; - - /// - /// @brief Add ctrl-link among nodes in graph - /// @param [in] src_name - /// @param [in] dst_name - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; - - /// - /// @brief Set owner graph - /// @param [in] graph - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); - - /// - /// @brief Add exist node - /// @param [in] node - /// @return PartialGraphBuilder - /// - PartialGraphBuilder &AddExistNode(const NodePtr &node); - - /// - /// @brief Build multi nodes with links - /// @param [out] error_code - /// @param [out] error_msg - /// @return ComputeGraphPtr - /// - ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; - - private: - /// - /// @brief Build exist nodes - /// @param [out] error_code - /// @param [out] error_msg - /// @return void - /// - void BuildExistNodes(graphStatus &error_code, std::string &error_msg); - - std::vector exist_nodes_; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/metadef/inc/graph/utils/node_utils.h b/metadef/inc/graph/utils/node_utils.h deleted file mode 100644 index bf57148d..00000000 --- a/metadef/inc/graph/utils/node_utils.h +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ -#define INC_GRAPH_UTILS_NODE_UTILS_H_ - -#include -#include -#include -#include "external/graph/operator.h" -#include "graph/node.h" - -namespace ge { -// Op types of Const like Opps. -extern const std::set kConstOpTypes; -// Op types of If like Opps. -extern const std::set kIfOpTypes; -// Op types of While like Opps. -extern const std::set kWhileOpTypes; -// Op types of Case like Opps. -extern const std::set kCaseOpTypes; -// Op types of For like Opps. -extern const std::set kForOpTypes; - -class NodeUtils { - public: - static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); - static graphStatus AddRecvEventId(const NodePtr &node, const uint32_t &event_id); - static graphStatus GetSendEventIdList(const NodePtr &node, std::vector &vec_send); - static graphStatus GetRecvEventIdList(const NodePtr &node, std::vector &vec_recv); - - static graphStatus ClearSendInfo(); - static graphStatus ClearRecvInfo(); - - static graphStatus GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst); - - static graphStatus GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, - InControlAnchorPtr &in_control); - - static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); - static graphStatus SetAllAnchorStatus(const NodePtr &nodePtr); - static graphStatus SetAllAnchorStatus(Node &node); - static bool IsAnchorStatusSet(const NodePtr &nodePtr); - static bool IsAnchorStatusSet(const Node &node); - - static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); - - static void UpdateIsInputConst(const NodePtr &nodePtr); - static void UpdateIsInputConst(Node &node); - static bool IsConst(const Node &node); - static void UnlinkAll(const Node &node); - static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); - - static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num); - static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num); - - static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num); - static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num); - - static bool IsInNodesEmpty(const Node &node); - static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); - static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); - static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); - static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); - // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; - // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, - // the out param "is_unknow" will be true too - static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); - - static std::string GetNodeType(const Node &node); - static std::string GetNodeType(const NodePtr &node); - - static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); - static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); - - /// - /// Check if node is input of subgraph - /// @param [in] node - /// @return bool - /// - static bool IsSubgraphInput(const NodePtr &node); - - /// - /// Check if node is output of subgraph - /// @param [in] node - /// @return bool - /// - static bool IsSubgraphOutput(const NodePtr &node); - - /// - /// @brief Get subgraph original input node. - /// @param [in] node - /// @return Node - /// - static NodePtr GetParentInput(const Node &node); - static NodePtr GetParentInput(const NodePtr &node); - - /// - /// @brief Get is dynamic shape graph from node. - /// @param [in] node - /// @return bool - /// - static bool IsDynamicShape(const Node &node); - static bool IsDynamicShape(const NodePtr &node); - - /// - /// @brief Check is varying_input for while node - /// @param [in] node: Data node for subgraph - /// @return bool - /// - static bool IsWhileVaryingInput(const ge::NodePtr &node); - - /// - /// @brief Get subgraph input is constant. - /// @param [in] node - /// @param [out] string - /// @return bool - /// - static bool GetConstOpType(const NodePtr &node, std::string &type); - - /// - /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. - /// @param [in] node - /// @return return GRAPH_SUCCESS if remove successfully, other for failed. - /// - static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); - - /// - /// @brief Get subgraph input data node by index. - /// @param [in] node - /// @return Node - /// - static vector GetSubgraphDataNodesByIndex(const Node &node, int index); - - /// - /// @brief Get subgraph input data node by index. - /// @param [in] node - /// @return Node - /// - static vector GetSubgraphOutputNodes(const Node &node); - - static NodePtr GetInDataNodeByIndex(const Node &node, const int index); - - static vector> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index); - - static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt); - - static graphStatus GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor); - - static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); - - private: - static std::map> map_send_info_; - static std::map> map_recv_info_; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_NODE_UTILS_H_ diff --git a/metadef/inc/graph/utils/op_desc_utils.h b/metadef/inc/graph/utils/op_desc_utils.h deleted file mode 100644 index 6a9a4695..00000000 --- a/metadef/inc/graph/utils/op_desc_utils.h +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ -#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ - -#include -#include -#include -#include "graph/def_types.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/range_vistor.h" - -namespace ge { -class OpDesc; -using OpDescPtr = std::shared_ptr; - -class OpDescUtils { - public: - template - using Vistor = RangeVistor>; - - OpDescUtils() = default; - ~OpDescUtils() = default; - static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); - static bool HasQuantizeFactorParams(const OpDesc& op_desc); - static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); - static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); - static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); - static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); - - static vector GetConstInputNode(const ge::Node& node); - static vector GetInputData(const vector& input_nodes); - - static vector GetWeights(const ge::Node& node); - static vector GetWeights(const ge::ConstNodePtr& node); - static vector MutableWeights(const ge::Node& node); - static vector MutableWeights(const ge::NodePtr node); - static graphStatus SetWeights(ge::Node& node, const vector& weights); - static graphStatus SetWeights(ge::NodePtr node, const vector& weights); - static graphStatus ClearWeights(ge::NodePtr node); - - static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); - static bool ClearInputDesc(const ge::NodePtr& node); - static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); - static bool ClearOutputDesc(const ge::NodePtr& node); - static vector GetConstInputs(const ge::Node& node); - static vector GetConstInputs(const ge::ConstNodePtr& node); - static size_t GetNonConstInputsSize(const ge::Node& node); - static size_t GetNonConstInputsSize(ge::ConstNodePtr node); - // Index: Indicates the index of all non const inputs - static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); - static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); - static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); - static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); - // Index: Indicates the index of all inputs - static bool IsNonConstInput(const ge::Node& node, size_t index = 0); - static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); - - static vector GetNonConstTensorDesc(const ge::ConstNodePtr& node); - static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); - - static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); - static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); - static OpDescPtr GetOpDescFromOperator(const Operator& oprt); - - static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); - - static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, - const std::string& subgraph_instance_name, OpDescPtr& op_desc); - - private: - static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); - static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); - static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); - static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); -}; - -class OpDescBuilder { - public: - OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} - OpDescBuilder(const OpDescBuilder&) = delete; - OpDescBuilder& operator=(const OpDescBuilder&) = delete; - OpDescBuilder(const OpDescBuilder&&) = delete; - OpDescBuilder& operator=(const OpDescBuilder&&) = delete; - ~OpDescBuilder() = default; - - /// - /// @brief Add input - /// @param [in] name - /// @return OpDescBuilder - /// - OpDescBuilder& AddInput(const std::string& name); - - /// - /// @brief Add input - /// @param [in] name - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); - - /// - /// @brief Add dynamic input - /// @param [in] name - /// @param [in] num - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); - - /// - /// @brief Add dynamic input - /// @param [in] name - /// @param [in] num - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); - - /// - /// @brief Add output - /// @param [in] name - /// @return OpDescBuilder - /// - OpDescBuilder& AddOutput(const std::string& name); - - /// - /// @brief Add output - /// @param [in] name - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); - - /// - /// @brief Add dynamic output - /// @param [in] name - /// @param [in] num - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); - - /// - /// @brief Add dynamic output - /// @param [in] name - /// @param [in] num - /// @param [in] tensor - /// @return OpDescBuilder - /// - OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); - - /// - /// @brief Build op_desc - /// @return OpDescPtr - /// - OpDescPtr Build(); - - private: - std::string name_; - std::string type_; - std::vector> inputs_; - std::vector> outputs_; -}; -} // namespace ge - -#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/metadef/inc/graph/utils/tensor_adapter.h b/metadef/inc/graph/utils/tensor_adapter.h deleted file mode 100644 index a7355553..00000000 --- a/metadef/inc/graph/utils/tensor_adapter.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ -#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ - -#include -#include "graph/ge_tensor.h" -#include "graph/tensor.h" - -namespace ge { -using GeTensorPtr = std::shared_ptr; -using ConstGeTensorPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { - public: - static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensorDesc); - static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &geTensorDesc); - static GeTensorPtr Tensor2GeTensor(const Tensor &tensor); - static Tensor GeTensor2Tensor(const ConstGeTensorPtr &geTensor); - - static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value - static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value - static const GeTensor AsGeTensor(const Tensor &tensor); // Share value - static GeTensor AsGeTensor(Tensor &tensor); // Share value - static const Tensor AsTensor(const GeTensor &tensor); // Share value - static Tensor AsTensor(GeTensor &tensor); // Share value -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ diff --git a/metadef/inc/graph/utils/tensor_utils.h b/metadef/inc/graph/utils/tensor_utils.h deleted file mode 100644 index caa80dcf..00000000 --- a/metadef/inc/graph/utils/tensor_utils.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ -#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ - -#include -#include "graph/def_types.h" -#include "graph/ge_error_codes.h" -#include "graph/ge_tensor.h" - -namespace ge { -class TensorUtils { - public: - static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); - static void SetSize(GeTensorDesc &tensorDesc, int64_t size); - static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); - static uint32_t GetWeightSize(const GeTensor &tensor); - static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); - static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensorPtr, uint8_t *base); - static uint8_t *GetWeightAddr(const GeTensor &tensor, uint8_t *base); - static void SetWeightSize(GeTensorDesc &tensorDesc, uint32_t size); - static ge::graphStatus GetReuseInput(const GeTensorDesc &tensorDesc, bool &flag); - static void SetReuseInput(GeTensorDesc &tensorDesc, bool flag); - static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensorDesc, bool &flag); - static void SetOutputTensor(GeTensorDesc &tensorDesc, bool flag); - static graphStatus GetDeviceType(const GeTensorDesc &tensorDesc, DeviceType &type); - static void SetDeviceType(GeTensorDesc &tensorDesc, DeviceType type); - static ge::graphStatus GetInputTensor(const GeTensorDesc &tensorDesc, bool &flag); - static void SetInputTensor(GeTensorDesc &tensorDesc, bool flag); - static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensorDesc, uint32_t &cnt); - static void SetRealDimCnt(GeTensorDesc &tensorDesc, uint32_t cnt); - static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensorDesc, uint32_t &idx); - static void SetReuseInputIndex(GeTensorDesc &tensorDesc, uint32_t idx); - static ge::graphStatus GetDataOffset(const GeTensorDesc &tensorDesc, int64_t &offset); - static void SetDataOffset(GeTensorDesc &tensorDesc, int64_t offset); - static ge::graphStatus GetCmpsSize(const GeTensorDesc &tensorDesc, uint32_t &cmp_size); - static void SetCmpsSize(GeTensorDesc &tensorDesc, uint32_t cmp_size); - static ge::graphStatus GetCmpsTab(const GeTensorDesc &tensorDesc, vector &vec); - static void SetCmpsTab(GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); - static ge::graphStatus GetCmpsTabOffset(const GeTensorDesc &tensorDesc, int64_t &tab_offset); - static void SetCmpsTabOffset(GeTensorDesc &tensorDesc, int64_t tab_offset); - static ge::graphStatus GetCmpsInfo(const GeTensorDesc &tensorDesc, CompressInfo &info); - static void SetCmpsInfo(GeTensorDesc &tensorDesc, const CompressInfo &info); - static bool HasAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc); - static ge::graphStatus GetAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc, AllOffsetQuantizeInfo &info); - static void SetAlloffsetQuantizeInfo(GeTensorDesc &tensorDesc, const AllOffsetQuantizeInfo &info); - static ge::graphStatus GetRC(const GeTensorDesc &tensorDesc, uint32_t &rc); - static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); - - /// - /// calculate tensor mem size. - /// @param shape tensor shape - /// @param format tensor format - /// @param data_type tensor data type - /// @param mem_size -1 means unknown shape,other means mem size - /// @return GRAPH_SUCCESS:success, other:failed - /// - static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); - static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); - static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ diff --git a/metadef/inc/graph/utils/type_utils.h b/metadef/inc/graph/utils/type_utils.h deleted file mode 100644 index 38509b9a..00000000 --- a/metadef/inc/graph/utils/type_utils.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_ -#define INC_GRAPH_UTILS_TYPE_UTILS_H_ - -#include -#include -#include -#include "graph/def_types.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" -#include "graph/usr_types.h" -#include "register/register_types.h" -#include "external/register/register_fmk_types.h" - -namespace ge { -class TypeUtils { - public: - static bool IsDataTypeValid(DataType dt); - static bool IsFormatValid(Format format); - static bool IsInternalFormat(Format format); - - static std::string ImplyTypeToSerialString(domi::ImplyType imply_type); - static std::string DataTypeToSerialString(DataType data_type); - static DataType SerialStringToDataType(const std::string &str); - static std::string FormatToSerialString(Format format); - static Format SerialStringToFormat(const std::string &str); - static Format DataFormatToFormat(const std::string &str); - static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); - static std::string FmkTypeToSerialString(domi::FrameworkType fmk_type); - - static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); - static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); - - static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length); - static bool CheckUint64MulOverflow(uint64_t a, uint32_t b); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TYPE_UTILS_H_ diff --git a/metadef/proto/dump_task.proto b/metadef/proto/dump_task.proto deleted file mode 100644 index ecdf4792..00000000 --- a/metadef/proto/dump_task.proto +++ /dev/null @@ -1,127 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package toolkit.dumpdata; - -enum OutputDataType { - DT_UNDEFINED = 0; - DT_FLOAT = 1; - DT_FLOAT16 = 2; - DT_INT8 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_UINT16 = 6; - DT_INT32 = 7; - DT_INT64 = 8; - DT_UINT32 = 9; - DT_UINT64 = 10; - DT_BOOL = 11; - DT_DOUBLE = 12; - DT_STRING = 13; - DT_DUAL_SUB_INT8 = 14; - DT_DUAL_SUB_UINT8 = 15; - DT_COMPLEX64 = 16; - DT_COMPLEX128 = 17; - DT_QINT8 = 18; - DT_QINT16 = 19; - DT_QINT32 = 20; - DT_QUINT8 = 21; - DT_QUINT16 = 22; - DT_RESOURCE = 23; - DT_STRING_REF = 24; - DT_DUAL = 25; -} - -enum OutputFormat { - FORMAT_NCHW = 0; - FORMAT_NHWC = 1; - FORMAT_ND = 2; - FORMAT_NC1HWC0 = 3; - FORMAT_FRACTAL_Z = 4; - FORMAT_NC1C0HWPAD = 5; - FORMAT_NHWC1C0 = 6; - FORMAT_FSR_NCHW = 7; - FORMAT_FRACTAL_DECONV = 8; - FORMAT_C1HWNC0 = 9; - FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; - FORMAT_NC1HWC0_C04 = 12; - FORMAT_FRACTAL_Z_C04 = 13; - FORMAT_CHWN = 14; - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; - FORMAT_HWCN = 16; - FORMAT_NC1KHKWHWC0 = 17; - FORMAT_BN_WEIGHT = 18; - FORMAT_FILTER_HWCK = 19; - FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; - FORMAT_HASHTABLE_LOOKUP_KEYS = 21; - FORMAT_HASHTABLE_LOOKUP_VALUE = 22; - FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; - FORMAT_HASHTABLE_LOOKUP_HITS=24; - FORMAT_C1HWNCoC0 = 25; - FORMAT_MD = 26; - FORMAT_NDHWC = 27; - FORMAT_FRACTAL_ZZ = 28; - FORMAT_FRACTAL_NZ = 29; - FORMAT_RESERVED = 30; -} - -message OriginalOp { - string name = 1; - uint32 output_index = 2; - OutputDataType data_type = 3; - OutputFormat format = 4; -} - -message Shape { - repeated uint64 dim = 1; -} - -message OpOutput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - OriginalOp original_op = 4; // the original op corresponding to the output - bytes data = 5; - uint64 size = 6; -} - -message OpInput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - bytes data = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - bytes data = 2; - uint64 size = 3; -} - -message DumpData{ - string version = 1; - uint64 dump_time = 2; - repeated OpOutput output = 3; - repeated OpInput input = 4; - repeated OpBuffer buffer = 5; -} diff --git a/metadef/proto/fusion_model.proto b/metadef/proto/fusion_model.proto deleted file mode 100644 index 6220963c..00000000 --- a/metadef/proto/fusion_model.proto +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -import "om.proto"; - -package domi; - -message FusionModelDef { - string version = 1; - repeated OpDef fusion_op = 2; -} \ No newline at end of file diff --git a/metadef/proto/fwk_adapter.proto b/metadef/proto/fwk_adapter.proto deleted file mode 100644 index 99333d2e..00000000 --- a/metadef/proto/fwk_adapter.proto +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package aicpu.FWKAdapter; -option cc_enable_arenas = true; - - -// Defines an struct for input and output. -message TensorDataInfo { - - // value DataType - uint32 dtype = 1; - - // shape dim - repeated int64 dim = 2; - - // data point addr - int64 data_addr = 3; -} - -message KernelRunParam { - // input - repeated TensorDataInfo input = 1; - // output - repeated TensorDataInfo output = 2; -} - diff --git a/metadef/proto/ge_api.proto b/metadef/proto/ge_api.proto deleted file mode 100644 index ac5b3b3a..00000000 --- a/metadef/proto/ge_api.proto +++ /dev/null @@ -1,104 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package ge.api_pb; - -import "ge_ir.proto"; - -// GE initialize -message GEInitialize { - map options = 1; -}; - -// initialize response -message GEInitializeResponse { - uint32 status = 1; - uint32 clientId = 2; -}; - -// GE finalize -message GEFinalize { - bool final = 1; - uint32 clientId = 2; -}; - -message GEFinalizeResponse { - uint32 status = 1; -}; - -// GE Session -message CreateSession{ - map options = 1; -}; - -message CreateSessionResponse { - uint32 status = 1; - uint64 sessionId = 2; -}; - -//GE AddGraph -//model serialize :: serializegraph -message SessionAddGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - ge.proto.GraphDef graph = 3; -}; - -message SessionAddGraphResponse { - uint32 status = 1; -}; - -//GE SessionRemoveGraph -message SessionRemoveGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; -}; - -message SessionRemoveGraphResponse { - uint32 status = 1; -}; - -message SessionRunGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; -}; - -message SessionBuildGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; - string savePath = 4; -}; - -message SessionRunGraphResponse { - uint32 status = 1; - repeated ge.proto.TensorDef tensor = 2; -}; - -message SessionBuildGraphResponse { - uint32 status = 1; -}; - -message DestroySession{ - bool final = 1; - uint64 sessionId = 2; -}; - -message DestroySessionResponse { - uint32 status = 1; -}; diff --git a/metadef/proto/ge_ir.proto b/metadef/proto/ge_ir.proto deleted file mode 100644 index 87886c84..00000000 --- a/metadef/proto/ge_ir.proto +++ /dev/null @@ -1,206 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package ge.proto; - -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name锛歩ndex - - map attr = 10; // Set of operator parameter fields - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition锛実raph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field -} - diff --git a/metadef/proto/insert_op.proto b/metadef/proto/insert_op.proto deleted file mode 100644 index a059e122..00000000 --- a/metadef/proto/insert_op.proto +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPP模式,区分静态AIPP和动态AIPP - AippMode aipp_mode = 1; - - // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 - // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 - uint32 related_input_rank = 2; - - // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 - // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 - // 配置值 <= Data算子输出边的个数。 - repeated uint32 input_edge_idx = 3; - - // [Begin] 动态AIPP参数,配置静态AIPP时无效 - uint32 max_src_image_size = 4; - - // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 - bool support_rotation = 5; - - // [End] 动态AIPP参数 - - - // [Begin] 静态AIPP参数,配置动态AIPP时无效 - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] 静态AIPP参数 - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //动态batch - resolution = 1; //动态分辨率,扩展用 - } - - MultiShapeMode mode = 1; //算子模式 - uint32 related_input_rank = 2; //新增算子插入到哪个输入 - - - repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 -} diff --git a/metadef/proto/om.proto b/metadef/proto/om.proto deleted file mode 100644 index dd992191..00000000 --- a/metadef/proto/om.proto +++ /dev/null @@ -1,401 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/metadef/proto/op_mapping_info.proto b/metadef/proto/op_mapping_info.proto deleted file mode 100644 index 7b84a115..00000000 --- a/metadef/proto/op_mapping_info.proto +++ /dev/null @@ -1,89 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message Task { - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; -} \ No newline at end of file diff --git a/metadef/proto/optimizer_priority.proto b/metadef/proto/optimizer_priority.proto deleted file mode 100644 index 3327be8a..00000000 --- a/metadef/proto/optimizer_priority.proto +++ /dev/null @@ -1,23 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; -package ge.optimizers; - -// Default: GE>FE>AICPU -message Priority{ - repeated string optimizer = 1; -} \ No newline at end of file diff --git a/metadef/proto/task.proto b/metadef/proto/task.proto deleted file mode 100644 index 50ea061b..00000000 --- a/metadef/proto/task.proto +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - repeated uint32 origin_op_index = 8; -} - - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/parser b/parser new file mode 160000 index 00000000..971fbfdf --- /dev/null +++ b/parser @@ -0,0 +1 @@ +Subproject commit 971fbfdf017ee197a9a2f1edc41167e825803a8f From b070cc498ccc4788a85f3ce575b33d50fd17d30c Mon Sep 17 00:00:00 2001 From: taoxiangdong Date: Fri, 9 Oct 2020 17:42:59 +0800 Subject: [PATCH 2/2] ge submodule metadef --- .gitmodules | 3 +++ metadef | 1 + 2 files changed, 4 insertions(+) create mode 160000 metadef diff --git a/.gitmodules b/.gitmodules index 4a36bfba..ea879278 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "parser"] path = parser url = https://gitee.com/ascend/parser.git +[submodule "metadef"] + path = metadef + url = https://gitee.com/ascend/metadef.git diff --git a/metadef b/metadef new file mode 160000 index 00000000..d097f7ce --- /dev/null +++ b/metadef @@ -0,0 +1 @@ +Subproject commit d097f7ce4e7a3ec449e13df78e26e8a76ca48cec