| @@ -0,0 +1,3 @@ | |||||
| [submodule "parser"] | |||||
| path = parser | |||||
| url = https://gitee.com/ascend/parser.git | |||||
| @@ -1,79 +0,0 @@ | |||||
| # Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # libgraph.so | |||||
| # compiling proto files generates some warnings, use no-unused-variable to suppress them | |||||
| set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||||
| # add all proto files, generate corresponding .h and .cc files | |||||
| file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "${GE_SOURCE_DIR}/metadef/proto/om.proto" | |||||
| "${GE_SOURCE_DIR}/metadef/proto/ge_ir.proto" | |||||
| "${GE_SOURCE_DIR}/metadef/proto/insert_op.proto" | |||||
| "${GE_SOURCE_DIR}/metadef/proto/task.proto" | |||||
| "${GE_SOURCE_DIR}/metadef/proto/fwk_adaper.proto" | |||||
| "${GE_SOURCE_DIR}/metadef/proto/op_mapping_info.proto" | |||||
| "${GE_SOURCE_DIR}/metadef/proto/dump_task.proto" | |||||
| ) | |||||
| file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "${onnx_INC}/onnx/onnx.proto" | |||||
| ) | |||||
| ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) | |||||
| # need to remove dependencies on pb files later | |||||
| file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "*.cc" | |||||
| "utils/*.cc" | |||||
| "opsproto/*.cc" | |||||
| "detail/*.cc" | |||||
| "debug/*.cc" | |||||
| "option/*.cc" | |||||
| ) | |||||
| # include directories | |||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||||
| include_directories(${GE_SOURCE_DIR}) | |||||
| #include_directories(${GE_SOURCE_DIR}/src) | |||||
| include_directories(${GE_SOURCE_DIR}/ge) | |||||
| include_directories(${GE_SOURCE_DIR}/metadef) | |||||
| include_directories(${GE_SOURCE_DIR}/metadef/graph) | |||||
| include_directories(${GE_SOURCE_DIR}/inc) | |||||
| include_directories(${GE_SOURCE_DIR}/inc/framework) | |||||
| include_directories(${GE_SOURCE_DIR}/inc/external) | |||||
| include_directories(${GE_SOURCE_DIR}/metadef/inc) | |||||
| include_directories(${GE_SOURCE_DIR}/metadef/inc/external/graph) | |||||
| include_directories(${GE_SOURCE_DIR}/metadef/inc/external) | |||||
| include_directories(${GE_SOURCE_DIR}/metadef/inc/graph) | |||||
| include_directories(${GE_SOURCE_DIR}/inc/common) | |||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||||
| include_directories(${CMAKE_BINARY_DIR}) | |||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
| include_directories(${GE_SOURCE_DIR}/build) | |||||
| ######### libgraph.so ############# | |||||
| add_library(graph SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_ONNX_SRCS}) | |||||
| target_compile_definitions(graph PRIVATE | |||||
| DAVINCI_CLOUD | |||||
| Werror) | |||||
| target_link_libraries(graph PRIVATE | |||||
| ${PROTOBUF_LIBRARY} | |||||
| ${c_sec} | |||||
| ${slog} | |||||
| ${error_manager} | |||||
| rt | |||||
| dl) | |||||
| @@ -1,371 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/anchor.h" | |||||
| #include <algorithm> | |||||
| #include <cstring> | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), idx_(idx) {} | |||||
| bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | |||||
| size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } | |||||
| Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | |||||
| vector<AnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| ret.push_back(anchor.lock()); | |||||
| } | |||||
| return Anchor::Vistor<AnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| AnchorPtr Anchor::GetFirstPeerAnchor() const { | |||||
| if (peer_anchors_.empty()) { | |||||
| return nullptr; | |||||
| } else { | |||||
| return Anchor::DynamicAnchorCast<Anchor>(peer_anchors_.begin()->lock()); | |||||
| } | |||||
| } | |||||
| NodePtr Anchor::GetOwnerNode() const { return owner_node_.lock(); } | |||||
| void Anchor::UnlinkAll() noexcept { | |||||
| if (!peer_anchors_.empty()) { | |||||
| do { | |||||
| auto peer_anchor_ptr = peer_anchors_.begin()->lock(); | |||||
| if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | |||||
| GELOGW("unlink peer_anchor_ptr failed."); | |||||
| } | |||||
| } while (!peer_anchors_.empty()); | |||||
| } | |||||
| } | |||||
| graphStatus Anchor::Unlink(const AnchorPtr &peer) { | |||||
| if (peer == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "peer anchor is invalid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr<Anchor> &an) { | |||||
| auto anchor = an.lock(); | |||||
| return peer->Equal(anchor); | |||||
| }); | |||||
| GE_IF_BOOL_EXEC(it == peer_anchors_.end(), GELOGW("this anchor is not connected to peer"); return GRAPH_FAILED); | |||||
| auto it_peer = | |||||
| std::find_if(peer->peer_anchors_.begin(), peer->peer_anchors_.end(), [this](const std::weak_ptr<Anchor> &an) { | |||||
| auto anchor = an.lock(); | |||||
| return Equal(anchor); | |||||
| }); | |||||
| GE_CHK_BOOL_RET_STATUS(it_peer != peer->peer_anchors_.end(), GRAPH_FAILED, "peer is not connected to this anchor"); | |||||
| (void)peer_anchors_.erase(it); | |||||
| (void)peer->peer_anchors_.erase(it_peer); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) { | |||||
| GE_CHK_BOOL_RET_STATUS(old_peer != nullptr, GRAPH_FAILED, "this old peer anchor is nullptr"); | |||||
| GE_CHK_BOOL_RET_STATUS(first_peer != nullptr, GRAPH_FAILED, "this first peer anchor is nullptr"); | |||||
| GE_CHK_BOOL_RET_STATUS(second_peer != nullptr, GRAPH_FAILED, "this second peer anchor is nullptr"); | |||||
| auto this_it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [old_peer](const std::weak_ptr<Anchor> &an) { | |||||
| auto anchor = an.lock(); | |||||
| return old_peer->Equal(anchor); | |||||
| }); | |||||
| GE_CHK_BOOL_RET_STATUS(this_it != peer_anchors_.end(), GRAPH_FAILED, "this anchor is not connected to old_peer"); | |||||
| auto old_it = std::find_if(old_peer->peer_anchors_.begin(), old_peer->peer_anchors_.end(), | |||||
| [this](const std::weak_ptr<Anchor> &an) { | |||||
| auto anchor = an.lock(); | |||||
| return Equal(anchor); | |||||
| }); | |||||
| GE_CHK_BOOL_RET_STATUS(old_it != old_peer->peer_anchors_.end(), GRAPH_FAILED, | |||||
| "old_peer is not connected to this anchor"); | |||||
| *this_it = first_peer; | |||||
| first_peer->peer_anchors_.push_back(shared_from_this()); | |||||
| *old_it = second_peer; | |||||
| second_peer->peer_anchors_.push_back(old_peer); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool Anchor::IsLinkedWith(const AnchorPtr &peer) { | |||||
| auto it = std::find_if(peer_anchors_.begin(), peer_anchors_.end(), [peer](const std::weak_ptr<Anchor> &an) { | |||||
| auto anchor = an.lock(); | |||||
| GE_CHK_BOOL_RET_STATUS(peer != nullptr, false, "this old peer anchor is nullptr"); | |||||
| return peer->Equal(anchor); | |||||
| }); | |||||
| return (it != peer_anchors_.end()); | |||||
| } | |||||
| int Anchor::GetIdx() const { return idx_; } | |||||
| void Anchor::SetIdx(int index) { idx_ = index; } | |||||
| DataAnchor::DataAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} | |||||
| bool DataAnchor::IsTypeOf(TYPE type) const { | |||||
| if (strcmp(Anchor::TypeOf<DataAnchor>(), type) == 0) { | |||||
| return true; | |||||
| } | |||||
| return Anchor::IsTypeOf(type); | |||||
| } | |||||
| InDataAnchor::InDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} | |||||
| OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const { | |||||
| if (peer_anchors_.empty()) { | |||||
| return nullptr; | |||||
| } else { | |||||
| return Anchor::DynamicAnchorCast<OutDataAnchor>(peer_anchors_.begin()->lock()); | |||||
| } | |||||
| } | |||||
| graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) { | |||||
| // InDataAnchor must be only linkfrom once | |||||
| if (src == nullptr || !peer_anchors_.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "src anchor is invalid or the peerAnchors is not empty."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| peer_anchors_.push_back(src); | |||||
| src->peer_anchors_.push_back(shared_from_this()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool InDataAnchor::Equal(AnchorPtr anchor) const { | |||||
| auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor); | |||||
| if (in_data_anchor != nullptr) { | |||||
| if (GetOwnerNode() == in_data_anchor->GetOwnerNode() && GetIdx() == in_data_anchor->GetIdx()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool InDataAnchor::IsTypeOf(TYPE type) const { | |||||
| if (strcmp(Anchor::TypeOf<InDataAnchor>(), type) == 0) { | |||||
| return true; | |||||
| } | |||||
| return DataAnchor::IsTypeOf(type); | |||||
| } | |||||
| OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, int idx) : DataAnchor(owner_node, idx) {} | |||||
| OutDataAnchor::Vistor<InDataAnchorPtr> OutDataAnchor::GetPeerInDataAnchors() const { | |||||
| vector<InDataAnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor.lock()); | |||||
| if (in_data_anchor != nullptr) { | |||||
| ret.push_back(in_data_anchor); | |||||
| } | |||||
| } | |||||
| return OutDataAnchor::Vistor<InDataAnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| uint32_t OutDataAnchor::GetPeerInDataNodesSize() const { | |||||
| uint32_t out_nums = 0; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor.lock()); | |||||
| if (in_data_anchor != nullptr && in_data_anchor->GetOwnerNode() != nullptr) { | |||||
| out_nums++; | |||||
| } | |||||
| } | |||||
| return out_nums; | |||||
| } | |||||
| OutDataAnchor::Vistor<InControlAnchorPtr> OutDataAnchor::GetPeerInControlAnchors() const { | |||||
| vector<InControlAnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(anchor.lock()); | |||||
| if (in_control_anchor != nullptr) { | |||||
| ret.push_back(in_control_anchor); | |||||
| } | |||||
| } | |||||
| return OutDataAnchor::Vistor<InControlAnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) { | |||||
| if (dest == nullptr || !dest->peer_anchors_.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "dest anchor is invalid or the peerAnchors is not empty."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| peer_anchors_.push_back(dest); | |||||
| dest->peer_anchors_.push_back(shared_from_this()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) { | |||||
| if (dest == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "dest anchor is invalid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| peer_anchors_.push_back(dest); | |||||
| dest->peer_anchors_.push_back(shared_from_this()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) { | |||||
| if (dest == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "dest anchor is invalid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| peer_anchors_.push_back(dest); | |||||
| dest->peer_anchors_.push_back(shared_from_this()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool OutDataAnchor::Equal(AnchorPtr anchor) const { | |||||
| CHECK_FALSE_EXEC(anchor != nullptr, return false); | |||||
| auto out_data_anchor = Anchor::DynamicAnchorCast<OutDataAnchor>(anchor); | |||||
| if (out_data_anchor != nullptr) { | |||||
| if (GetOwnerNode() == out_data_anchor->GetOwnerNode() && GetIdx() == out_data_anchor->GetIdx()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool OutDataAnchor::IsTypeOf(TYPE type) const { | |||||
| if (strcmp(Anchor::TypeOf<OutDataAnchor>(), type) == 0) { | |||||
| return true; | |||||
| } | |||||
| return DataAnchor::IsTypeOf(type); | |||||
| } | |||||
| ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {} | |||||
| ControlAnchor::ControlAnchor(const NodePtr &owner_node, int idx) : Anchor(owner_node, idx) {} | |||||
| bool ControlAnchor::IsTypeOf(TYPE type) const { | |||||
| if (strcmp(Anchor::TypeOf<ControlAnchor>(), type) == 0) { | |||||
| return true; | |||||
| } | |||||
| return Anchor::IsTypeOf(type); | |||||
| } | |||||
| InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} | |||||
| InControlAnchor::InControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} | |||||
| InControlAnchor::Vistor<OutControlAnchorPtr> InControlAnchor::GetPeerOutControlAnchors() const { | |||||
| vector<OutControlAnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto out_control_anchor = Anchor::DynamicAnchorCast<OutControlAnchor>(anchor.lock()); | |||||
| if (out_control_anchor != nullptr) { | |||||
| ret.push_back(out_control_anchor); | |||||
| } | |||||
| } | |||||
| return InControlAnchor::Vistor<OutControlAnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| InControlAnchor::Vistor<OutDataAnchorPtr> InControlAnchor::GetPeerOutDataAnchors() const { | |||||
| vector<OutDataAnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto out_data_anchor = Anchor::DynamicAnchorCast<OutDataAnchor>(anchor.lock()); | |||||
| if (out_data_anchor != nullptr) { | |||||
| ret.push_back(out_data_anchor); | |||||
| } | |||||
| } | |||||
| return InControlAnchor::Vistor<OutDataAnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) { | |||||
| if (src == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "src anchor is invalid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| peer_anchors_.push_back(src); | |||||
| src->peer_anchors_.push_back(shared_from_this()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool InControlAnchor::Equal(AnchorPtr anchor) const { | |||||
| CHECK_FALSE_EXEC(anchor != nullptr, return false); | |||||
| auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(anchor); | |||||
| if (in_control_anchor != nullptr) { | |||||
| if (GetOwnerNode() == in_control_anchor->GetOwnerNode()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool InControlAnchor::IsTypeOf(TYPE type) const { | |||||
| if (strcmp(Anchor::TypeOf<InControlAnchor>(), type) == 0) { | |||||
| return true; | |||||
| } | |||||
| return ControlAnchor::IsTypeOf(type); | |||||
| } | |||||
| OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} | |||||
| OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, int idx) : ControlAnchor(owner_node, idx) {} | |||||
| OutControlAnchor::Vistor<InControlAnchorPtr> OutControlAnchor::GetPeerInControlAnchors() const { | |||||
| vector<InControlAnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(anchor.lock()); | |||||
| if (in_control_anchor != nullptr) { | |||||
| ret.push_back(in_control_anchor); | |||||
| } | |||||
| } | |||||
| return OutControlAnchor::Vistor<InControlAnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| OutControlAnchor::Vistor<InDataAnchorPtr> OutControlAnchor::GetPeerInDataAnchors() const { | |||||
| vector<InDataAnchorPtr> ret; | |||||
| for (const auto &anchor : peer_anchors_) { | |||||
| auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(anchor.lock()); | |||||
| if (in_data_anchor != nullptr) { | |||||
| ret.push_back(in_data_anchor); | |||||
| } | |||||
| } | |||||
| return OutControlAnchor::Vistor<InDataAnchorPtr>(shared_from_this(), ret); | |||||
| } | |||||
| graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) { | |||||
| if (dest == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "dest anchor is invalid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| peer_anchors_.push_back(dest); | |||||
| dest->peer_anchors_.push_back(shared_from_this()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool OutControlAnchor::Equal(AnchorPtr anchor) const { | |||||
| auto out_control_anchor = Anchor::DynamicAnchorCast<OutControlAnchor>(anchor); | |||||
| if (out_control_anchor != nullptr) { | |||||
| if (GetOwnerNode() == out_control_anchor->GetOwnerNode()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool OutControlAnchor::IsTypeOf(TYPE type) const { | |||||
| if (strcmp(Anchor::TypeOf<OutControlAnchor>(), type) == 0) { | |||||
| return true; | |||||
| } | |||||
| return ControlAnchor::IsTypeOf(type); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,38 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "external/graph/attr_value.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| namespace ge { | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { impl = ComGraphMakeShared<AttrValueImpl>(); } | |||||
| #define ATTR_VALUE_SET_GET_IMP(type) \ | |||||
| graphStatus AttrValue::GetValue(type &val) const { \ | |||||
| if (impl != nullptr) { \ | |||||
| GELOGW("GetValue failed."); \ | |||||
| return impl->geAttrValue_.GetValue<type>(val); \ | |||||
| } \ | |||||
| return GRAPH_FAILED; \ | |||||
| } | |||||
| ATTR_VALUE_SET_GET_IMP(AttrValue::STR) | |||||
| ATTR_VALUE_SET_GET_IMP(AttrValue::INT) | |||||
| ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT) | |||||
| } // namespace ge | |||||
| @@ -1,112 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/buffer.h" | |||||
| #include "proto/ge_ir.pb.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| Buffer::Buffer() { | |||||
| data_.InitDefault(); | |||||
| if (data_.GetProtoMsg()) { | |||||
| buffer_ = data_.GetProtoMsg()->mutable_bt(); | |||||
| } | |||||
| } | |||||
| Buffer::Buffer(const Buffer &other) { | |||||
| // Share data | |||||
| data_ = other.data_; | |||||
| buffer_ = other.buffer_; | |||||
| } | |||||
| Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default | |||||
| auto proto_msg = data_.GetProtoMsg(); | |||||
| if (proto_msg != nullptr) { | |||||
| try { | |||||
| proto_msg->set_bt(std::string(buffer_size, default_val)); | |||||
| buffer_ = proto_msg->mutable_bt(); | |||||
| } catch (std::bad_alloc &e) { | |||||
| GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); | |||||
| buffer_ = nullptr; | |||||
| } | |||||
| } | |||||
| } | |||||
| Buffer Buffer::CopyFrom(const std::uint8_t *data, std::size_t buffer_size) { | |||||
| Buffer buffer; | |||||
| auto proto_msg = buffer.data_.GetProtoMsg(); | |||||
| if (proto_msg != nullptr && data != nullptr) { | |||||
| try { | |||||
| proto_msg->set_bt(data, buffer_size); | |||||
| buffer.buffer_ = proto_msg->mutable_bt(); | |||||
| } catch (std::bad_alloc &e) { | |||||
| GELOGE(MEMALLOC_FAILED, "Failed to alloc buffer memory, buffer size %zu", buffer_size); | |||||
| buffer.buffer_ = nullptr; | |||||
| } | |||||
| } | |||||
| return buffer; | |||||
| } | |||||
| Buffer::Buffer(const std::shared_ptr<google::protobuf::Message> &proto_owner, proto::AttrDef *buffer) | |||||
| : data_(proto_owner, buffer) { | |||||
| if (data_.GetProtoMsg() != nullptr) { | |||||
| buffer_ = data_.GetProtoMsg()->mutable_bt(); | |||||
| } | |||||
| } | |||||
| Buffer::Buffer(const std::shared_ptr<google::protobuf::Message> &proto_owner, std::string *buffer) | |||||
| : data_(proto_owner, nullptr) { | |||||
| buffer_ = buffer; | |||||
| } | |||||
| Buffer &Buffer::operator=(const Buffer &other) { | |||||
| if (&other != this) { | |||||
| // Share data | |||||
| data_ = other.data_; | |||||
| buffer_ = other.buffer_; | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| const std::uint8_t *Buffer::GetData() const { | |||||
| if (buffer_ != nullptr) { | |||||
| return (const std::uint8_t *)buffer_->data(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| std::uint8_t *Buffer::GetData() { | |||||
| if (buffer_ != nullptr && !buffer_->empty()) { | |||||
| // Avoid copy on write | |||||
| (void)(*buffer_)[0]; | |||||
| return reinterpret_cast<uint8_t *>(const_cast<char *>(buffer_->data())); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| std::size_t Buffer::GetSize() const { | |||||
| if (buffer_ != nullptr) { | |||||
| return buffer_->size(); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| void Buffer::ClearBuffer() { | |||||
| if (buffer_ != nullptr) { | |||||
| buffer_->clear(); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,147 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||||
| #define COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||||
| #define GE_LOGI_IF(condition, ...) \ | |||||
| if ((condition)) { \ | |||||
| GELOGI(__VA_ARGS__); \ | |||||
| } | |||||
| #define GE_LOGW_IF(condition, ...) \ | |||||
| if ((condition)) { \ | |||||
| GELOGW(__VA_ARGS__); \ | |||||
| } | |||||
| #define GE_LOGE_IF(condition, ...) \ | |||||
| if ((condition)) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| } | |||||
| #define GE_CHK_STATUS_RET_NOLOG(expr) \ | |||||
| do { \ | |||||
| const ge::graphStatus _status = (expr); \ | |||||
| if (ge::SUCCESS != _status) { \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||||
| do { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| #define GE_IF_BOOL_EXEC(expr, exec_expr) \ | |||||
| { \ | |||||
| if (expr) { \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| #define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ | |||||
| do { \ | |||||
| const ge::graphStatus _status = (expr); \ | |||||
| if (_status) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // If expr is true, the log is printed and a custom statement is executed | |||||
| #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (b) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| // Only check error log | |||||
| #define GE_CHK_BOOL_ONLY_LOG(expr, ...) \ | |||||
| do { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGI(__VA_ARGS__); \ | |||||
| } \ | |||||
| } while (0) | |||||
| // If expr is not true, do not print the log and return the specified status | |||||
| #define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | |||||
| do { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // If expr is not true, the log is printed and a custom statement is executed | |||||
| #define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| // If expr is not true, the log is printed and a custom statement is executed | |||||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGI(__VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| // If expr is not GRAPH_SUCCESS, print the log and return the same value | |||||
| #define GE_CHK_STATUS_RET(expr, ...) \ | |||||
| do { \ | |||||
| const ge::graphStatus _status = (expr); \ | |||||
| if (ge::SUCCESS != _status) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||||
| try { \ | |||||
| exec_expr0; \ | |||||
| } catch (...) { \ | |||||
| GELOGE(ge::FAILED, "Make shared failed"); \ | |||||
| exec_expr1; \ | |||||
| } | |||||
| #endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||||
| @@ -1,69 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||||
| #define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||||
| namespace ge { | |||||
| #define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name | |||||
| GE_REGISTER_OPTYPE(DATA, "Data"); | |||||
| GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | |||||
| GE_REGISTER_OPTYPE(MATMUL, "MatMul"); | |||||
| GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); | |||||
| GE_REGISTER_OPTYPE(PERMUTE, "Permute"); | |||||
| GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); | |||||
| GE_REGISTER_OPTYPE(_WHILE, "_While"); | |||||
| GE_REGISTER_OPTYPE(WHILE, "While"); | |||||
| GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); | |||||
| GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | |||||
| GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | |||||
| GE_REGISTER_OPTYPE(SWITCH, "Switch"); | |||||
| GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); | |||||
| GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); | |||||
| GE_REGISTER_OPTYPE(MERGE, "Merge"); | |||||
| GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||||
| GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||||
| GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||||
| GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||||
| GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | |||||
| GE_REGISTER_OPTYPE(CONSTANT, "Const"); | |||||
| GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||||
| GE_REGISTER_OPTYPE(END, "End"); | |||||
| GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | |||||
| GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | |||||
| GE_REGISTER_OPTYPE(INITDATA, "InitData"); | |||||
| GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); | |||||
| GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | |||||
| GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | |||||
| GE_REGISTER_OPTYPE(VARIABLE, "Variable"); | |||||
| GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); | |||||
| GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); | |||||
| // Horovod operator | |||||
| GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); | |||||
| GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); | |||||
| GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); | |||||
| GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); | |||||
| GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); | |||||
| GE_REGISTER_OPTYPE(RECV, "Recv"); | |||||
| GE_REGISTER_OPTYPE(SEND, "Send"); | |||||
| }; // namespace ge | |||||
| #endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||||
| @@ -1,274 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_ | |||||
| #define COMMON_GRAPH_DEBUG_GE_UTIL_H_ | |||||
| #include <limits.h> | |||||
| #include <math.h> | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/debug/ge_log.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #define GE_DYNAMIC_CAST dynamic_cast | |||||
| #define GE_DYNAMIC_POINTER_CAST std::dynamic_pointer_cast | |||||
| #else | |||||
| #define GE_DYNAMIC_CAST static_cast | |||||
| #define GE_DYNAMIC_POINTER_CAST std::static_pointer_cast | |||||
| #endif | |||||
| #define GE_RETURN_IF_ERROR(expr) \ | |||||
| do { \ | |||||
| const ::ge::optStatus _status = (expr); \ | |||||
| if (_status) return _status; \ | |||||
| } while (0) | |||||
| #define GE_RETURN_WITH_LOG_IF_INFO(expr, ...) \ | |||||
| do { \ | |||||
| const ::ge::optStatus _status = (expr); \ | |||||
| if (_status) { \ | |||||
| GELOGI(__VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is true. If yes, return graph failed and record the error log | |||||
| #define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | |||||
| do { \ | |||||
| if (condition) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| return ge::GRAPH_FAILED; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is false. If yes, return graph failed and record the error log | |||||
| #define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ | |||||
| do { \ | |||||
| bool _condition = (condition); \ | |||||
| if (!_condition) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| return ge::GRAPH_FAILED; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is true. If yes, return GRAPH_PARAM_INVALID and record the error log | |||||
| #define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | |||||
| do { \ | |||||
| if (condition) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is false. If yes, return GRAPH_PARAM_INVALID and record the error log | |||||
| #define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ | |||||
| do { \ | |||||
| bool _condition = (condition); \ | |||||
| if (!_condition) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log | |||||
| #define GE_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is null. If yes, return GRAPH_PARAM_INVALID and record the error log | |||||
| #define GE_CHECK_NOTNULL_EXEC(val, expr) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] must not be null.", #val); \ | |||||
| expr; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify whether the parameter is null. If yes, return false and record the error log | |||||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, "param[%s] must not be null.", #val); \ | |||||
| return false; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Check whether the parameter is out of range | |||||
| #define GE_CHECK_SIZE(size) \ | |||||
| do { \ | |||||
| if (size == 0) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| /// | |||||
| /// @ingroup GE_common | |||||
| /// eg:GE_DEFINE_BYTE_SIZE(filter_byte, filter.data().size(), sizeof(float)); | |||||
| /// | |||||
| #define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
| uint32_t _var_name; \ | |||||
| do { \ | |||||
| uint32_t _expr_size = (_expr); \ | |||||
| uint32_t _sizeof_size = (_sizeof); \ | |||||
| if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "byte size : %s is out of range", #_var_name); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| _var_name = _sizeof_size * _expr_size; \ | |||||
| } while (0); | |||||
| // Check whether the container is empty | |||||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
| do { \ | |||||
| if (vector.empty()) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, "param[#vector] is empty", #vector); \ | |||||
| return ge::GRAPH_FAILED; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Check whether the container is empty and return the specified status code | |||||
| #define GE_CHECK_VECTOR_NOT_EMPTY_RET_STATUS(vector, _status) \ | |||||
| do { \ | |||||
| if (vector.empty()) { \ | |||||
| GELOGE(_status, "param[%s] is empty", #vector); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| /// | |||||
| /// @ingroup GE_common | |||||
| /// @brief This macro provides the ability to disable copying constructors and assignment operators. | |||||
| /// It is usually placed under private | |||||
| /// | |||||
| #define GE_DISALLOW_COPY_AND_ASSIGN(TypeName) \ | |||||
| TypeName(const TypeName &) = delete; \ | |||||
| void operator=(const TypeName &) = delete | |||||
| /// Check whether the size is 0 or out of range | |||||
| /// @param:size:Size to be verified | |||||
| #define GE_CHECK_SIZE_RANGE(size) \ | |||||
| do { \ | |||||
| if (size == 0 || size >= UINT_MAX / 4) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_CHECK_SHORT_SIZE_RANGE(size) \ | |||||
| do { \ | |||||
| if (size == 0 || size >= UINT_MAX / 2) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
| do { \ | |||||
| if (size <= 0) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not a positive number", #size); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_CHECK_POSITIVE_SHORT_SIZE_RANGE(size) \ | |||||
| do { \ | |||||
| if (size <= 0 || size == 0 || size >= UINT_MAX / 4) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is out of range", #size); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify that the value on the left is greater than or equal to the value on the right | |||||
| #define GE_CHECK_GE(lhs, rhs) \ | |||||
| do { \ | |||||
| if (lhs < rhs) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is less than[%s]", #lhs, #rhs); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Check whether the parameters are equal | |||||
| #define GE_CHECK_EQ(val1, val2) \ | |||||
| do { \ | |||||
| if (val1 != val2) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is not equals to[%s]", #val1, #val2); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Verify that the value on the left is less than or equal to the value on the right | |||||
| #define GE_CHECK_LE(lhs, rhs) \ | |||||
| do { \ | |||||
| if (lhs > rhs) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, "param[%s] is greater than[%s]", #lhs, #rhs); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // Check whether the parameters are equal | |||||
| #define GE_CHECK_EQ_WITH_LOG(val1, val2, ...) \ | |||||
| do { \ | |||||
| if (val1 != val2) { \ | |||||
| GELOGE(ge::GRAPH_PARAM_INVALID, __VA_ARGS__); \ | |||||
| return ge::GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| // If expr is false, the custom statement is executed | |||||
| #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | |||||
| do { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_DELETE_NEW_SINGLE(var) \ | |||||
| do { \ | |||||
| if (var != nullptr) { \ | |||||
| delete var; \ | |||||
| var = nullptr; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_DELETE_NEW_ARRAY(var) \ | |||||
| do { \ | |||||
| if (var != nullptr) { \ | |||||
| delete[] var; \ | |||||
| var = nullptr; \ | |||||
| } \ | |||||
| } while (0) | |||||
| template <typename T, typename... Args> | |||||
| static inline std::shared_ptr<T> ComGraphMakeShared(Args &&... args) { | |||||
| using T_nc = typename std::remove_const<T>::type; | |||||
| std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...)); | |||||
| return ret; | |||||
| } | |||||
| #endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_ | |||||
| @@ -1,246 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/debug/graph_debug.h" | |||||
| #include <algorithm> | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #define TAB " " | |||||
| #define STR_FMT(str) (" \"" + std::string(str) + "\" ") | |||||
| #define INPUT_ANCHOR_PORT(name) ("__input__" + (name)) | |||||
| #define OUTPUT_ANCHOR_PORT(name) ("__output__" + (name)) | |||||
| namespace ge { | |||||
| std::unordered_set<std::string> control_anchor; | |||||
| std::vector<string> types = { | |||||
| "DT_FLOAT", "DT_FLOAT16", "DT_INT8", "DT_INT32", "DT_UINT8", "", | |||||
| "DT_INT16", "DT_UINT16", "DT_UINT32", "DT_INT64", "DT_UINT64", "DT_DOUBLE", | |||||
| "DT_BOOL", "DT_DUAL", "DT_DUAL_SUB_INT8", "DT_DUAL_SUB_UINT8", "DT_UNDEFINED"}; | |||||
| std::vector<string> formats = {"FORMAT_NCHW", | |||||
| "FORMAT_NHWC", | |||||
| "FORMAT_ND", | |||||
| "FORMAT_NC1HWC0", | |||||
| "FORMAT_FRACTAL_Z", | |||||
| "FORMAT_NC1C0HWPAD", | |||||
| "FORMAT_NHWC1C0", | |||||
| "FORMAT_FSR_NCHW", | |||||
| "FORMAT_FRACTAL_DECONV", | |||||
| "FORMAT_C1HWNC0", | |||||
| "FORMAT_FRACTAL_DECONV_TRANSPOSE", | |||||
| "FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS", | |||||
| "FORMAT_NC1HWC0_C04", | |||||
| "FORMAT_FRACTAL_Z_C04", | |||||
| "FORMAT_CHWN", | |||||
| "FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS", | |||||
| "FORMAT_HWCN", | |||||
| "FORMAT_NC1KHKWHWC0", | |||||
| "FORMAT_BN_WEIGHT", | |||||
| "FORMAT_FILTER_HWCK", | |||||
| "FORMAT_HASHTABLE_LOOKUP_LOOKUPS", | |||||
| "FORMAT_HASHTABLE_LOOKUP_KEYS", | |||||
| "FORMAT_HASHTABLE_LOOKUP_VALUE", | |||||
| "FORMAT_HASHTABLE_LOOKUP_OUTPUT", | |||||
| "FORMAT_HASHTABLE_LOOKUP_HITS", | |||||
| "FORMAT_RESERVED"}; | |||||
| std::vector<string> data_nodes = {"Const", "Data"}; | |||||
| void GraphDebugPrinter::DumpNodeToDot(const NodePtr node, std::ostringstream &out_) { | |||||
| if (node == nullptr) { | |||||
| GELOGI("Some nodes are null."); | |||||
| return; | |||||
| } | |||||
| bool in_control = false; | |||||
| auto name = node->GetName(); | |||||
| out_ << TAB << STR_FMT(name); | |||||
| auto input_cnt = std::max(static_cast<size_t>(1), node->GetAllInDataAnchors().size()); | |||||
| auto output_cnt = std::max(static_cast<size_t>(1), node->GetAllOutDataAnchors().size()); | |||||
| if (control_anchor.find(node->GetName()) != control_anchor.end()) { | |||||
| input_cnt++; | |||||
| in_control = true; | |||||
| } | |||||
| auto max_col = input_cnt * output_cnt; | |||||
| out_ << "[\n"; | |||||
| if (find(data_nodes.begin(), data_nodes.end(), node->GetType()) != data_nodes.end()) { | |||||
| out_ << TAB << TAB << "shape=plaintext, color=goldenrod\n"; | |||||
| } else { | |||||
| out_ << TAB << TAB << "shape=plaintext, color=deepskyblue\n"; | |||||
| } | |||||
| out_ << TAB << TAB << "label=<\n"; | |||||
| out_ << TAB << TAB << R"(<table border="0" cellborder="1" align="center")" | |||||
| << ">" << std::endl; | |||||
| auto input_anchors = node->GetAllInDataAnchors(); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return ); | |||||
| if (!input_anchors.empty()) { | |||||
| out_ << TAB << TAB << "<tr>"; | |||||
| } | |||||
| for (const auto &anchor : input_anchors) { | |||||
| string anchor_text = op_desc->GetInputNameByIndex(anchor->GetIdx()); | |||||
| out_ << "<td port = " << STR_FMT(INPUT_ANCHOR_PORT(anchor_text)) << " colspan='" << output_cnt << "'>" | |||||
| << anchor_text << "</td>"; | |||||
| } | |||||
| if (in_control) { | |||||
| string anchor_text = "ctrl"; | |||||
| out_ << "<td port = " << STR_FMT(INPUT_ANCHOR_PORT(anchor_text)) << " colspan='" << output_cnt << "'>" | |||||
| << anchor_text << "</td>"; | |||||
| } | |||||
| if (!input_anchors.empty()) { | |||||
| out_ << "</tr>\n"; | |||||
| } | |||||
| // Node type | |||||
| out_ << TAB << TAB << "<tr><td colspan='" << max_col << "'>" | |||||
| << "<b>" << node->GetType() << "</b></td></tr>\n"; | |||||
| // Output | |||||
| auto output_anchors = node->GetAllOutDataAnchors(); | |||||
| if (!output_anchors.empty()) { | |||||
| out_ << TAB << TAB << "<tr>"; | |||||
| } | |||||
| for (const auto &anchor : output_anchors) { | |||||
| string anchor_text = op_desc->GetOutputNameByIndex(anchor->GetIdx()); | |||||
| out_ << "<td port = " << STR_FMT(OUTPUT_ANCHOR_PORT(anchor_text)) << " colspan='" << input_cnt << "'>" | |||||
| << anchor_text << "</td>"; | |||||
| } | |||||
| if (!output_anchors.empty()) { | |||||
| out_ << "</tr>\n"; | |||||
| } | |||||
| out_ << TAB << TAB << "</table>\n" << TAB << ">];\n"; | |||||
| } | |||||
| void GraphDebugPrinter::DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag) { | |||||
| if (node == nullptr) { | |||||
| GELOGI("Some nodes are null."); | |||||
| return; | |||||
| } | |||||
| auto all_out_anchor = node->GetAllOutDataAnchors(); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return ); | |||||
| for (const auto &anchor : all_out_anchor) { | |||||
| auto src_anchor = anchor; | |||||
| auto src_node_name = node->GetName(); | |||||
| auto src_anchor_index = op_desc->GetOutputNameByIndex(static_cast<uint32_t>(src_anchor->GetIdx())); | |||||
| auto des_anchors = anchor->GetPeerAnchors(); | |||||
| for (const auto &peer_in_anchor : des_anchors) { | |||||
| auto in_data_anchor = Anchor::DynamicAnchorCast<InDataAnchor>(peer_in_anchor); | |||||
| std::string dst_node_name; | |||||
| out_ << TAB << STR_FMT(src_node_name); | |||||
| out_ << ":" << OUTPUT_ANCHOR_PORT(src_anchor_index); | |||||
| auto op = peer_in_anchor->GetOwnerNode()->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(op, continue); | |||||
| if (in_data_anchor != nullptr) { | |||||
| dst_node_name = in_data_anchor->GetOwnerNode()->GetName(); | |||||
| string des_anchor_index = op->GetInputNameByIndex(static_cast<uint32_t>(in_data_anchor->GetIdx())); | |||||
| out_ << " -> " << STR_FMT(dst_node_name); | |||||
| out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); | |||||
| out_ << "["; | |||||
| } | |||||
| auto in_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(peer_in_anchor); | |||||
| if (in_control_anchor != nullptr) { | |||||
| dst_node_name = in_control_anchor->GetOwnerNode()->GetName(); | |||||
| string des_anchor_index = "ctrl"; | |||||
| out_ << " -> " << STR_FMT(dst_node_name); | |||||
| out_ << ":" << INPUT_ANCHOR_PORT(des_anchor_index); | |||||
| out_ << "["; | |||||
| out_ << " style=dashed "; | |||||
| } | |||||
| if (flag != DOT_NOT_SHOW_EDGE_LABEL && in_data_anchor) { | |||||
| string label; | |||||
| auto src_ops = src_anchor->GetOwnerNode()->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(src_ops, return ); | |||||
| auto src_shape = src_ops->GetOutputDesc(src_anchor->GetIdx()).GetShape(); | |||||
| auto dim = src_shape.GetDims(); | |||||
| std::ostringstream tensor_info; | |||||
| if (dim.size() > 0) { | |||||
| for (size_t i = 0; i < dim.size(); i++) { | |||||
| if (i != dim.size() - 1) { | |||||
| tensor_info << dim[i] << "x"; | |||||
| } else { | |||||
| tensor_info << dim[i]; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| tensor_info << "?"; | |||||
| } | |||||
| auto src_tensor_desc = src_ops->GetOutputDescPtr(src_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL_EXEC(src_tensor_desc, return ); | |||||
| auto format = src_tensor_desc->GetFormat(); | |||||
| auto datatype = src_tensor_desc->GetDataType(); | |||||
| tensor_info << " : " << formats[format] << " : " << types[datatype]; | |||||
| label = tensor_info.str(); | |||||
| out_ << "label=" << STR_FMT(label); | |||||
| } | |||||
| out_ << "]" << std::endl; | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus GraphDebugPrinter::DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, | |||||
| uint32_t flag) { | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| if (compute_graph == nullptr) { | |||||
| GELOGI("Compute graph is NULL ."); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return DumpGraphDotFile(compute_graph, output_dot_file_name, flag); | |||||
| } | |||||
| graphStatus GraphDebugPrinter::DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, | |||||
| uint32_t flag) { | |||||
| if (graph == nullptr) { | |||||
| GELOGI("graph is null."); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::ostringstream out_; | |||||
| out_ << "digraph G{\n"; | |||||
| out_ << TAB << R"(ratio=compress;size="8, 100")" << std::endl; | |||||
| out_ << TAB << R"(node[fontname="Consolas"])" << std::endl; | |||||
| out_ << TAB << R"(edge[fontsize = "8" fontname = "Consolas" color="dimgray" ])" << std::endl; | |||||
| auto all_nodes = graph->GetAllNodes(); | |||||
| for (const auto &node : all_nodes) { | |||||
| for (const auto &temp : node->GetAllOutDataAnchors()) { | |||||
| for (const auto &peer : temp->GetPeerAnchors()) { | |||||
| auto temp_control_anchor = Anchor::DynamicAnchorCast<InControlAnchor>(peer); | |||||
| if (temp_control_anchor) { | |||||
| (void)control_anchor.insert(peer->GetOwnerNode()->GetName()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &node : all_nodes) { | |||||
| DumpNodeToDot(node, out_); | |||||
| } | |||||
| for (const auto &node : all_nodes) { | |||||
| DumpEdgeToDot(node, out_, flag); | |||||
| } | |||||
| out_ << "}"; | |||||
| std::ofstream output_file(output_dot_file_name); | |||||
| if (output_file.is_open()) { | |||||
| output_file << out_.str(); | |||||
| } else { | |||||
| GELOGW("%s open error.", output_dot_file_name.c_str()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||||
| #define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||||
| #include <cstdint> | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include "external/graph/graph.h" | |||||
| #include "./ge_error_codes.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/debug/ge_log.h" | |||||
| #include "graph/node.h" | |||||
| #include "utils/graph_utils.h" | |||||
| namespace ge { | |||||
| enum DotFileFlag { | |||||
| // Show nodes, edges, size, type and format | |||||
| DOT_FLAG_DEFAULT = 0, | |||||
| DOT_NOT_SHOW_EDGE_LABEL = 1, | |||||
| }; | |||||
| class GraphDebugPrinter { | |||||
| public: | |||||
| static graphStatus DumpGraphDotFile(const Graph &graph, const std::string &output_dot_file_name, | |||||
| uint32_t flag = DOT_FLAG_DEFAULT); | |||||
| static graphStatus DumpGraphDotFile(const ComputeGraphPtr graph, const std::string &output_dot_file_name, | |||||
| uint32_t flag = DOT_FLAG_DEFAULT); | |||||
| static void DumpNodeToDot(const NodePtr node, std::ostringstream &out_); | |||||
| static void DumpEdgeToDot(const NodePtr node, std::ostringstream &out_, uint32_t flag = DOT_FLAG_DEFAULT); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||||
| @@ -1,241 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "detail/attributes_holder.h" | |||||
| #include <map> | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "proto/ge_ir.pb.h" | |||||
| namespace ge { | |||||
| using std::map; | |||||
| using std::unordered_set; | |||||
| void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { MutableAttrMap().CopyValueFrom(holder.GetAttrMap()); } | |||||
| graphStatus AttrHolder::SetAttr(const std::string &name, const GeAttrValue &value) { | |||||
| if (value.IsEmpty()) { | |||||
| GELOGE(GRAPH_FAILED, "value is empty, key %s", name.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto proto_map = MutableAttrMap().GetProtoMsg(); | |||||
| auto proto_val = value.value_.GetProtoMsg(); | |||||
| if (proto_map == nullptr || proto_val == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto it = proto_map->find(name); | |||||
| if (it != proto_map->end()) { | |||||
| if (it->second.value_case() != proto::AttrDef::VALUE_NOT_SET && | |||||
| it->second.value_case() != proto_val->value_case()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| (*proto_map)[name] = *proto_val; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus AttrHolder::AddRequiredAttr(const std::string &name) { | |||||
| if (HasAttr(name)) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| requiredAttrs_.push_back(name); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus AttrHolder::GetAttr(const std::string &name, GeAttrValue &value) const { | |||||
| auto proto_map = GetAttrMap().GetProtoMsg(); | |||||
| auto proto_val = value.value_.GetProtoMsg(); | |||||
| if (proto_map == nullptr || proto_val == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto it = proto_map->find(name); | |||||
| if (it != proto_map->end()) { | |||||
| *proto_val = it->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| bool AttrHolder::HasAttr(const std::string &name) const { | |||||
| auto proto_map = GetAttrMap().GetProtoMsg(); | |||||
| if (proto_map != nullptr) { | |||||
| if (proto_map->find(name) != proto_map->end()) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return std::find(requiredAttrs_.begin(), requiredAttrs_.end(), name) != requiredAttrs_.end(); | |||||
| } | |||||
| graphStatus AttrHolder::DelAttr(const std::string &name) { | |||||
| auto proto_map = MutableAttrMap().GetProtoMsg(); | |||||
| if (proto_map == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto it = proto_map->find(name); | |||||
| if (it != proto_map->end()) { | |||||
| (void)proto_map->erase(it); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const std::map<string, GeAttrValue> AttrHolder::GetAllAttrs() const { | |||||
| std::map<string, GeAttrValue> attr_value_map; | |||||
| auto proto_map = GetAttrMap().GetProtoMsg(); | |||||
| if (proto_map != nullptr) { | |||||
| auto proto_owner = GetAttrMap().GetProtoOwner(); | |||||
| GE_CHK_BOOL_EXEC(proto_owner != nullptr, return attr_value_map, "proto_owner is nullptr"); | |||||
| for (const auto &it : *proto_map) { | |||||
| attr_value_map[it.first] = GeAttrValue(proto_owner, const_cast<proto::AttrDef *>(&it.second)); | |||||
| } | |||||
| } | |||||
| return attr_value_map; | |||||
| } | |||||
| const std::unordered_set<string> AttrHolder::GetAllAttrNames() const { | |||||
| std::unordered_set<string> names; | |||||
| auto proto_map = GetAttrMap().GetProtoMsg(); | |||||
| if (proto_map != nullptr) { | |||||
| for (const auto &it : *proto_map) { | |||||
| (void)names.insert(it.first); | |||||
| } | |||||
| } | |||||
| for (const string &it : requiredAttrs_) { | |||||
| (void)names.insert(it); | |||||
| } | |||||
| return names; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::AttrDef>::InitDefault() { | |||||
| std::shared_ptr<proto::AttrDef> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::AttrDef>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::AttrDef make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::TensorDef>::InitDefault() { | |||||
| std::shared_ptr<proto::TensorDef> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::TensorDef>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::TensorDef make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::TensorDescriptor>::InitDefault() { | |||||
| std::shared_ptr<proto::TensorDescriptor> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::TensorDescriptor>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::ShapeDef>::InitDefault() { | |||||
| std::shared_ptr<proto::ShapeDef> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::ShapeDef>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::ShapeDef make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::NamedAttrs>::InitDefault() { | |||||
| std::shared_ptr<proto::NamedAttrs> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::NamedAttrs>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::NamedAttrs make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::ModelDef>::InitDefault() { | |||||
| std::shared_ptr<proto::ModelDef> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::ModelDef>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::OpDef>::InitDefault() { | |||||
| std::shared_ptr<proto::OpDef> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::OpDef>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<proto::GraphDef>::InitDefault() { | |||||
| std::shared_ptr<proto::GraphDef> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::GraphDef>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner.get(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<ProtoAttrMap>::InitDefault() { | |||||
| std::shared_ptr<proto::TensorDescriptor> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::TensorDescriptor>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = proto_owner->mutable_attr(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| template <> | |||||
| void GeIrProtoHelper<const ProtoAttrMap>::InitDefault() { | |||||
| std::shared_ptr<proto::TensorDescriptor> proto_owner; | |||||
| proto_owner = ComGraphMakeShared<proto::TensorDescriptor>(); | |||||
| if (proto_owner == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::TensorDescriptor make shared failed"); | |||||
| return; | |||||
| } | |||||
| protoMsg_ = &proto_owner->attr(); | |||||
| protoOwner_ = proto_owner; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,508 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "format_refiner.h" | |||||
| #include <deque> | |||||
| #include <iostream> | |||||
| #include <set> | |||||
| #include <unordered_map> | |||||
| #include <unordered_set> | |||||
| #include "graph/ref_relation.h" | |||||
| #include "./compute_graph.h" | |||||
| #include "./ge_error_codes.h" | |||||
| #include "./graph/ge_tensor.h" | |||||
| #include "./operator.h" | |||||
| #include "./operator_factory.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_op_types.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "utils/node_utils.h" | |||||
| #include "utils/op_desc_utils.h" | |||||
| #include "utils/tensor_utils.h" | |||||
| #include "utils/type_utils.h" | |||||
| using namespace ge; | |||||
| using namespace std; | |||||
| namespace ge { | |||||
| namespace { | |||||
| const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||||
| const string kIsGraphInferred = "_is_graph_inferred"; | |||||
| thread_local RefRelations reflection_builder; | |||||
| } // namespace | |||||
| graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||||
| std::deque<ge::NodePtr> &nodes, ge::Format to_be_set_format) { | |||||
| for (const auto &cell : reflection) { | |||||
| auto node = cell.node; | |||||
| auto in_out_idx = cell.in_out_idx; | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| if (cell.in_out == ge::NODE_IN) { | |||||
| auto desc = node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(in_out_idx)); | |||||
| desc.SetOriginFormat(to_be_set_format); | |||||
| desc.SetFormat(to_be_set_format); | |||||
| (void)node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||||
| } else { | |||||
| auto desc = node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(in_out_idx)); | |||||
| desc.SetOriginFormat(to_be_set_format); | |||||
| desc.SetFormat(to_be_set_format); | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||||
| } | |||||
| nodes.push_back(cell.node); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { | |||||
| // 5 meas dim num | |||||
| if (node_ptr->GetType() != "BiasAdd") { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; | |||||
| for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { | |||||
| auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(in_desc); | |||||
| if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||||
| continue; | |||||
| } | |||||
| auto format = in_desc->GetOriginFormat(); | |||||
| auto key = TypeUtils::FormatToSerialString(format); | |||||
| auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||||
| in_desc->SetOriginFormat(fixed_format); | |||||
| in_desc->SetFormat(fixed_format); | |||||
| GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, | |||||
| node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||||
| } | |||||
| for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { | |||||
| auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(out_desc); | |||||
| if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||||
| continue; | |||||
| } | |||||
| auto format = out_desc->GetOriginFormat(); | |||||
| auto key = TypeUtils::FormatToSerialString(format); | |||||
| auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||||
| out_desc->SetOriginFormat(fixed_format); | |||||
| out_desc->SetFormat(fixed_format); | |||||
| GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, | |||||
| node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||||
| TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { | |||||
| ConstGeTensorPtr tensor_value; | |||||
| if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { | |||||
| GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GE_CHECK_NOTNULL(tensor_value); | |||||
| (void)op_desc->UpdateOutputDesc(0, tensor_value->GetTensorDesc()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||||
| std::vector<ge::NodePtr> &data_nodes, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "input graph is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| anchor_points.clear(); | |||||
| // Get all anchor point nodes and switch nodes | |||||
| for (auto &node_ptr : graph->GetAllNodes()) { | |||||
| if (node_ptr == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus status = RefreshConstantOutProcess(graph, op_desc); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // consider special node save process | |||||
| // get all input desc format | |||||
| bool node_is_all_nd = false; | |||||
| auto input_size = static_cast<uint32_t>(op_desc->GetAllInputsSize()); | |||||
| for (uint32_t i = 0; i < input_size; i++) { | |||||
| // Operator pre-set format but not origin format | |||||
| GE_IF_BOOL_EXEC(op_desc->MutableInputDesc(i) == nullptr, continue); | |||||
| auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||||
| // Pre-save data node (only main graph data) and default infer fail | |||||
| if (node_ptr->GetType() == DATA) { | |||||
| data_nodes.push_back(node_ptr); | |||||
| } | |||||
| if (input_format != FORMAT_ND && input_format != FORMAT_RESERVED) { | |||||
| node_is_all_nd = true; | |||||
| } | |||||
| } | |||||
| // Get all output desc format | |||||
| auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||||
| for (uint32_t i = 0; i < output_size; i++) { | |||||
| GE_IF_BOOL_EXEC(op_desc->MutableOutputDesc(i) == nullptr, continue); | |||||
| auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); | |||||
| if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | |||||
| node_is_all_nd = true; | |||||
| } | |||||
| } | |||||
| // check anchor point valid | |||||
| if (!node_is_all_nd) { | |||||
| continue; | |||||
| } | |||||
| // special process for biasAdd op | |||||
| // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg | |||||
| // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism | |||||
| // so here do special process | |||||
| status = BiasAddFormatFixProcess(node_ptr); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); | |||||
| anchor_points.push_back(node_ptr); | |||||
| } | |||||
| GELOGI("anchor_points number is %zu", anchor_points.size()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||||
| if (anchor_node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "anchor node is null!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::deque<ge::NodePtr> nodes; | |||||
| nodes.push_back(anchor_node); | |||||
| while (!nodes.empty()) { | |||||
| ge::NodePtr node = nodes.front(); | |||||
| nodes.pop_front(); | |||||
| graphStatus status = BackInferProcess(nodes, node, node_status); | |||||
| if (status != GRAPH_SUCCESS && node != nullptr) { | |||||
| GELOGE(status, "BackInferProcess failed!node name [%s]", node->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| status = ForwardInferProcess(nodes, node, node_status); | |||||
| if (status != GRAPH_SUCCESS && node != nullptr) { | |||||
| GELOGE(status, "ForwardInferProcess failed!node name [%s]", node->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| GELOGD("Enter back infer process!Node is [%s]", (node->GetName()).c_str()); | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | |||||
| auto in_data_anchor_idx = in_anchor->GetIdx(); | |||||
| auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx)); | |||||
| GE_IF_BOOL_EXEC(input_desc == nullptr, continue); | |||||
| auto to_be_set_format = input_desc->GetOriginFormat(); | |||||
| if (to_be_set_format == FORMAT_ND) { | |||||
| GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | |||||
| continue; | |||||
| } | |||||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_data_anchor == nullptr) { | |||||
| GELOGW("Node[%s] %dth in data anchor's peer_out_anchor is null", (node->GetName()).c_str(), in_data_anchor_idx); | |||||
| continue; | |||||
| } | |||||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { | |||||
| GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (node->GetName()).c_str()); | |||||
| continue; | |||||
| } | |||||
| // Check format whether have been set | |||||
| int idx = peer_out_data_anchor->GetIdx(); | |||||
| // do peer_out_node name and index as key to lookup reflections | |||||
| ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); | |||||
| std::unordered_set<RefCell, RefCellHash> reflection; | |||||
| auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", | |||||
| (peer_out_data_node->GetName()).c_str(), idx); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||||
| if (dim_num == 0) { | |||||
| GELOGD("node name:%s idx:%d out is scalar. stop back infer!", peer_out_data_node->GetName().c_str(), idx); | |||||
| continue; | |||||
| } | |||||
| /// Check whether node to change dims () | |||||
| /// Because some node will calculate with 5D, C dim maybe multi meaning | |||||
| auto peer_out_data_node_type = peer_out_data_node->GetType(); | |||||
| auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); | |||||
| // 4 means dims num | |||||
| if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { | |||||
| GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", | |||||
| (peer_out_data_node->GetName()).c_str()); | |||||
| continue; | |||||
| } | |||||
| if (reflection.empty()) { | |||||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
| // Call operator infer format api (forward) to get out format | |||||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||||
| status = peer_out_data_node->InferOriginFormat(); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| nodes.push_back(peer_out_data_node); | |||||
| } else { | |||||
| auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| GELOGD("Enter forward infer process!Node is [%s]", (node->GetName()).c_str()); | |||||
| for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | |||||
| GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | |||||
| auto out_data_anchor_idx = out_data_anchor->GetIdx(); | |||||
| auto to_be_set_format = | |||||
| node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat(); | |||||
| if (to_be_set_format == FORMAT_ND) { | |||||
| GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | |||||
| continue; | |||||
| } | |||||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); | |||||
| auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||||
| GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); | |||||
| GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); | |||||
| // Check format whether have been set | |||||
| int idx = peer_in_data_anchor->GetIdx(); | |||||
| // do peer_out_node name and index as key to lookup reflections | |||||
| ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); | |||||
| std::unordered_set<RefCell, RefCellHash> reflection; | |||||
| auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", | |||||
| (peer_in_data_node->GetName()).c_str(), idx); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||||
| if (dim_num == 0) { | |||||
| GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx); | |||||
| continue; | |||||
| } | |||||
| /// Check whether node to change dims () | |||||
| /// Because some node will calculate with 5D, C dim maybe multi meaning | |||||
| auto peer_in_data_node_type = peer_in_data_node->GetType(); | |||||
| auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); | |||||
| // 4 means dims num | |||||
| if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4)) { | |||||
| GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | |||||
| continue; | |||||
| } | |||||
| if (reflection.empty()) { | |||||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||||
| (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
| /// Because netoutput node added before infer format ,so netoutput is end condition | |||||
| /// must set netoutput format , because saved result depend on format | |||||
| if (peer_in_data_node_type == NETOUTPUT) { | |||||
| continue; | |||||
| } | |||||
| // Call operator infer format api (forward) to get out format | |||||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||||
| status = peer_in_data_node->InferOriginFormat(); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| nodes.push_back(peer_in_data_node); | |||||
| } else { | |||||
| auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor_points) { | |||||
| for (const auto &node : anchor_points) { | |||||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { | |||||
| if (input_desc != nullptr) { | |||||
| input_desc->SetOriginFormat(input_desc->GetFormat()); | |||||
| } | |||||
| } | |||||
| for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { | |||||
| if (output_desc != nullptr) { | |||||
| output_desc->SetOriginFormat(output_desc->GetFormat()); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||||
| ge::Format data_format, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||||
| if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { | |||||
| GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), | |||||
| TypeUtils::FormatToSerialString(data_format).c_str()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GELOGD("Enter DataNodeFormatProcess"); | |||||
| std::vector<ge::NodePtr> uninfered_data_nodes; | |||||
| // Check and renew data nodes format | |||||
| for (const auto &data_node : data_nodes) { | |||||
| GE_CHECK_NOTNULL(data_node); | |||||
| auto op_desc = data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); | |||||
| auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); | |||||
| if (curr_format != FORMAT_ND) { | |||||
| // Data format has been infered , continue | |||||
| continue; | |||||
| } | |||||
| // Set format for un-infered data node | |||||
| auto input_descs = op_desc->GetAllInputsDescPtr(); | |||||
| auto output_descs = op_desc->GetAllOutputsDescPtr(); | |||||
| for (const auto &input_desc : input_descs) { | |||||
| if (input_desc != nullptr) { | |||||
| input_desc->SetOriginFormat(data_format); | |||||
| input_desc->SetFormat(data_format); | |||||
| } | |||||
| } | |||||
| for (const auto &output_desc : output_descs) { | |||||
| if (output_desc != nullptr) { | |||||
| output_desc->SetOriginFormat(data_format); | |||||
| output_desc->SetFormat(data_format); | |||||
| } | |||||
| } | |||||
| uninfered_data_nodes.push_back(data_node); | |||||
| } | |||||
| // Reinfer format from uninfered data nodes | |||||
| for (const auto &node : uninfered_data_nodes) { | |||||
| if (node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| GELOGD("data node [%s] start infer format process", node->GetName().c_str()); | |||||
| auto status = AnchorProcess(node, node_status); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "data node [%s] infer format process failed!", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("DataNodeFormatProcess success"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { | |||||
| GELOGI("Enter InferOrigineFormat process!"); | |||||
| // True: infered false:no-infered | |||||
| std::unordered_map<ge::NodePtr, bool> node_status; | |||||
| std::vector<ge::NodePtr> anchor_points; | |||||
| std::vector<ge::NodePtr> data_nodes; | |||||
| // global net format | |||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "input graph is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // build reflection relations of boundary | |||||
| (void)reflection_builder.Clear(); | |||||
| auto status = reflection_builder.BuildRefRelations(*graph); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // User set global net format | |||||
| status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // Refresh origin format of anchor point | |||||
| RefreshOriginFormatOfAnchor(anchor_points); | |||||
| // Infer format process | |||||
| for (const auto &anchor_node : anchor_points) { | |||||
| if (anchor_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| status = AnchorProcess(anchor_node, node_status); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Anchor node [%s] process failed!", anchor_node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| /// According to discuss with sys-enginer, data node default format is ND.Its format | |||||
| /// should be set by infered.But if some data-node can not be got by infer, set context's | |||||
| /// format for these data nodes. | |||||
| /// Notice: ignore 5D formats | |||||
| auto data_format = graph->GetDataFormat(); | |||||
| status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); | |||||
| (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); | |||||
| return status; | |||||
| } | |||||
| bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { | |||||
| bool is_graph_inferred = false; | |||||
| return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,50 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_FORMAT_REFINER_H_ | |||||
| #define COMMON_GRAPH_FORMAT_REFINER_H_ | |||||
| #include <deque> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "./compute_graph.h" | |||||
| #include "./external/graph/types.h" | |||||
| #include "./ge_error_codes.h" | |||||
| namespace ge { | |||||
| // ShapeRefiner performs shape inference for compute graphs | |||||
| class FormatRefiner { | |||||
| public: | |||||
| static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); | |||||
| private: | |||||
| static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
| static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||||
| std::vector<ge::NodePtr> &data_nodes, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||||
| static graphStatus AnchorProcess(const ge::NodePtr &anchor_node, std::unordered_map<ge::NodePtr, bool> &node_status); | |||||
| static void RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor_points); | |||||
| static graphStatus BackInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||||
| static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||||
| static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||||
| ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status); | |||||
| static bool IsGraphInferred(const ComputeGraphPtr &graph); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // COMMON_GRAPH_FORMAT_REFINER_H_ | |||||
| @@ -1,384 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "external/graph/graph.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/debug/ge_op_types.h" | |||||
| #include "graph/model.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| using std::map; | |||||
| using std::pair; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| class GraphImpl { | |||||
| public: | |||||
| friend class GraphUtils; | |||||
| GraphImpl(const GraphImpl &) = delete; | |||||
| GraphImpl &operator=(const GraphImpl &) = delete; | |||||
| explicit GraphImpl(const std::string &name) : name_(name) {} | |||||
| ~GraphImpl() { | |||||
| if (IsValid()) { | |||||
| if (compute_graph_ != nullptr) { | |||||
| GraphUtils::BreakConnect(compute_graph_->GetAllNodesInfo()); | |||||
| } | |||||
| } | |||||
| for (const auto &it : op_list_) { | |||||
| Operator op = it.second; | |||||
| op.BreakConnect(); | |||||
| } | |||||
| } | |||||
| graphStatus SetInputs(const std::vector<Operator> &inputs) { | |||||
| compute_graph_ = GraphUtils::CreateGraphFromOperator(name_, inputs); | |||||
| GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "Build Graph failed."); | |||||
| GE_CHK_BOOL_RET_STATUS(inputs.size() != 0, GRAPH_FAILED, "set input NULL."); | |||||
| compute_graph_->SetInputSize(static_cast<uint32_t>(inputs.size())); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus SetOutputs(const std::vector<Operator> &outputs) { | |||||
| if (compute_graph_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (outputs.empty()) { | |||||
| GELOGW("set outputs size is 0."); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // Construct special output node | |||||
| std::vector<std::pair<Operator, std::vector<size_t>>> output_indexs; | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| output_indexs.emplace_back(outputs[i], std::vector<size_t>{}); | |||||
| } | |||||
| graphStatus ret = SetOutputs(output_indexs); | |||||
| return ret; | |||||
| } | |||||
| graphStatus SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) { | |||||
| if (compute_graph_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "set ComputeGraph failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (output_indexs.empty()) { | |||||
| GELOGW("set outputs size is 0."); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // Construct special output node | |||||
| std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes; | |||||
| for (const auto &item : output_indexs) { | |||||
| const Operator &output = item.first; | |||||
| const vector<size_t> &indexs = item.second; | |||||
| ge::NodePtr node = compute_graph_->FindNode(output.GetName()); | |||||
| if (node == nullptr) { | |||||
| GELOGW("user designated out_node [%s] not exist in graph, will ignored!", output.GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); | |||||
| size_t out_size = tmp_op_ptr->GetOutputsSize(); | |||||
| if (indexs.empty()) { | |||||
| for (size_t i = 0; i < out_size; ++i) { | |||||
| output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; | |||||
| output_nodes.emplace_back(node, i); | |||||
| } | |||||
| } else { | |||||
| for (size_t i = 0; i < indexs.size(); ++i) { | |||||
| if (indexs[i] >= out_size) { | |||||
| GELOGW("index[%zu] is not belong to out_node[%s]", indexs[i], output.GetName().c_str()); | |||||
| } else { | |||||
| output_name_ += output.GetName() + ":" + std::to_string(i) + ";"; | |||||
| output_nodes.emplace_back(node, indexs[i]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // Del last ";" | |||||
| if (!output_name_.empty()) { | |||||
| output_name_ = output_name_.substr(0, output_name_.length() - 1); | |||||
| } | |||||
| compute_graph_->SetUserDefOutput(output_name_); | |||||
| compute_graph_->SetOutputSize(static_cast<uint32_t>(output_indexs.size())); | |||||
| compute_graph_->SetGraphOutNodesInfo(output_nodes); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus SetOutputs(const std::vector<pair<Operator, string>> &outputs) { | |||||
| GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); | |||||
| GE_CHK_BOOL_EXEC_INFO(outputs.size() != 0, return GRAPH_SUCCESS, "set outputs size is 0."); | |||||
| // Construct specified output | |||||
| std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes; | |||||
| for (auto item : outputs) { | |||||
| ge::NodePtr node = compute_graph_->FindNode(item.first.GetName()); | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, " Warning, user designated out_node (%s) not exist in graph, this out_node ignored!", | |||||
| item.first.GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(tmp_op_ptr, continue); | |||||
| size_t out_size = tmp_op_ptr->GetOutputsSize(); | |||||
| if (item.second.empty()) { | |||||
| for (size_t i = 0; i < out_size; ++i) { | |||||
| output_name_ += item.first.GetName() + ":" + std::to_string(i) + ";"; | |||||
| output_nodes.push_back(std::make_pair(node, i)); | |||||
| } | |||||
| } else { | |||||
| int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second); | |||||
| if (index < 0) { | |||||
| GELOGE(GRAPH_FAILED, | |||||
| " Warning, user designated out_node (%s):(%s) not exist in graph, this out_node ignored!", | |||||
| item.first.GetName().c_str(), item.second.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| output_name_ += item.first.GetName() + ":" + std::to_string(index) + ";"; | |||||
| output_nodes.push_back(std::make_pair(node, index)); | |||||
| } | |||||
| } | |||||
| // Del last ";" | |||||
| if (!output_name_.empty()) { | |||||
| output_name_ = output_name_.substr(0, output_name_.length() - 1); | |||||
| } | |||||
| compute_graph_->SetOutputSize(static_cast<uint32_t>(outputs.size())); | |||||
| compute_graph_->SetGraphOutNodesInfo(output_nodes); | |||||
| GELOGI("********************SetOutputs Success***********************"); | |||||
| GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str())); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus SetTargets(const std::vector<Operator> &targets) { | |||||
| GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "set ComputeGraph faild."); | |||||
| GE_CHK_BOOL_EXEC_INFO(targets.size() != 0, return GRAPH_SUCCESS, "set targets size is 0."); | |||||
| std::vector<ge::NodePtr> target_nodes; | |||||
| for (auto item : targets) { | |||||
| ge::NodePtr node = compute_graph_->FindNode(item.GetName()); | |||||
| if (node == nullptr) { | |||||
| GELOGW(" Warning, user designated target_node (%s) not exist in graph, this target_node ignored!", | |||||
| item.GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| target_nodes.push_back(node); | |||||
| } | |||||
| compute_graph_->SetGraphTargetNodesInfo(target_nodes); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool IsValid() const { return (compute_graph_ != nullptr); } | |||||
| graphStatus AddOp(const ge::Operator &op) { | |||||
| std::pair<std::map<string, ge::Operator>::iterator, bool> ret; | |||||
| ret = op_list_.emplace(std::pair<string, ge::Operator>(op.GetName(), op)); | |||||
| GE_CHK_BOOL_RET_STATUS(ret.second != false, GRAPH_FAILED, "the op have added before, op name:%s.", | |||||
| op.GetName().c_str()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GetAllOpName(std::vector<string> &op_name) const { | |||||
| for (const auto &it : op_list_) { | |||||
| op_name.push_back(it.second.GetName()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FindOpByName(const string &name, ge::Operator &op) const { | |||||
| auto it = op_list_.find(name); | |||||
| GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); | |||||
| op = it->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||||
| for (auto &op : op_list_) { | |||||
| auto op_type = op.second.GetOpType(); | |||||
| if (op_type == type) { | |||||
| ops.push_back(op.second); | |||||
| continue; | |||||
| } | |||||
| if (op_type == ge::FRAMEWORKOP) { | |||||
| op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); | |||||
| if (op_type == type) { | |||||
| ops.push_back(op.second); | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void SetNeedIteration(bool need_iteration) { | |||||
| if (compute_graph_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); | |||||
| return; | |||||
| } | |||||
| compute_graph_->SetNeedIteration(need_iteration); | |||||
| } | |||||
| const std::string &GetName() const { return name_; } | |||||
| private: | |||||
| std::string name_; | |||||
| std::string output_name_; | |||||
| std::map<string, ge::Operator> op_list_; | |||||
| ComputeGraphPtr compute_graph_{nullptr}; | |||||
| }; | |||||
| Graph::Graph(const std::string &name) { | |||||
| impl_ = ComGraphMakeShared<GraphImpl>(name); | |||||
| if (impl_ == nullptr) { | |||||
| GELOGW("GraphImpl make shared failed, impl_ is nullptr"); | |||||
| } | |||||
| } | |||||
| graphStatus Graph::AddOp(const ge::Operator &op) { | |||||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, "AddOp failed: graph can not be used, impl is nullptr."); | |||||
| return impl_->AddOp(op); | |||||
| } | |||||
| graphStatus Graph::GetAllOpName(std::vector<string> &op_name) const { | |||||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, | |||||
| "GetAllOpName failed: graph can not be used, impl is nullptr."); | |||||
| return impl_->GetAllOpName(op_name); | |||||
| } | |||||
| graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { | |||||
| Operator op_find_op_def("NULL"); | |||||
| op = op_find_op_def; | |||||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, | |||||
| "FindOpByName failed: graph can not be used, impl is nullptr."); | |||||
| return impl_->FindOpByName(name, op); | |||||
| } | |||||
| graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||||
| GE_CHECK_NOTNULL(impl_); | |||||
| return impl_->FindOpByType(type, ops); | |||||
| } | |||||
| Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) { | |||||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") | |||||
| GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); | |||||
| (void)impl_->SetInputs(inputs); | |||||
| return *this; | |||||
| } | |||||
| Graph &Graph::SetOutputs(const vector<ge::Operator> &outputs) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); | |||||
| return *this; | |||||
| } | |||||
| (void)impl_->SetOutputs(outputs); | |||||
| return *this; | |||||
| } | |||||
| Graph &Graph::SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetOutputs failed: graph can not be used, impl is nullptr."); | |||||
| return *this; | |||||
| } | |||||
| (void)impl_->SetOutputs(output_indexs); | |||||
| return *this; | |||||
| } | |||||
| Graph &Graph::SetOutputs(const std::vector<pair<Operator, string>> &outputs) { | |||||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetOutputs failed: graph can not be used, impl is nullptr.") | |||||
| (void)impl_->SetOutputs(outputs); | |||||
| return *this; | |||||
| } | |||||
| Graph &Graph::SetTargets(const vector<ge::Operator> &targets) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetTargets failed: graph can not be used, impl is nullptr."); | |||||
| return *this; | |||||
| } | |||||
| (void)impl_->SetTargets(targets); | |||||
| return *this; | |||||
| } | |||||
| bool Graph::IsValid() const { | |||||
| if (impl_ == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return impl_->IsValid(); | |||||
| } | |||||
| void Graph::SetNeedIteration(bool need_iteration) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Set need iteration failed, as impl is null."); | |||||
| return; | |||||
| } | |||||
| impl_->SetNeedIteration(need_iteration); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { | |||||
| GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); | |||||
| return graph.impl_->compute_graph_; | |||||
| } | |||||
| graphStatus Graph::SaveToFile(const string &file_name) const { | |||||
| Model model = Model(); | |||||
| model.SetGraph(*this); | |||||
| return model.SaveToFile(file_name); | |||||
| } | |||||
| graphStatus Graph::LoadFromFile(const string &file_name) { | |||||
| Model model = Model(); | |||||
| graphStatus ret = model.LoadFromFile(file_name); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| *this = model.GetGraph(); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph | |||||
| GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { | |||||
| GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); | |||||
| auto name = compute_graph->GetName(); | |||||
| auto graph = Graph(name); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(graph.impl_ != nullptr, return graph); | |||||
| graph.impl_->compute_graph_ = compute_graph; | |||||
| return graph; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { | |||||
| GE_CHECK_NOTNULL(graph.impl_); | |||||
| GE_CHECK_NOTNULL(graph.impl_->compute_graph_); | |||||
| graph.impl_->op_list_.clear(); | |||||
| for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { | |||||
| graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,294 +0,0 @@ | |||||
| LOCAL_PATH := $(call my-dir) | |||||
| include $(LOCAL_PATH)/stub/Makefile | |||||
| COMMON_LOCAL_SRC_FILES := \ | |||||
| ./proto/om.proto \ | |||||
| ./proto/ge_ir.proto \ | |||||
| ./proto/ge_onnx.proto \ | |||||
| ./proto/insert_op.proto \ | |||||
| ./proto/task.proto \ | |||||
| ./proto/fwk_adapter.proto \ | |||||
| ./proto/op_mapping_info.proto \ | |||||
| ./proto/dump_task.proto \ | |||||
| ./anchor.cc \ | |||||
| ./ge_attr_value.cc \ | |||||
| ./attr_value.cc \ | |||||
| ./buffer.cc \ | |||||
| ./compute_graph.cc \ | |||||
| ./graph.cc \ | |||||
| ./inference_context.cc \ | |||||
| ./shape_refiner.cc \ | |||||
| ./format_refiner.cc \ | |||||
| ./ref_relation.cc \ | |||||
| ./model.cc \ | |||||
| ./model_serialize.cc \ | |||||
| ./node.cc \ | |||||
| ./op_desc.cc \ | |||||
| ./operator.cc \ | |||||
| ./operator_factory.cc \ | |||||
| ./operator_factory_impl.cc \ | |||||
| ./ge_attr_define.cc \ | |||||
| ./ge_tensor.cc \ | |||||
| ./detail/attributes_holder.cc \ | |||||
| ./utils/anchor_utils.cc \ | |||||
| ./utils/tuning_utils.cc \ | |||||
| ./utils/graph_utils.cc \ | |||||
| ./utils/ge_ir_utils.cc \ | |||||
| ./utils/node_utils.cc \ | |||||
| ./utils/op_desc_utils.cc \ | |||||
| ./utils/type_utils.cc \ | |||||
| ./utils/tensor_utils.cc \ | |||||
| ./tensor.cc \ | |||||
| ./debug/graph_debug.cc \ | |||||
| ./opsproto/opsproto_manager.cc \ | |||||
| ../ops/op_imp.cpp \ | |||||
| option/ge_context.cc \ | |||||
| option/ge_local_context.cc \ | |||||
| ./runtime_inference_context.cc \ | |||||
| COMMON_LOCAL_C_INCLUDES := \ | |||||
| proto/om.proto \ | |||||
| proto/ge_ir.proto \ | |||||
| proto_inner/ge_onnx.proto \ | |||||
| proto/insert_op.proto \ | |||||
| proto/task.proto \ | |||||
| proto/fwk_adapter.proto \ | |||||
| proto/op_mapping_info.proto \ | |||||
| proto/dump_task.proto \ | |||||
| inc \ | |||||
| inc/external \ | |||||
| inc/external/graph \ | |||||
| inc/graph \ | |||||
| inc/common \ | |||||
| common \ | |||||
| common/graph \ | |||||
| third_party/protobuf/include \ | |||||
| libc_sec/include \ | |||||
| ops/built-in/op_proto/inc \ | |||||
| #compiler for host | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libgraph | |||||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||||
| LOCAL_CPPFLAGS += -fexceptions | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libprotobuf \ | |||||
| libslog \ | |||||
| liberror_manager \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||||
| #compiler for host | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := stub/libgraph | |||||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||||
| LOCAL_CPPFLAGS += -fexceptions | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := \ | |||||
| ../../out/graph/lib64/stub/graph.cc \ | |||||
| ../../out/graph/lib64/stub/operator.cc \ | |||||
| ../../out/graph/lib64/stub/tensor.cc \ | |||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||||
| #compiler for host | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := fwk_stub/libgraph | |||||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||||
| LOCAL_CPPFLAGS += -fexceptions | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := \ | |||||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||||
| ../../out/graph/lib64/stub/graph.cc \ | |||||
| ../../out/graph/lib64/stub/operator.cc \ | |||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||||
| ../../out/graph/lib64/stub/tensor.cc \ | |||||
| ../../out/graph/lib64/stub/inference_context.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||||
| #compiler for device | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libgraph | |||||
| LOCAL_CFLAGS += -O2 | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libprotobuf \ | |||||
| libslog \ | |||||
| liberror_manager \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| ifeq ($(device_os),android) | |||||
| LOCAL_LDFLAGS := -ldl | |||||
| endif | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_SHARED_LIBRARY) | |||||
| #compiler for device | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := stub/libgraph | |||||
| LOCAL_CFLAGS += -O2 | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := \ | |||||
| ../../out/graph/lib64/stub/graph.cc \ | |||||
| ../../out/graph/lib64/stub/operator.cc \ | |||||
| ../../out/graph/lib64/stub/tensor.cc \ | |||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| ifeq ($(device_os),android) | |||||
| LOCAL_LDFLAGS := -ldl | |||||
| endif | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_SHARED_LIBRARY) | |||||
| #compiler for device | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := fwk_stub/libgraph | |||||
| LOCAL_CFLAGS += -O2 | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := \ | |||||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||||
| ../../out/graph/lib64/stub/graph.cc \ | |||||
| ../../out/graph/lib64/stub/operator.cc \ | |||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||||
| ../../out/graph/lib64/stub/tensor.cc \ | |||||
| ../../out/graph/lib64/stub/inference_context.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| ifeq ($(device_os),android) | |||||
| LOCAL_LDFLAGS := -ldl | |||||
| endif | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_SHARED_LIBRARY) | |||||
| # compile for ut/st | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libgraph | |||||
| LOCAL_CFLAGS += | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libprotobuf \ | |||||
| libslog \ | |||||
| liberror_manager \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_LLT_SHARED_LIBRARY) | |||||
| #compiler for host static lib | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libgraph | |||||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||||
| LOCAL_CPPFLAGS += -fexceptions | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_STATIC_LIBRARIES := \ | |||||
| libprotobuf \ | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libslog \ | |||||
| liberror_manager \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_HOST_STATIC_LIBRARY) | |||||
| #compiler for device static lib | |||||
| include $(CLEAR_VARS) | |||||
| LOCAL_MODULE := libgraph | |||||
| LOCAL_CFLAGS += -O2 | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | |||||
| LOCAL_STATIC_LIBRARIES := \ | |||||
| libprotobuf \ | |||||
| LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | |||||
| libslog \ | |||||
| liberror_manager \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | |||||
| LOCAL_MULTILIB := 64 | |||||
| LOCAL_PROPRIETARY_MODULE := true | |||||
| include $(BUILD_STATIC_LIBRARY) | |||||
| @@ -1,112 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "external/graph/inference_context.h" | |||||
| #include "debug/ge_util.h" | |||||
| namespace ge { | |||||
| class ShapeAndTypeImpl { | |||||
| public: | |||||
| ShapeAndTypeImpl() = default; | |||||
| ~ShapeAndTypeImpl() = default; | |||||
| ShapeAndTypeImpl(const Shape &shape, DataType data_type) : shape_(shape), data_type_(data_type) {} | |||||
| Shape shape_; | |||||
| DataType data_type_ = DT_UNDEFINED; | |||||
| }; | |||||
| class InferenceContextImpl { | |||||
| public: | |||||
| InferenceContextImpl() = default; | |||||
| ~InferenceContextImpl() = default; | |||||
| // For deliver to op in pair, help to support dynamic shape | |||||
| std::vector<std::string> marks_; | |||||
| std::vector<std::vector<ShapeAndType>> input_handle_shapes_and_types_; | |||||
| std::vector<std::vector<ShapeAndType>> output_handle_shapes_and_types_; | |||||
| }; | |||||
| ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(); } | |||||
| ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { | |||||
| shape_and_type_impl_ = ComGraphMakeShared<ShapeAndTypeImpl>(shape, data_type); | |||||
| } | |||||
| void ShapeAndType::SetShape(const Shape &shape) { | |||||
| if (shape_and_type_impl_ != nullptr) { | |||||
| shape_and_type_impl_->shape_ = shape; | |||||
| } | |||||
| } | |||||
| void ShapeAndType::SetType(DataType data_type) { | |||||
| if (shape_and_type_impl_ != nullptr) { | |||||
| shape_and_type_impl_->data_type_ = data_type; | |||||
| } | |||||
| } | |||||
| Shape ShapeAndType::GetShape() const { | |||||
| if (shape_and_type_impl_ != nullptr) { | |||||
| return shape_and_type_impl_->shape_; | |||||
| } | |||||
| return Shape(); | |||||
| } | |||||
| DataType ShapeAndType::GetDataType() const { | |||||
| if (shape_and_type_impl_ != nullptr) { | |||||
| return shape_and_type_impl_->data_type_; | |||||
| } | |||||
| return DT_UNDEFINED; | |||||
| } | |||||
| InferenceContext::InferenceContext(std::unique_ptr<InferenceContextImpl> &impl) { | |||||
| inference_context_impl_ = std::move(impl); | |||||
| } | |||||
| std::unique_ptr<InferenceContext> InferenceContext::Create() { | |||||
| std::unique_ptr<InferenceContextImpl> impl = | |||||
| std::unique_ptr<InferenceContextImpl>(new (std::nothrow) InferenceContextImpl()); | |||||
| if (impl == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::unique_ptr<InferenceContext>(new (std::nothrow) InferenceContext(impl)); | |||||
| } | |||||
| void InferenceContext::SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) { | |||||
| inference_context_impl_->input_handle_shapes_and_types_.swap(shapes_and_types); | |||||
| } | |||||
| const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetInputHandleShapesAndTypes() const { | |||||
| return inference_context_impl_->input_handle_shapes_and_types_; | |||||
| } | |||||
| const std::vector<std::vector<ShapeAndType>> &InferenceContext::GetOutputHandleShapesAndTypes() const { | |||||
| return inference_context_impl_->output_handle_shapes_and_types_; | |||||
| } | |||||
| void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types) { | |||||
| inference_context_impl_->output_handle_shapes_and_types_ = shapes_and_types; | |||||
| } | |||||
| void InferenceContext::SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types) { | |||||
| inference_context_impl_->output_handle_shapes_and_types_.swap(shapes_and_types); | |||||
| } | |||||
| void InferenceContext::SetMarks(const std::vector<std::string> &marks) { inference_context_impl_->marks_ = marks; } | |||||
| const std::vector<std::string> &InferenceContext::GetMarks() const { return inference_context_impl_->marks_; } | |||||
| } // namespace ge | |||||
| @@ -1,190 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/model.h" | |||||
| #include <fcntl.h> | |||||
| #include <google/protobuf/io/coded_stream.h> | |||||
| #include <google/protobuf/io/zero_copy_stream.h> | |||||
| #include <google/protobuf/io/zero_copy_stream_impl.h> | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include <sys/stat.h> | |||||
| #include <sys/types.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | |||||
| #include <cstring> | |||||
| #include <fstream> | |||||
| #include <iomanip> | |||||
| #include "debug/ge_attr_define.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/model_serialize.h" | |||||
| #include "proto/ge_ir.pb.h" | |||||
| #include "utils/attr_utils.h" | |||||
| #include "utils/ge_ir_utils.h" | |||||
| using google::protobuf::io::FileInputStream; | |||||
| using google::protobuf::io::FileOutputStream; | |||||
| using google::protobuf::io::ZeroCopyInputStream; | |||||
| namespace { | |||||
| const int DEFAULT_VERSION = 1; | |||||
| const int ACCESS_PERMISSION_BITS = 0400; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| void Model::Init() { | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | |||||
| (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | |||||
| version_ = 0; | |||||
| } | |||||
| Model::Model() { | |||||
| attrs_.InitDefault(); | |||||
| Init(); | |||||
| } | |||||
| Model::Model(const string &name, const string &custom_version) | |||||
| : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) { | |||||
| attrs_.InitDefault(); | |||||
| Init(); | |||||
| } | |||||
| string Model::GetName() const { return name_; } | |||||
| void Model::SetName(const string &name) { name_ = name; } | |||||
| uint32_t Model::GetVersion() const { return version_; } | |||||
| string Model::GetPlatformVersion() const { return platform_version_; } | |||||
| void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } | |||||
| Graph Model::GetGraph() const { return graph_; } | |||||
| graphStatus Model::Save(Buffer &buffer, bool is_dump) const { | |||||
| ModelSerialize serialize; | |||||
| buffer = serialize.SerializeModel(*this, is_dump); | |||||
| return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
| } | |||||
| void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; } | |||||
| graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) { | |||||
| ModelSerialize serialize; | |||||
| model = serialize.UnserializeModel(data, len); | |||||
| return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
| } | |||||
| graphStatus Model::SaveToFile(const string &file_name) const { | |||||
| Buffer buffer; | |||||
| if ((*this).Save(buffer) != GRAPH_SUCCESS) { | |||||
| GE_LOGE("save to file fail."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // Write file | |||||
| ge::proto::ModelDef ge_proto; | |||||
| if (buffer.GetData() != nullptr) { | |||||
| std::string str((const char *)buffer.GetData(), buffer.GetSize()); | |||||
| if (!ge_proto.ParseFromString(str)) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0x00}; | |||||
| if (strlen(file_name.c_str()) >= PATH_MAX) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (realpath(file_name.c_str(), real_path) == nullptr) { | |||||
| GELOGI("file %s does not exit, it will be created.", file_name.c_str()); | |||||
| } | |||||
| int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | |||||
| if (fd < 0) { | |||||
| GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| bool ret = ge_proto.SerializeToFileDescriptor(fd); | |||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed"); | |||||
| if (close(fd) != 0) { | |||||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (close(fd) != 0) { | |||||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Model::Load(ge::proto::ModelDef &model_def) { | |||||
| ModelSerialize serialize; | |||||
| *this = serialize.UnserializeModel(model_def); | |||||
| return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
| } | |||||
| bool Model::IsValid() const { return graph_.IsValid(); } | |||||
| graphStatus Model::LoadFromFile(const string &file_name) { | |||||
| char real_path[PATH_MAX] = {0x00}; | |||||
| if (strlen(file_name.c_str()) >= PATH_MAX) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (realpath(file_name.c_str(), real_path) == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int fd = open(real_path, O_RDONLY); | |||||
| if (fd < 0) { | |||||
| GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ge::proto::ModelDef model_def; | |||||
| bool ret = model_def.ParseFromFileDescriptor(fd); | |||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed"); | |||||
| if (close(fd) != 0) { | |||||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (close(fd) != 0) { | |||||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return Load(model_def); | |||||
| } | |||||
| ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; } | |||||
| ConstProtoAttrMapHelper Model::GetAttrMap() const { | |||||
| return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg()); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,763 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/model_serialize.h" | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include <queue> | |||||
| #include <iostream> | |||||
| #include "debug/ge_attr_define.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/detail/model_serialize_imp.h" | |||||
| #include "proto/ge_ir.pb.h" | |||||
| #include "utils/graph_utils.h" | |||||
| #include "debug/ge_op_types.h" | |||||
| using std::map; | |||||
| using std::string; | |||||
| namespace ge { | |||||
| bool ModelSerializeImp::ParseNodeIndex(const string &node_index, string &node_name, int32_t &index) { | |||||
| auto sep = node_index.rfind(":"); | |||||
| if (sep == string::npos) { | |||||
| GELOGW("separator is not found in node_index."); | |||||
| return false; | |||||
| } | |||||
| node_name = node_index.substr(0, sep); | |||||
| auto index_str = node_index.substr(sep + 1); | |||||
| index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10)); | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeTensor(const ConstGeTensorPtr &tensor, | |||||
| proto::TensorDef *tensor_proto) { | |||||
| GE_CHK_BOOL_EXEC(tensor != nullptr, return false, "tensor is null."); | |||||
| GE_CHK_BOOL_EXEC(tensor_proto != nullptr, return false, "tensor_proto is null."); | |||||
| if (tensor->tensor_def_.GetProtoMsg() != nullptr) { | |||||
| *tensor_proto = *tensor->tensor_def_.GetProtoMsg(); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_proto) { | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is null."); | |||||
| GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); | |||||
| op_def_proto->clear_input(); | |||||
| // Inputs | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_data_anchor != nullptr) { | |||||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { | |||||
| op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||||
| std::to_string(peer_out_anchor->GetIdx())); | |||||
| } else { | |||||
| op_def_proto->add_input(""); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Control edge | |||||
| auto control_anchor = node->GetInControlAnchor(); | |||||
| if (control_anchor != nullptr) { | |||||
| auto peer_out_anchors = control_anchor->GetPeerOutControlAnchors(); | |||||
| for (const auto &peer_out_anchor : peer_out_anchors) { | |||||
| if (peer_out_anchor != nullptr && peer_out_anchor->GetOwnerNode()) { | |||||
| op_def_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":-1"); | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); | |||||
| GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); | |||||
| if (op_desc->op_def_.GetProtoMsg() != nullptr) { | |||||
| *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | |||||
| // Delete unnecessary attr | |||||
| if (is_dump) { | |||||
| auto attr = op_def_proto->mutable_attr(); | |||||
| attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); | |||||
| attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); | |||||
| attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); | |||||
| GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), | |||||
| attr->erase(ATTR_NAME_WEIGHTS)); | |||||
| } | |||||
| op_def_proto->clear_input_desc(); | |||||
| op_def_proto->clear_output_desc(); | |||||
| // Input descs | |||||
| if (op_desc->GetAllInputsSize() > 0) { | |||||
| auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize()); | |||||
| for (uint32_t i = 0; i < size; i++) { | |||||
| auto tensor_desc = op_desc->GetInputDescPtrDfault(i); | |||||
| if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | |||||
| *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Output descs | |||||
| if (op_desc->GetOutputsSize() > 0) { | |||||
| auto size = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||||
| for (uint32_t i = 0; i < size; i++) { | |||||
| auto tensor_desc = op_desc->GetOutputDescPtr(i); | |||||
| if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | |||||
| *op_def_proto->add_output_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | |||||
| } | |||||
| } | |||||
| } | |||||
| op_def_proto->set_id(op_desc->GetId()); | |||||
| for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||||
| op_def_proto->add_subgraph_name(name); | |||||
| } | |||||
| OpDescToAttrDef(op_desc, op_def_proto); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { | |||||
| proto::AttrDef key_in; | |||||
| proto::AttrDef value_in; | |||||
| auto op_desc_attr = op_def_proto->mutable_attr(); | |||||
| if (!op_desc->input_name_idx_.empty()) { | |||||
| for (auto &item : op_desc->input_name_idx_) { | |||||
| key_in.mutable_list()->add_s(item.first); | |||||
| value_in.mutable_list()->add_i(item.second); | |||||
| } | |||||
| op_desc_attr->insert({"_input_name_key", key_in}); | |||||
| op_desc_attr->insert({"_input_name_value", value_in}); | |||||
| } | |||||
| proto::AttrDef key_out; | |||||
| proto::AttrDef value_out; | |||||
| if (!op_desc->output_name_idx_.empty()) { | |||||
| for (auto &item : op_desc->output_name_idx_) { | |||||
| key_out.mutable_list()->add_s(item.first); | |||||
| value_out.mutable_list()->add_i(item.second); | |||||
| } | |||||
| op_desc_attr->insert({"_output_name_key", key_out}); | |||||
| op_desc_attr->insert({"_output_name_value", value_out}); | |||||
| } | |||||
| proto::AttrDef opt_input; | |||||
| if (!op_desc->optional_input_names_.empty()) { | |||||
| for (auto &item : op_desc->optional_input_names_) { | |||||
| opt_input.mutable_list()->add_s(item); | |||||
| } | |||||
| op_desc_attr->insert({"_opt_input", opt_input}); | |||||
| } | |||||
| } | |||||
| bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { | |||||
| if (node == nullptr || op_def_proto == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | |||||
| return false; | |||||
| } | |||||
| if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { | |||||
| GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | |||||
| return false; | |||||
| } | |||||
| if (SerializeEdge(node, op_def_proto)) { | |||||
| return true; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | |||||
| proto::GraphDef *graph_proto, | |||||
| bool is_dump) { | |||||
| if (graph == nullptr || graph_proto == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Input para Invalid"); | |||||
| return false; | |||||
| } | |||||
| graph_proto->set_name(graph->GetName()); | |||||
| // Inputs | |||||
| for (const auto &input : graph->GetInputNodes()) { | |||||
| if (input != nullptr) { | |||||
| graph_proto->add_input(input->GetName() + ":0"); | |||||
| } | |||||
| } | |||||
| // Outputs | |||||
| for (const auto &output : graph->GetGraphOutNodesInfo()) { | |||||
| if (output.first != nullptr) { | |||||
| graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); | |||||
| GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second); | |||||
| } | |||||
| } | |||||
| if (graph->attrs_.GetProtoMsg() != nullptr) { | |||||
| *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | |||||
| } | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { | |||||
| if (node->GetOpDesc() != nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { | |||||
| if (model_proto == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | |||||
| return false; | |||||
| } | |||||
| model_proto->set_name(model.GetName()); | |||||
| model_proto->set_custom_version(model.GetPlatformVersion()); | |||||
| model_proto->set_version(model.GetVersion()); | |||||
| if (model.attrs_.GetProtoMsg()) { | |||||
| *model_proto->mutable_attr() = *model.attrs_.GetProtoMsg(); | |||||
| } | |||||
| auto &graph = model.graph_; | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| if (compute_graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | |||||
| return false; | |||||
| } | |||||
| if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { | |||||
| GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | |||||
| return false; | |||||
| } | |||||
| for (auto subgraph : compute_graph->GetAllSubgraphs()) { | |||||
| if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { | |||||
| GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeTensor( | |||||
| GeTensorPtr &tensor, proto::TensorDef &tensor_proto) { | |||||
| tensor = std::shared_ptr<GeTensor>(new (std::nothrow) GeTensor(protobuf_owner_, &tensor_proto)); | |||||
| if (tensor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "tensor is nullptr"); | |||||
| return false; | |||||
| } else { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out, | |||||
| std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out, | |||||
| std::vector<string> &opt_input) { | |||||
| if (!key_in.empty()) { | |||||
| if (key_in.size() != value_in.size()) { | |||||
| GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), | |||||
| value_in.size()); | |||||
| } else { | |||||
| for (uint32_t i = 0; i < key_in.size(); ++i) { | |||||
| op_desc->input_name_idx_.insert(std::pair<string, uint32_t>(key_in.at(i), value_in.at(i))); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!key_out.empty()) { | |||||
| if (key_out.size() != value_out.size()) { | |||||
| GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), | |||||
| value_out.size()); | |||||
| } else { | |||||
| for (uint32_t i = 0; i < key_out.size(); ++i) { | |||||
| op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key_out.at(i), value_out.at(i))); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (!opt_input.empty()) { | |||||
| for (const auto &i : opt_input) { | |||||
| op_desc->optional_input_names_.insert(i); | |||||
| } | |||||
| } | |||||
| } | |||||
| bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { | |||||
| std::vector<string> opt_input; | |||||
| std::vector<string> key_in; | |||||
| std::vector<uint32_t> value_in; | |||||
| if (op_def_proto.attr().count("_opt_input") > 0) { | |||||
| auto &name_list = op_def_proto.attr().at("_opt_input").list(); | |||||
| for (const auto &item_s : name_list.s()) { | |||||
| opt_input.push_back(item_s); | |||||
| } | |||||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||||
| op_desc_attr->erase("_opt_input"); | |||||
| } | |||||
| if (op_def_proto.attr().count("_input_name_key") > 0) { | |||||
| auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); | |||||
| for (const auto &item_s : output_name_key_list.s()) { | |||||
| key_in.push_back(item_s); | |||||
| } | |||||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||||
| op_desc_attr->erase("_input_name_key"); | |||||
| } | |||||
| if (op_def_proto.attr().count("_input_name_value") > 0) { | |||||
| auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); | |||||
| for (const auto &item_i : input_name_value_list.i()) { | |||||
| value_in.push_back(static_cast<uint32_t>(item_i)); | |||||
| } | |||||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||||
| op_desc_attr->erase("_input_name_value"); | |||||
| } | |||||
| std::vector<string> key_out; | |||||
| std::vector<uint32_t> value_out; | |||||
| if (op_def_proto.attr().count("_output_name_key") > 0) { | |||||
| auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); | |||||
| for (const auto &item_s : output_name_key_list.s()) { | |||||
| key_out.push_back(item_s); | |||||
| } | |||||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||||
| op_desc_attr->erase("_output_name_key"); | |||||
| } | |||||
| if (op_def_proto.attr().count("_output_name_value") > 0) { | |||||
| auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); | |||||
| for (const auto &item_i : output_name_value_list.i()) { | |||||
| value_out.push_back(static_cast<uint32_t>(item_i)); | |||||
| } | |||||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||||
| op_desc_attr->erase("_output_name_value"); | |||||
| } | |||||
| op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(protobuf_owner_, &op_def_proto)); | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr."); | |||||
| // Input tensor | |||||
| for (auto &input_desc : *op_def_proto.mutable_input_desc()) { | |||||
| std::shared_ptr<GeTensorDesc> temp_value = | |||||
| std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &input_desc)); | |||||
| GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | |||||
| op_desc->inputs_desc_.push_back(temp_value); | |||||
| } | |||||
| // Output tensor | |||||
| for (auto &output_desc : *op_def_proto.mutable_output_desc()) { | |||||
| std::shared_ptr<GeTensorDesc> temp_value = | |||||
| std::shared_ptr<GeTensorDesc>(new (std::nothrow) GeTensorDesc(protobuf_owner_, &output_desc)); | |||||
| GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | |||||
| op_desc->outputs_desc_.push_back(temp_value); | |||||
| } | |||||
| op_desc->SetId(op_def_proto.id()); | |||||
| uint32_t graph_index = 0; | |||||
| for (const std::string &name : op_def_proto.subgraph_name()) { | |||||
| op_desc->AddSubgraphName(name); | |||||
| op_desc->SetSubgraphInstanceName(graph_index++, name); | |||||
| } | |||||
| // insert name index by key and value | |||||
| AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) { | |||||
| GE_RT_FALSE_CHECK_NOTNULL(graph); | |||||
| OpDescPtr op_desc = nullptr; | |||||
| if (!UnserializeOpDesc(op_desc, op_def_proto)) { | |||||
| GELOGW("UnserializeOpDesc error."); | |||||
| } | |||||
| NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); | |||||
| // Inputs | |||||
| int dst_index = 0; | |||||
| for (const auto &input : op_def_proto.input()) { | |||||
| string node_name; | |||||
| int32_t index = 0; | |||||
| if (ParseNodeIndex(input, node_name, index)) { | |||||
| node_input_node_names_.push_back(NodeNameNodeReq{node_name, index, node, dst_index, op_def_proto.name()}); | |||||
| } | |||||
| if (index >= 0) { | |||||
| dst_index++; | |||||
| } | |||||
| } | |||||
| node_map_[op_def_proto.name()] = node; | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::HandleNodeNameRef() { | |||||
| // Edges | |||||
| for (auto &item : node_input_node_names_) { | |||||
| auto src_node_it = node_map_.find(item.src_node_name); | |||||
| if (src_node_it == node_map_.end()) { | |||||
| GELOGE(GRAPH_FAILED, "cannot find node %s", item.src_node_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(src_node_it->second == nullptr || item.dst_node == nullptr, continue); | |||||
| if (item.src_out_index >= 0) { | |||||
| auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index); | |||||
| auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); | |||||
| if (src_anchor == nullptr || dst_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "get anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index, | |||||
| item.dst_node_name.c_str(), item.dst_in_index); | |||||
| return false; | |||||
| } | |||||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||||
| } else { | |||||
| // Control edge | |||||
| auto src_anchor = src_node_it->second->GetOutControlAnchor(); | |||||
| auto dst_anchor = item.dst_node->GetInControlAnchor(); | |||||
| if (src_anchor != nullptr && dst_anchor != nullptr) { | |||||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||||
| } | |||||
| } | |||||
| } | |||||
| // Graph input | |||||
| for (auto &item : graph_input_node_names_) { | |||||
| auto node_it = node_map_.find(item.node_name); | |||||
| if (node_it == node_map_.end()) { | |||||
| GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(item.graph == nullptr, continue); | |||||
| auto ret = item.graph->AddInputNode(node_it->second); | |||||
| if (ret == nullptr) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Graph output | |||||
| for (auto &item : graph_output_node_names_) { | |||||
| auto node_it = node_map_.find(item.node_name); | |||||
| if (node_it == node_map_.end()) { | |||||
| GELOGE(GRAPH_FAILED, "cannot find node %s", item.node_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| GE_IF_BOOL_EXEC(item.graph == nullptr, continue); | |||||
| auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); | |||||
| GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index); | |||||
| if (ret == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddOutputNode failed."); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| node_input_node_names_.clear(); | |||||
| graph_input_node_names_.clear(); | |||||
| graph_output_node_names_.clear(); | |||||
| node_map_.clear(); | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) { | |||||
| std::queue<ComputeGraphPtr> all_graphs; | |||||
| all_graphs.emplace(compute_graph); | |||||
| while (!all_graphs.empty()) { | |||||
| ComputeGraphPtr graph = all_graphs.front(); | |||||
| all_graphs.pop(); | |||||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||||
| const OpDescPtr op_desc = node->GetOpDesc(); | |||||
| for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||||
| auto it = subgraphs.find(name); | |||||
| if (it == subgraphs.end()) { | |||||
| GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), | |||||
| subgraphs.size()); | |||||
| return false; | |||||
| } | |||||
| ComputeGraphPtr &subgraph = it->second; | |||||
| subgraph->SetParentGraph(graph); | |||||
| subgraph->SetParentNode(node); | |||||
| compute_graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
| all_graphs.emplace(subgraph); | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { | |||||
| model.name_ = model_proto.name(); | |||||
| model.version_ = model_proto.version(); | |||||
| model.platform_version_ = model_proto.custom_version(); | |||||
| model.attrs_ = ProtoAttrMapHelper(protobuf_owner_, model_proto.mutable_attr()); | |||||
| auto &graphs_proto = *model_proto.mutable_graph(); | |||||
| if (!graphs_proto.empty()) { | |||||
| auto &graph_proto = graphs_proto[0]; | |||||
| ComputeGraphPtr compute_graph_ptr; | |||||
| if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { | |||||
| model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | |||||
| } | |||||
| // 0 is main graph, following is subgraph. | |||||
| map<string, ComputeGraphPtr> subgraphs; | |||||
| for (int idx = 1; idx < graphs_proto.size(); ++idx) { | |||||
| ComputeGraphPtr subgraph; | |||||
| ModelSerializeImp impl; | |||||
| if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { | |||||
| GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); | |||||
| return false; | |||||
| } | |||||
| if (!impl.HandleNodeNameRef()) { | |||||
| GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||||
| return false; | |||||
| } | |||||
| subgraphs[subgraph->GetName()] = subgraph; | |||||
| } | |||||
| if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { | |||||
| GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (!HandleNodeNameRef()) { | |||||
| GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) { | |||||
| graph = ComGraphMakeShared<ComputeGraph>(graph_proto.name()); | |||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | |||||
| return false; | |||||
| } | |||||
| // Inputs | |||||
| for (auto input : graph_proto.input()) { | |||||
| string node_name; | |||||
| int32_t index; | |||||
| if (ParseNodeIndex(input, node_name, index)) { | |||||
| graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); | |||||
| } | |||||
| } | |||||
| // Outputs | |||||
| for (auto output : graph_proto.output()) { | |||||
| string node_name; | |||||
| int32_t index; | |||||
| if (ParseNodeIndex(output, node_name, index)) { | |||||
| graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); | |||||
| } | |||||
| } | |||||
| graph->attrs_ = ProtoAttrMapHelper(protobuf_owner_, graph_proto.mutable_attr()); | |||||
| for (auto &op_def_proto : *graph_proto.mutable_op()) { | |||||
| if (!UnserializeNode(graph, op_def_proto)) { | |||||
| GELOGE(GRAPH_FAILED, "UnserializeNode fail"); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph, | |||||
| proto::GraphDef &graph_proto) { | |||||
| if (!UnserializeGraphWithoutEdge(graph, graph_proto)) { | |||||
| GELOGW("UnserializeGraphWithoutEdge fail"); | |||||
| } | |||||
| if (!HandleNodeNameRef()) { | |||||
| GELOGE(GRAPH_FAILED, "Link Anchor or set graph input or output fail"); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf::Message *proto) { | |||||
| GE_CHK_BOOL_EXEC(data != nullptr, return false, "data is null."); | |||||
| GE_CHK_BOOL_EXEC(proto != nullptr, return false, "proto is null."); | |||||
| google::protobuf::io::CodedInputStream coded_stream(data, len); | |||||
| // 2048M -1 | |||||
| coded_stream.SetTotalBytesLimit(INT32_MAX, -1); | |||||
| if (!proto->ParseFromCodedStream(&coded_stream)) { | |||||
| GELOGE(GRAPH_FAILED, "ReadProtoFromBinaryFile failed len %zu", len); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { | |||||
| proto::ModelDef model_def; | |||||
| ModelSerializeImp imp; | |||||
| if (!imp.SerializeModel(model, &model_def, is_dump)) { | |||||
| return Buffer(); | |||||
| } | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| Buffer buffer(model_def.ByteSizeLong()); | |||||
| #else | |||||
| Buffer buffer(model_def.ByteSize()); | |||||
| #endif | |||||
| GE_CHK_BOOL_ONLY_LOG(buffer.GetSize() != 0, "get size failed"); | |||||
| GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); | |||||
| auto ret = model_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize())); | |||||
| if (ret != true) { | |||||
| GELOGW("serialize to array fail."); | |||||
| } | |||||
| return buffer; | |||||
| } | |||||
| size_t ModelSerialize::GetSerializeModelSize(const Model &model) { | |||||
| proto::ModelDef model_def; | |||||
| ModelSerializeImp imp; | |||||
| if (!imp.SerializeModel(model, &model_def)) { | |||||
| return 0; | |||||
| } | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| return model_def.ByteSizeLong(); | |||||
| #else | |||||
| return model_def.ByteSize(); | |||||
| #endif | |||||
| } | |||||
| Model ModelSerialize::UnserializeModel(const uint8_t *data, size_t len) { | |||||
| if (data == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "data is nullptr"); | |||||
| return Model(); | |||||
| } | |||||
| std::shared_ptr<proto::ModelDef> model_proto_ptr; | |||||
| model_proto_ptr = ComGraphMakeShared<proto::ModelDef>(); | |||||
| if (model_proto_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::ModelDef make shared failed"); | |||||
| return Model(); | |||||
| } | |||||
| auto &model_proto = *model_proto_ptr; | |||||
| if (!ReadProtoFromBinaryFile(data, len, &model_proto)) { | |||||
| GELOGE(GRAPH_FAILED, "ParseFromArray fail"); | |||||
| return Model(); | |||||
| } | |||||
| Model model; | |||||
| ModelSerializeImp imp; | |||||
| imp.SetProtobufOwner(model_proto_ptr); | |||||
| if (!imp.UnserializeModel(model, model_proto)) { | |||||
| GELOGE(GRAPH_FAILED, "Unserialize Model fail"); | |||||
| return Model(); | |||||
| } | |||||
| return model; | |||||
| } | |||||
| Model ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def) { | |||||
| std::shared_ptr<proto::ModelDef> model_def_ptr = ComGraphMakeShared<proto::ModelDef>(model_def); | |||||
| GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, return Model(), "mode_def make shared failed"); | |||||
| ModelSerializeImp imp; | |||||
| imp.SetProtobufOwner(model_def_ptr); | |||||
| Model model; | |||||
| if (!imp.UnserializeModel(model, *model_def_ptr)) { | |||||
| GELOGE(GRAPH_FAILED, "Unserialize Model fail"); | |||||
| return Model(); | |||||
| } | |||||
| return model; | |||||
| } | |||||
| Buffer ModelSerialize::SerializeGraph(const ComputeGraphPtr &graph) { | |||||
| proto::GraphDef graph_def; | |||||
| ModelSerializeImp imp; | |||||
| if (!imp.SerializeGraph(graph, &graph_def)) { | |||||
| return Buffer(); | |||||
| } | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| Buffer buffer(graph_def.ByteSizeLong()); | |||||
| #else | |||||
| Buffer buffer(graph_def.ByteSize()); | |||||
| #endif | |||||
| GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); | |||||
| GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); | |||||
| auto ret = graph_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize())); | |||||
| if (ret != true) { | |||||
| GE_LOGE("serialize to array fail."); | |||||
| } | |||||
| return buffer; | |||||
| } | |||||
| ComputeGraphPtr ModelSerialize::UnserializeGraph(const uint8_t *data, size_t len) { | |||||
| if (data == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "data is nullptr"); | |||||
| return nullptr; | |||||
| } | |||||
| std::shared_ptr<proto::GraphDef> graph_proto_ptr; | |||||
| graph_proto_ptr = ComGraphMakeShared<proto::GraphDef>(); | |||||
| if (graph_proto_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | |||||
| return nullptr; | |||||
| } | |||||
| proto::GraphDef &graph_proto = *graph_proto_ptr; | |||||
| if (!ReadProtoFromBinaryFile(data, len, &graph_proto)) { | |||||
| GELOGE(GRAPH_FAILED, "ParseFromArray fail"); | |||||
| return nullptr; | |||||
| } | |||||
| ComputeGraphPtr graph; | |||||
| ModelSerializeImp imp; | |||||
| imp.SetProtobufOwner(graph_proto_ptr); | |||||
| if (!imp.UnserializeGraph(graph, graph_proto)) { | |||||
| return nullptr; | |||||
| } | |||||
| return graph; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer ModelSerialize::SerializeOpDesc(const ConstOpDescPtr &op_desc) { | |||||
| proto::OpDef op_def; | |||||
| ModelSerializeImp imp; | |||||
| if (!imp.SerializeOpDesc(op_desc, &op_def)) { | |||||
| return Buffer(); | |||||
| } | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| Buffer buffer(op_def.ByteSizeLong()); | |||||
| #else | |||||
| Buffer buffer(op_def.ByteSize()); | |||||
| #endif | |||||
| GE_CHK_BOOL_ONLY_LOG((buffer.GetSize() != 0), "get size failed"); | |||||
| GE_CHK_BOOL_ONLY_LOG((buffer.GetData() != nullptr), "get size failed"); | |||||
| auto ret = op_def.SerializeToArray(buffer.GetData(), static_cast<int>(buffer.GetSize())); | |||||
| if (ret != true) { | |||||
| GE_LOGE("serialize to array fail."); | |||||
| } | |||||
| return buffer; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr ModelSerialize::UnserializeOpDesc(const uint8_t *data, | |||||
| size_t len) { | |||||
| if (data == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "data is nullptr"); | |||||
| return nullptr; | |||||
| } | |||||
| std::shared_ptr<proto::OpDef> op_def_ptr; | |||||
| op_def_ptr = ComGraphMakeShared<proto::OpDef>(); | |||||
| if (op_def_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); | |||||
| return nullptr; | |||||
| } | |||||
| proto::OpDef &op_def = *op_def_ptr; | |||||
| if (!ReadProtoFromBinaryFile(data, len, &op_def)) { | |||||
| GELOGE(GRAPH_FAILED, "ParseFromArray fail"); | |||||
| return nullptr; | |||||
| } | |||||
| OpDescPtr op_desc; | |||||
| ModelSerializeImp imp; | |||||
| imp.SetProtobufOwner(op_def_ptr); | |||||
| if (!imp.UnserializeOpDesc(op_desc, op_def)) { | |||||
| GELOGW("UnserializeOpDesc error."); | |||||
| } | |||||
| return op_desc; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,3 +0,0 @@ | |||||
| LOCAL_PATH := $(call my-dir) | |||||
| include $(LOCAL_PATH)/graph.mk | |||||
| @@ -1,877 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/node.h" | |||||
| #include <utility> | |||||
| #include "debug/ge_op_types.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "external/graph/operator_factory.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "graph/shape_refiner.h" | |||||
| #include "utils/ge_ir_utils.h" | |||||
| #include "utils/node_utils.h" | |||||
| #include "utils/op_desc_utils.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) | |||||
| : op_(op), | |||||
| owner_graph_(owner_graph), | |||||
| in_data_anchors_(), | |||||
| out_data_anchors_(), | |||||
| in_control_anchor_(nullptr), | |||||
| out_control_anchor_(nullptr), | |||||
| attrs_(), | |||||
| has_init_(false) { | |||||
| anchor_status_updated_ = false; | |||||
| } | |||||
| Node::~Node() { | |||||
| for (const auto &in_data_anchor : in_data_anchors_) { | |||||
| if (in_data_anchor != nullptr) { | |||||
| in_data_anchor->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| for (const auto &out_data_anchor : out_data_anchors_) { | |||||
| if (out_data_anchor != nullptr) { | |||||
| out_data_anchor->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| if (in_control_anchor_ != nullptr) { | |||||
| in_control_anchor_->UnlinkAll(); | |||||
| } | |||||
| if (out_control_anchor_ != nullptr) { | |||||
| out_control_anchor_->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| graphStatus Node::Init() { | |||||
| if (has_init_) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); | |||||
| size_t size = op_->GetAllInputsSize(); | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| std::shared_ptr<InDataAnchor> anchor = ComGraphMakeShared<InDataAnchor>(shared_from_this(), i); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| in_data_anchors_.push_back(anchor); | |||||
| } | |||||
| size = op_->GetOutputsSize(); | |||||
| for (size_t i = 0; i < size; i++) { | |||||
| std::shared_ptr<OutDataAnchor> anchor = ComGraphMakeShared<OutDataAnchor>(shared_from_this(), i); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Current out_data_anchor is null, malloc shared_ptr failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| out_data_anchors_.push_back(anchor); | |||||
| } | |||||
| in_control_anchor_ = ComGraphMakeShared<InControlAnchor>(shared_from_this(), -1); | |||||
| out_control_anchor_ = ComGraphMakeShared<OutControlAnchor>(shared_from_this(), -1); | |||||
| if (in_control_anchor_ == nullptr || out_control_anchor_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| has_init_ = true; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const { | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); | |||||
| return op_->GetName(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const { | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return string(), "original OpDesc is nullptr"); | |||||
| return op_->GetType(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAttrsAreEqual(const Node &r_node) const { | |||||
| const auto &attr_map = this->attrs_; | |||||
| const auto &r_attr_map = r_node.attrs_; | |||||
| // 1.Verify node's map<string, AttrValue> size | |||||
| if (attr_map.size() != r_attr_map.size()) { | |||||
| GELOGE(GRAPH_FAILED, "Size of node's attr map verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| // 2.Verify node's map<string, AttrValue> key, verify values is temporarily not implemented | |||||
| for (const auto &it : attr_map) { | |||||
| if (r_attr_map.count(it.first) == 0) { | |||||
| GELOGE(GRAPH_FAILED, "Key of node's attr map verify failed, node name: %s key name: %s.", this->GetName().c_str(), | |||||
| it.first.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const { | |||||
| return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) || | |||||
| ((this->op_ == nullptr) && (r_node.op_ == nullptr))) && | |||||
| IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") && | |||||
| IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") && | |||||
| IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") && | |||||
| IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_")); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor, | |||||
| const AnchorPtr &right_anchor, | |||||
| size_t i) const { | |||||
| GE_IF_BOOL_EXEC(left_anchor == nullptr, GELOGE(GRAPH_FAILED, "left_anchor is null."); return false); | |||||
| GE_IF_BOOL_EXEC(right_anchor == nullptr, GELOGE(GRAPH_FAILED, "right_anchor is null."); return false); | |||||
| const auto anchor_peer_size = left_anchor->GetPeerAnchors().size(); | |||||
| const auto right_anchor_peer_size = right_anchor->GetPeerAnchors().size(); | |||||
| // Firstly, verify anchor's peer anchors size equal or not | |||||
| if (anchor_peer_size != right_anchor_peer_size) { | |||||
| GELOGE(GRAPH_FAILED, | |||||
| "Size of anchor's peer anchors verify failed, node name: %s " | |||||
| "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", | |||||
| this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); | |||||
| return false; | |||||
| } | |||||
| // Secondly, verify anchor's peer anchor owner node equal or not | |||||
| for (size_t j = 0; j < anchor_peer_size; j++) { | |||||
| const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||||
| const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||||
| if (peer_node == nullptr || r_peer_node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||||
| this->GetName().c_str(), i, j); | |||||
| return false; | |||||
| } | |||||
| // Determine the connection relationship by linking the node's name | |||||
| if (peer_node->GetName() != r_peer_node->GetName()) { | |||||
| GELOGE(GRAPH_FAILED, | |||||
| "anchor's peer node name verify failed, node name: %s index[%zu]" | |||||
| "peer node name %s is different from %s at index [%zu].", | |||||
| this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const { | |||||
| // 1.Verify all in data and control anchors size | |||||
| const auto in_data_anchor_size = this->GetAllInDataAnchors().size(); | |||||
| const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size(); | |||||
| if (in_data_anchor_size != r_in_data_anchor_size) { | |||||
| GELOGE(GRAPH_FAILED, "Size of node's in data anchors verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| const auto l_in_anchors = this->GetAllInAnchors(); | |||||
| const auto r_in_anchors = r_node.GetAllInAnchors(); | |||||
| // Data anchors size equal, all anchors size not equal, means control anchor size not equal | |||||
| const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size; | |||||
| const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size; | |||||
| if (in_control_anchor_size != r_in_control_anchor_size) { | |||||
| GELOGE(GRAPH_FAILED, "Size of node's in control anchors verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| // 2.Verify all in data and control anchors connect info | |||||
| for (size_t i = 0; i < this->GetAllInAnchors().size(); i++) { | |||||
| // Verify data anchors | |||||
| if (i < in_data_anchor_size) { | |||||
| const auto &in_anchor = l_in_anchors.at(i); | |||||
| const auto &r_in_anchor = r_in_anchors.at(i); | |||||
| if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) { | |||||
| GELOGE(GRAPH_FAILED, "Node's in data control anchor verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| // Verify control anchors | |||||
| const auto &in_control_anchor = l_in_anchors.at(i); | |||||
| const auto &r_in_control_anchor = r_in_anchors.at(i); | |||||
| if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) { | |||||
| GELOGE(GRAPH_FAILED, "Node's in control anchor verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const { | |||||
| // 1.Verify all out data and control anchors size | |||||
| const auto l_out_data_anchors = this->GetAllOutDataAnchors(); | |||||
| const auto r_out_data_anchors = r_node.GetAllOutDataAnchors(); | |||||
| const auto out_data_anchor_size = l_out_data_anchors.size(); | |||||
| const auto r_out_data_anchor_size = r_out_data_anchors.size(); | |||||
| if (out_data_anchor_size != r_out_data_anchor_size) { | |||||
| GELOGE(GRAPH_FAILED, "Size of node's out data anchors verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| const auto l_out_anchors = this->GetAllOutAnchors(); | |||||
| const auto r_out_anchors = r_node.GetAllOutAnchors(); | |||||
| // Data anchors size equal, all anchors size not equal, means control anchor size not equal | |||||
| const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size; | |||||
| const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size; | |||||
| if (out_control_anchor_size != r_out_control_anchor_size) { | |||||
| GELOGE(GRAPH_FAILED, "Size of node's out control anchors verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| // 2.Verify all out data and control anchors connect info | |||||
| for (size_t i = 0; i < this->GetAllOutAnchors().size(); i++) { | |||||
| // Verify data anchors | |||||
| if (i < out_data_anchor_size) { | |||||
| const auto &out_anchor = l_out_data_anchors.at(i); | |||||
| const auto &r_out_anchor = r_out_data_anchors.at(i); | |||||
| if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) { | |||||
| GELOGE(GRAPH_FAILED, "Node's out data control anchor verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| // Verify control anchors | |||||
| const auto &out_control_anchor = l_out_anchors.at(i); | |||||
| const auto &r_out_control_anchor = r_out_anchors.at(i); | |||||
| if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) { | |||||
| GELOGE(GRAPH_FAILED, "Node's out control anchor verify failed, node name: %s.", this->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const { | |||||
| return (NodeMembersAreEqual(r_node) && NodeAttrsAreEqual(r_node) && NodeInConnectsAreEqual(r_node) && | |||||
| NodeOutConnectsAreEqual(r_node)); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) { | |||||
| // This function is deprecated, please use other two overloaded functions | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| // Input_node ---> this | |||||
| auto out_anchors = input_node->GetAllOutDataAnchors(); | |||||
| if (out_anchors.size() != 1) { | |||||
| GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); | |||||
| auto op_desc = input_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (op_->AddInputDesc(op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "add input desc failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<InDataAnchor> anchor = ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size()); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| in_data_anchors_.push_back(anchor); | |||||
| (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index, | |||||
| NodePtr input_node) { | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| // Input_node ---> this | |||||
| auto out_anchors = input_node->GetAllOutDataAnchors(); | |||||
| if (out_anchors.size() != 1) { | |||||
| GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, only support 1", out_anchors.size()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| GE_CHECK_NOTNULL(op_); | |||||
| auto op_desc = input_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "add input desc failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (index < GetAllInDataAnchors().size()) { | |||||
| (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); | |||||
| } else { | |||||
| std::shared_ptr<InDataAnchor> anchor = | |||||
| ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size()); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| in_data_anchors_.push_back(anchor); | |||||
| (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) { | |||||
| // This function is used for ParseWeights. | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| // Input_node ---> this | |||||
| auto out_anchors = input_node->GetAllOutDataAnchors(); | |||||
| if (out_anchors.size() != 1) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| std::shared_ptr<InDataAnchor> anchor = ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size()); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "out_anchor size is:%zu, make anchor failed", out_anchors.size()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| in_data_anchors_.push_back(anchor); | |||||
| (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const string &name, NodePtr input_node) { | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| // Input_node ---> this | |||||
| auto out_anchors = input_node->GetAllOutDataAnchors(); | |||||
| if (out_anchors.size() != 1) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "out_anchor size is:%zu, only support 1", out_anchors.size()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| GE_CHECK_NOTNULL(op_); | |||||
| auto input_op_desc = input_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(input_op_desc); | |||||
| auto index = op_->GetInputIndexByName(name); | |||||
| if (index != -1) { | |||||
| if (index >= static_cast<int>(in_data_anchors_.size())) { | |||||
| GELOGE(GRAPH_FAILED, "op %s get input name %s 's index %d is illegal.", op_->GetName().c_str(), name.c_str(), | |||||
| index); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| (void)out_anchors.at(0)->LinkTo(in_data_anchors_[index]); | |||||
| } else { | |||||
| std::shared_ptr<InDataAnchor> anchor = | |||||
| ComGraphMakeShared<InDataAnchor>(shared_from_this(), in_data_anchors_.size()); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "in_data_anchors_size is:%zu, malloc shared_ptr failed.", in_data_anchors_.size()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| in_data_anchors_.push_back(anchor); | |||||
| (void)out_anchors.at(0)->LinkTo(in_data_anchors_.back()); | |||||
| } | |||||
| if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0)) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "add input desc failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const { | |||||
| return owner_graph_.lock(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { | |||||
| if (graph == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| owner_graph_ = graph; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<InDataAnchorPtr> Node::GetAllInDataAnchors() const { | |||||
| return Vistor<InDataAnchorPtr>(shared_from_this(), in_data_anchors_); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<OutDataAnchorPtr> Node::GetAllOutDataAnchors() const { | |||||
| return Vistor<OutDataAnchorPtr>(shared_from_this(), out_data_anchors_); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const { | |||||
| return in_data_anchors_.size(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const { | |||||
| return out_data_anchors_.size(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::GetAllInAnchors() const { | |||||
| std::vector<AnchorPtr> vec; | |||||
| // Push back in_data_anchors_ | |||||
| for (const auto &in_anchor_iter : Vistor<InDataAnchorPtr>(shared_from_this(), in_data_anchors_)) { | |||||
| auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_anchor_iter); | |||||
| if (in_anchor != nullptr) { | |||||
| vec.push_back(in_anchor); | |||||
| } | |||||
| } | |||||
| // Push back in_control_anchor_ | |||||
| if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | |||||
| (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | |||||
| auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | |||||
| if (in_anchor != nullptr) { | |||||
| vec.push_back(in_anchor); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<AnchorPtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::GetAllOutAnchors() const { | |||||
| std::vector<AnchorPtr> vec; | |||||
| // Push back out_data_anchors_ | |||||
| for (const auto &out_anchor_iter : Vistor<OutDataAnchorPtr>(shared_from_this(), out_data_anchors_)) { | |||||
| auto out_anchor = Anchor::DynamicAnchorCast<Anchor>(out_anchor_iter); | |||||
| if (out_anchor != nullptr) { | |||||
| vec.push_back(out_anchor); | |||||
| } | |||||
| } | |||||
| // Push back out_control_anchor_ | |||||
| if (out_control_anchor_->GetPeerInControlAnchors().size() > 0 || | |||||
| out_control_anchor_->GetPeerInDataAnchors().size() > 0) { | |||||
| auto out_anchor = Anchor::DynamicAnchorCast<Anchor>(out_control_anchor_); | |||||
| if (out_anchor != nullptr) { | |||||
| vec.push_back(out_anchor); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<AnchorPtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { | |||||
| if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||||
| {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); | |||||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, | |||||
| GetType().c_str()); | |||||
| return nullptr; | |||||
| } else { | |||||
| return in_data_anchors_[idx]; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { | |||||
| // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ | |||||
| if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||||
| GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); | |||||
| return nullptr; | |||||
| } else { | |||||
| // Return control anchor | |||||
| if (idx == -1) { | |||||
| auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | |||||
| return in_anchor; | |||||
| } | |||||
| // Return data anchor | |||||
| return in_data_anchors_[idx]; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { | |||||
| // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ | |||||
| if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, | |||||
| { | |||||
| GetName().c_str(), | |||||
| std::to_string(idx), | |||||
| "out_anchor", | |||||
| GetType().c_str(), | |||||
| }); | |||||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, | |||||
| GetType().c_str()); | |||||
| return nullptr; | |||||
| } else { | |||||
| // Return control anchor | |||||
| if (idx == -1) { | |||||
| auto out_anchor = Anchor::DynamicAnchorCast<Anchor>(out_control_anchor_); | |||||
| return out_anchor; | |||||
| } | |||||
| // Return data anchor | |||||
| return out_data_anchors_[idx]; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { | |||||
| if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||||
| {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); | |||||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, | |||||
| GetType().c_str()); | |||||
| return nullptr; | |||||
| } else { | |||||
| return out_data_anchors_[idx]; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const { | |||||
| return in_control_anchor_; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const { | |||||
| return out_control_anchor_; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &in_anchor : in_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto node = out_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| if (in_control_anchor_ != nullptr) { | |||||
| if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | |||||
| for (const auto &out_anchor : peer_out_anchors) { | |||||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); | |||||
| auto node = out_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | |||||
| for (const auto &out_control_anchor : peer_out_control_anchors) { | |||||
| GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | |||||
| "in_control_anchor_ peer out control anchors is nullptr"); | |||||
| auto node = out_control_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( | |||||
| std::unordered_set<Node *> &nodes_seen) const { | |||||
| for (const auto &in_anchor : in_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto node = out_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { | |||||
| continue; | |||||
| } | |||||
| if (nodes_seen.count(node.get()) == 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (in_control_anchor_ != nullptr) { | |||||
| if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { | |||||
| return true; | |||||
| } | |||||
| auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | |||||
| for (const auto &out_control_anchor : peer_out_control_anchors) { | |||||
| GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, "out_control_anchor is nullptr"); | |||||
| auto node = out_control_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { | |||||
| continue; | |||||
| } | |||||
| if (nodes_seen.count(node.get()) == 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInDataNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &in_anchor : in_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "in_data_anchor is nullptr"); | |||||
| auto anchor_ptr = in_anchor->GetPeerOutAnchor(); | |||||
| if (anchor_ptr == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto node = anchor_ptr->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInControlNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| if (in_control_anchor_ != nullptr) { | |||||
| for (const auto &in_anchor : in_control_anchor_->GetPeerOutControlAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerOutControlAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &out_anchor : out_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); | |||||
| for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC((peer_in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); | |||||
| auto node = peer_in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| if (out_control_anchor_ != nullptr) { | |||||
| auto peer_in_control_anchors = out_control_anchor_->GetPeerInControlAnchors(); | |||||
| for (const auto &in_control_anchor : peer_in_control_anchors) { | |||||
| GE_CHK_BOOL_EXEC(in_control_anchor != nullptr, continue, | |||||
| "out_control_anchor_ peer in control anchors is nullptr"); | |||||
| auto node = in_control_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetInAllNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &in_node : GetInDataNodes()) { | |||||
| vec.push_back(in_node); | |||||
| } | |||||
| for (const auto &in_control_node : GetInControlNodes()) { | |||||
| vec.push_back(in_control_node); | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutDataNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &out_anchor : out_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); | |||||
| for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInDataAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const { | |||||
| uint32_t out_nums = 0; | |||||
| for (const auto &out_anchor : out_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); | |||||
| out_nums += out_anchor->GetPeerInDataNodesSize(); | |||||
| } | |||||
| return out_nums; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutControlNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &out_anchor : out_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "out_data_anchors_ is nullptr"); | |||||
| for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { | |||||
| GE_CHK_BOOL_EXEC((in_anchor != nullptr), continue, "GetPeerInControlAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| if (out_control_anchor_ != nullptr) { | |||||
| for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetOutAllNodes() const { | |||||
| std::vector<NodePtr> vec; | |||||
| for (const auto &out_anchor : out_data_anchors_) { | |||||
| GE_CHK_BOOL_EXEC((out_anchor != nullptr), { continue; }, "out_data_anchors_ is nullptr"); | |||||
| for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC((in_anchor != nullptr), { continue; }, "GetPeerInDataAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| for (const auto &in_anchor : out_anchor->GetPeerInControlAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| if (out_control_anchor_ != nullptr) { | |||||
| for (const auto &in_anchor : out_control_anchor_->GetPeerAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(in_anchor != nullptr, continue, "GetPeerInControlAnchors is nullptr"); | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||||
| vec.push_back(node); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<NodePtr>(shared_from_this(), vec); | |||||
| } | |||||
| graphStatus Node::InferShapeAndType() const { | |||||
| Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); | |||||
| graphStatus ret = ShapeRefiner::InferShapeAndType(shared_from_this(), op); | |||||
| return ret; | |||||
| } | |||||
| graphStatus Node::InferOriginFormat() const { | |||||
| Operator op = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); | |||||
| // Get infer func and execute | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); | |||||
| return op_->CallInferFormatFunc(op); | |||||
| } | |||||
| graphStatus Node::Verify() const { | |||||
| const string data_type = "Data"; | |||||
| const string aipp_data_type = "AippData"; | |||||
| const string const_type = "Const"; | |||||
| const string variable_type = "Variable"; | |||||
| bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); | |||||
| if (!is_unknown_graph) { | |||||
| for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { | |||||
| GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue); | |||||
| bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || | |||||
| op_->GetType() == const_type || op_->GetType() == variable_type || | |||||
| op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; | |||||
| if (!valid_anchor) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, | |||||
| {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); | |||||
| GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| string frameworkop_type = "FrameworkOp"; | |||||
| bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; | |||||
| if (need_update_name) { | |||||
| auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); | |||||
| if (node_op.IsEmpty()) { | |||||
| GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); | |||||
| } else { | |||||
| GELOGD("get op from OperatorFactory success. opType: %s", op_->GetType().c_str()); | |||||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
| if (temp_op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "temp op desc is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!op_->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
| GELOGW("Verify UpdateInputName failed"); | |||||
| } | |||||
| if (!op_->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
| GELOGW("Verify UpdateOutputName failed"); | |||||
| } | |||||
| } | |||||
| node_op.BreakConnect(); | |||||
| } | |||||
| GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); | |||||
| if (op_->CommonVerify() == GRAPH_SUCCESS) { | |||||
| Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); | |||||
| auto verify_func = op_->GetVerifyFunc(); | |||||
| if (verify_func == nullptr) { | |||||
| verify_func = OperatorFactoryImpl::GetVerifyFunc(GetType()); | |||||
| } | |||||
| if (verify_func != nullptr) { | |||||
| return (graphStatus)verify_func(op_proxy); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "%s Verify failed.", op_->GetType().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { return op_; } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) { | |||||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return GRAPH_PARAM_INVALID, "Param OpDesc is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | |||||
| "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | |||||
| op_desc->GetInputsSize()); | |||||
| GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | |||||
| "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | |||||
| op_desc->GetOutputsSize()); | |||||
| op_ = op_desc; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<std::pair<NodePtr, OutDataAnchorPtr>> | |||||
| Node::GetInDataNodesAndAnchors() const { | |||||
| std::vector<std::pair<NodePtr, OutDataAnchorPtr>> vec; | |||||
| for (const auto &p : in_data_anchors_) { | |||||
| if (p == nullptr) { | |||||
| GELOGW("indata anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto anchor_ptr = p->GetPeerOutAnchor(); | |||||
| if (anchor_ptr == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto node = anchor_ptr->GetOwnerNode(); | |||||
| if (node == nullptr) { | |||||
| GELOGW("src node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| vec.push_back(std::make_pair(node, anchor_ptr)); | |||||
| } | |||||
| return Node::Vistor<std::pair<NodePtr, OutDataAnchorPtr>>(shared_from_this(), vec); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<std::pair<NodePtr, InDataAnchorPtr>> | |||||
| Node::GetOutDataNodesAndAnchors() const { | |||||
| std::vector<std::pair<NodePtr, InDataAnchorPtr>> vec; | |||||
| for (const auto &p : out_data_anchors_) { | |||||
| if (p == nullptr) { | |||||
| GELOGW("out data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| for (const auto &in_anchor : p->GetPeerInDataAnchors()) { | |||||
| if (in_anchor == nullptr) { | |||||
| GELOGW("dst in data anchor is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto node = in_anchor->GetOwnerNode(); | |||||
| if (node == nullptr) { | |||||
| GELOGW("dst node is nullptr, node %s:%s", GetType().c_str(), GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| vec.push_back(std::make_pair(node, in_anchor)); | |||||
| } | |||||
| } | |||||
| return Node::Vistor<std::pair<NodePtr, InDataAnchorPtr>>(shared_from_this(), vec); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,79 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdint.h> | |||||
| #include <functional> | |||||
| #include <vector> | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_util.h" | |||||
| using namespace std; | |||||
| namespace ge { | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| BroadCastInfer(const function<vector<int64_t>()>& get_in1_shape, const function<vector<int64_t>()>& get_in2_shape, | |||||
| const function<void(const vector<int64_t>& outShape)>& set_out_shape) { | |||||
| auto x1_shape = get_in1_shape(); | |||||
| auto x2_shape = get_in2_shape(); | |||||
| vector<int64_t> y_shape; | |||||
| if (x1_shape.empty()) { | |||||
| y_shape = x2_shape; | |||||
| set_out_shape(y_shape); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| if (x2_shape.empty()) { | |||||
| y_shape = x1_shape; | |||||
| set_out_shape(y_shape); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| int len_diff = static_cast<int>(x1_shape.size() - x2_shape.size()); | |||||
| if (len_diff >= 0) { | |||||
| for (int i = 0; i < len_diff; i++) { | |||||
| y_shape.push_back(x1_shape[i]); | |||||
| } | |||||
| int x2_shape_size = static_cast<int>(x2_shape.size()); | |||||
| for (int i = 0; i < x2_shape_size; i++) { | |||||
| bool shapeFlag = | |||||
| ((x1_shape[i + len_diff] != x2_shape[i]) && (std::min(x1_shape[i + len_diff], x2_shape[i]) != 1)); | |||||
| if (shapeFlag) { | |||||
| GE_LOGE("operands could not be broadcast together"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| y_shape.push_back(std::max(x1_shape[i + len_diff], x2_shape[i])); | |||||
| } | |||||
| } else { | |||||
| for (int i = 0; i < -len_diff; i++) { | |||||
| y_shape.push_back(x2_shape[i]); | |||||
| } | |||||
| int x1_shape_size = static_cast<int>(x1_shape.size()); | |||||
| for (int i = 0; i < x1_shape_size; i++) { | |||||
| bool shapeFlag = | |||||
| ((x1_shape[i] != x2_shape[i - len_diff]) && (std::min(x1_shape[i], x2_shape[i - len_diff]) != 1)); | |||||
| if (shapeFlag) { | |||||
| GE_LOGE("operands could not be broadcast together"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| y_shape.push_back(std::max(x1_shape[i], x2_shape[i - len_diff])); | |||||
| } | |||||
| } | |||||
| set_out_shape(y_shape); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "debug/ge_log.h" | |||||
| namespace ge { | |||||
| Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) { | |||||
| return OperatorFactoryImpl::CreateOperator(operator_name, operator_type); | |||||
| } | |||||
| graphStatus OperatorFactory::GetOpsTypeList(std::vector<std::string> &all_ops) { | |||||
| return OperatorFactoryImpl::GetOpsTypeList(all_ops); | |||||
| } | |||||
| bool OperatorFactory::IsExistOp(const string &operator_type) { return OperatorFactoryImpl::IsExistOp(operator_type); } | |||||
| OperatorCreatorRegister::OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator) { | |||||
| (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator); | |||||
| } | |||||
| InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type, | |||||
| const InferShapeFunc &infer_shape_func) { | |||||
| (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func); | |||||
| } | |||||
| InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type, | |||||
| const InferFormatFunc &infer_format_func) { | |||||
| (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func); | |||||
| } | |||||
| VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) { | |||||
| (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,149 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| shared_ptr<std::map<string, OpCreator>> OperatorFactoryImpl::operator_creators_; | |||||
| shared_ptr<std::map<string, InferShapeFunc>> OperatorFactoryImpl::operator_infershape_funcs_; | |||||
| shared_ptr<std::map<string, InferFormatFunc>> OperatorFactoryImpl::operator_inferformat_funcs_; | |||||
| shared_ptr<std::map<string, VerifyFunc>> OperatorFactoryImpl::operator_verify_funcs_; | |||||
| Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) { | |||||
| if (operator_creators_ == nullptr) { | |||||
| return Operator(); | |||||
| } | |||||
| auto it = operator_creators_->find(operator_type); | |||||
| if (it == operator_creators_->end()) { | |||||
| GELOGW("no OpProto of [%s] registered", operator_type.c_str()); | |||||
| return Operator(); | |||||
| } | |||||
| return it->second(operator_name); | |||||
| } | |||||
| graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector<std::string> &all_ops) { | |||||
| all_ops.clear(); | |||||
| if (operator_creators_ != nullptr) { | |||||
| for (auto it = operator_creators_->begin(); it != operator_creators_->end(); ++it) { | |||||
| all_ops.emplace_back(it->first); | |||||
| } | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "no operator creators found"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool OperatorFactoryImpl::IsExistOp(const string &operator_type) { | |||||
| if (operator_creators_ == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto it = operator_creators_->find(operator_type); | |||||
| if (it == operator_creators_->end()) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) { | |||||
| if (operator_infershape_funcs_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto it = operator_infershape_funcs_->find(operator_type); | |||||
| if (it == operator_infershape_funcs_->end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return it->second; | |||||
| } | |||||
| InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) { | |||||
| if (operator_inferformat_funcs_ == nullptr) { | |||||
| GELOGI("operator_inferformat_funcs_ is null"); | |||||
| return nullptr; | |||||
| } | |||||
| auto it = operator_inferformat_funcs_->find(operator_type); | |||||
| if (it == operator_inferformat_funcs_->end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return it->second; | |||||
| } | |||||
| VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) { | |||||
| if (operator_verify_funcs_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto it = operator_verify_funcs_->find(operator_type); | |||||
| if (it == operator_verify_funcs_->end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return it->second; | |||||
| } | |||||
| graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | |||||
| if (operator_creators_ == nullptr) { | |||||
| operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | |||||
| } | |||||
| auto it = operator_creators_->find(operator_type); | |||||
| if (it != operator_creators_->end()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| (void)operator_creators_->emplace(operator_type, op_creator); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type, | |||||
| InferShapeFunc const infer_shape_func) { | |||||
| if (operator_infershape_funcs_ == nullptr) { | |||||
| GELOGI("operator_infershape_funcs_ init"); | |||||
| operator_infershape_funcs_.reset(new (std::nothrow) std::map<string, InferShapeFunc>()); | |||||
| } | |||||
| auto it = operator_infershape_funcs_->find(operator_type); | |||||
| if (it != operator_infershape_funcs_->end()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type, | |||||
| InferFormatFunc const infer_format_func) { | |||||
| if (operator_inferformat_funcs_ == nullptr) { | |||||
| GELOGI("operator_inferformat_funcs_ init"); | |||||
| operator_inferformat_funcs_.reset(new (std::nothrow) std::map<string, InferFormatFunc>()); | |||||
| } | |||||
| auto it = operator_inferformat_funcs_->find(operator_type); | |||||
| if (it != operator_inferformat_funcs_->end()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) { | |||||
| if (operator_verify_funcs_ == nullptr) { | |||||
| GELOGI("operator_verify_funcs_ init"); | |||||
| operator_verify_funcs_.reset(new (std::nothrow) std::map<string, VerifyFunc>()); | |||||
| } | |||||
| auto it = operator_verify_funcs_->find(operator_type); | |||||
| if (it != operator_verify_funcs_->end()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| (void)operator_verify_funcs_->emplace(operator_type, verify_func); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,187 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/opsproto_manager.h" | |||||
| #include <cstdlib> | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <iostream> | |||||
| #include <sstream> | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/debug/ge_log.h" | |||||
| namespace ge { | |||||
| OpsProtoManager *OpsProtoManager::Instance() { | |||||
| static OpsProtoManager instance; | |||||
| return &instance; | |||||
| } | |||||
| bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| if (is_init_) { | |||||
| GELOGI("OpsProtoManager is already initialized."); | |||||
| return true; | |||||
| } | |||||
| /*lint -e1561*/ | |||||
| auto proto_iter = options.find("ge.opsProtoLibPath"); | |||||
| /*lint +e1561*/ | |||||
| if (proto_iter == options.end()) { | |||||
| GELOGW("ge.opsProtoLibPath option not set, return."); | |||||
| return false; | |||||
| } | |||||
| pluginPath_ = proto_iter->second; | |||||
| LoadOpsProtoPluginSo(pluginPath_); | |||||
| is_init_ = true; | |||||
| return true; | |||||
| } | |||||
| void OpsProtoManager::Finalize() { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| if (!is_init_) { | |||||
| GELOGI("OpsProtoManager is not initialized."); | |||||
| return; | |||||
| } | |||||
| for (auto handle : handles_) { | |||||
| if (handle != nullptr) { | |||||
| if (dlclose(handle) != 0) { | |||||
| GELOGW("failed to close handle, message: %s", dlerror()); | |||||
| continue; | |||||
| } | |||||
| GELOGI("close opsprotomanager handler success"); | |||||
| } else { | |||||
| GELOGW("close opsprotomanager handler failure, handler is nullptr"); | |||||
| } | |||||
| } | |||||
| is_init_ = false; | |||||
| } | |||||
| static std::vector<std::string> Split(const std::string &str, char delim) { | |||||
| std::vector<std::string> elems; | |||||
| if (str.empty()) { | |||||
| elems.emplace_back(""); | |||||
| return elems; | |||||
| } | |||||
| std::stringstream ss(str); | |||||
| std::string item; | |||||
| while (getline(ss, item, delim)) { | |||||
| elems.push_back(item); | |||||
| } | |||||
| auto str_size = str.size(); | |||||
| if (str_size > 0 && str[str_size - 1] == delim) { | |||||
| elems.emplace_back(""); | |||||
| } | |||||
| return elems; | |||||
| } | |||||
| static void FindParserSo(const std::string &path, std::vector<std::string> &file_list) { | |||||
| // Lib plugin path not exist | |||||
| if (path.empty()) { | |||||
| GELOGI("realPath is empty"); | |||||
| return; | |||||
| } | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path.size() >= PATH_MAX, return, "path is invalid"); | |||||
| char resolved_path[PATH_MAX] = {0}; | |||||
| // Nullptr is returned when the path does not exist or there is no permission | |||||
| // Return absolute path when path is accessible | |||||
| if (realpath(path.c_str(), resolved_path) == nullptr) { | |||||
| GELOGW("the path [%s] not exsit.", path.c_str()); | |||||
| return; | |||||
| } | |||||
| struct dirent *dent = nullptr; | |||||
| DIR *dir = opendir(resolved_path); | |||||
| // Lib plugin path not exist | |||||
| if (dir == nullptr) { | |||||
| GELOGW("Open directory %s failed,maybe it is not exit or not a dir", resolved_path); | |||||
| return; | |||||
| } | |||||
| while ((dent = readdir(dir)) != nullptr) { | |||||
| if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) { | |||||
| continue; | |||||
| } | |||||
| std::string name = dent->d_name; | |||||
| std::string full_name = path + "/" + name; | |||||
| const std::string so_suff = ".so"; | |||||
| if (dent->d_type != DT_DIR && name.size() >= so_suff.size() && | |||||
| name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { | |||||
| file_list.push_back(full_name); | |||||
| GELOGI("OpsProtoManager Parse full name = %s \n", full_name.c_str()); | |||||
| } | |||||
| } | |||||
| if (closedir(dir) != 0) { | |||||
| GELOGW("close dir fail."); | |||||
| } | |||||
| } | |||||
| static void GetPluginSoFileList(const std::string &path, std::vector<std::string> &file_list) { | |||||
| // Support multi lib directory with ":" as delimiter | |||||
| std::vector<std::string> v_path = Split(path, ':'); | |||||
| for (size_t i = 0; i < v_path.size(); ++i) { | |||||
| FindParserSo(v_path[i], file_list); | |||||
| GELOGI("OpsProtoManager full name = %s", v_path[i].c_str()); | |||||
| } | |||||
| } | |||||
| void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { | |||||
| if (path.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "filePath is invalid. please check your text file %s.", path.c_str()); | |||||
| return; | |||||
| } | |||||
| std::vector<std::string> file_list; | |||||
| // If there is .so file in the lib path | |||||
| GetPluginSoFileList(path, file_list); | |||||
| // Not found any .so file in the lib path | |||||
| if (file_list.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str()); | |||||
| return; | |||||
| } | |||||
| // Warning message | |||||
| GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); | |||||
| // Load .so file | |||||
| for (auto elem : file_list) { | |||||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
| if (handle == nullptr) { | |||||
| GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | |||||
| continue; | |||||
| } else { | |||||
| // Close dl when the program exist, not close here | |||||
| GELOGI("OpsProtoManager plugin load %s success.", elem.c_str()); | |||||
| handles_.push_back(handle); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,104 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "./ge_context.h" | |||||
| #include "./ge_global_options.h" | |||||
| #include "./ge_local_context.h" | |||||
| #include "framework/common/ge_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const int64_t kMinTrainingTraceJobId = 256; | |||||
| const int kDecimal = 10; | |||||
| const char *kHostExecPlacement = "HOST"; | |||||
| } // namespace | |||||
| GEContext &GetContext() { | |||||
| static GEContext ge_context{}; | |||||
| return ge_context; | |||||
| } | |||||
| graphStatus GEContext::GetOption(const std::string &key, std::string &option) { | |||||
| return GetThreadLocalContext().GetOption(key, option); | |||||
| } | |||||
| bool GEContext::GetHostExecFlag() { | |||||
| std::string exec_placement; | |||||
| if (GetThreadLocalContext().GetOption(GE_OPTION_EXEC_PLACEMENT, exec_placement) != GRAPH_SUCCESS) { | |||||
| GELOGW("get option OPTION_EXEC_PLACEMENT failed."); | |||||
| return false; | |||||
| } | |||||
| GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); | |||||
| return exec_placement == kHostExecPlacement; | |||||
| } | |||||
| std::map<std::string, std::string> &GetMutableGlobalOptions() { | |||||
| static std::map<std::string, std::string> global_options{}; | |||||
| return global_options; | |||||
| } | |||||
| void GEContext::Init() { | |||||
| string session_id; | |||||
| (void)GetOption("ge.exec.sessionId", session_id); | |||||
| try { | |||||
| session_id_ = static_cast<uint64_t>(std::stoi(session_id.c_str())); | |||||
| } catch (std::invalid_argument &) { | |||||
| GELOGW("%s transform to int failed.", session_id.c_str()); | |||||
| } catch (std::out_of_range &) { | |||||
| GELOGW("%s transform to int failed.", session_id.c_str()); | |||||
| } | |||||
| string device_id; | |||||
| (void)GetOption("ge.exec.deviceId", device_id); | |||||
| try { | |||||
| device_id_ = static_cast<uint32_t>(std::stoi(device_id.c_str())); | |||||
| } catch (std::invalid_argument &) { | |||||
| GELOGW("%s transform to int failed.", device_id.c_str()); | |||||
| } catch (std::out_of_range &) { | |||||
| GELOGW("%s transform to int failed.", device_id.c_str()); | |||||
| } | |||||
| string job_id; | |||||
| (void)GetOption("ge.exec.jobId", job_id); | |||||
| std::string s_job_id = ""; | |||||
| for (auto c : job_id) { | |||||
| if (c >= '0' && c <= '9') { | |||||
| s_job_id += c; | |||||
| } | |||||
| } | |||||
| if (s_job_id == "") { | |||||
| trace_id_ = kMinTrainingTraceJobId; | |||||
| return; | |||||
| } | |||||
| int64_t d_job_id = std::strtoll(s_job_id.c_str(), nullptr, kDecimal); | |||||
| if (d_job_id < kMinTrainingTraceJobId) { | |||||
| trace_id_ = d_job_id + kMinTrainingTraceJobId; | |||||
| } else { | |||||
| trace_id_ = d_job_id; | |||||
| } | |||||
| } | |||||
| uint64_t GEContext::SessionId() { return session_id_; } | |||||
| uint32_t GEContext::DeviceId() { return device_id_; } | |||||
| uint64_t GEContext::TraceId() { return trace_id_; } | |||||
| void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } | |||||
| void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | |||||
| } // namespace ge | |||||
| @@ -1,60 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "./ge_local_context.h" | |||||
| #include <utility> | |||||
| namespace ge { | |||||
| namespace { | |||||
| thread_local GEThreadLocalContext thread_context; | |||||
| } | |||||
| GEThreadLocalContext &GetThreadLocalContext() { return thread_context; } | |||||
| graphStatus GEThreadLocalContext::GetOption(const string &key, string &option) { | |||||
| auto graph_iter = graph_options_.find(key); | |||||
| if (graph_iter != graph_options_.end()) { | |||||
| option = graph_iter->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto session_iter = session_options_.find(key); | |||||
| if (session_iter != session_options_.end()) { | |||||
| option = session_iter->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto global_iter = global_options_.find(key); | |||||
| if (global_iter != global_options_.end()) { | |||||
| option = global_iter->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| void GEThreadLocalContext::SetGlobalOption(map<string, string> options_map) { | |||||
| global_options_.clear(); | |||||
| global_options_ = std::move(options_map); | |||||
| } | |||||
| void GEThreadLocalContext::SetSessionOption(map<string, string> options_map) { | |||||
| session_options_.clear(); | |||||
| session_options_ = std::move(options_map); | |||||
| } | |||||
| void GEThreadLocalContext::SetGraphOption(map<std::string, string> options_map) { | |||||
| graph_options_.clear(); | |||||
| graph_options_ = std::move(options_map); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,455 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/ref_relation.h" | |||||
| #include <unordered_set> | |||||
| #include <unordered_map> | |||||
| #include "utils/mem_utils.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_op_types.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "debug/ge_attr_define.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| using namespace std; | |||||
| using namespace ge; | |||||
| namespace ge { | |||||
| namespace { | |||||
| const char *kRefIndex = "_parent_node_index"; | |||||
| const string kWhile = "While"; | |||||
| const string kIf = "If"; | |||||
| const string kCase = "Case"; | |||||
| const uint16_t kMaxElementNum = 100; | |||||
| std::unordered_set<string> function_op = {kWhile, kIf, kCase}; | |||||
| } // namespace | |||||
| /* Impl */ | |||||
| class RefRelations::Impl { | |||||
| public: | |||||
| graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||||
| unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get())); | |||||
| std::string lookup_key = | |||||
| key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); | |||||
| auto iter = look_up_table_.find(lookup_key); | |||||
| if (iter != look_up_table_.end()) { | |||||
| for (auto &c : iter->second) { | |||||
| result.insert(c); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GELOGW("can not find any relations! key value is %s", lookup_key.c_str()); | |||||
| return GRAPH_SUCCESS; | |||||
| }; | |||||
| graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||||
| graphStatus Clear() { | |||||
| GELOGD("Start clear boundary reflections between main graph and sub graph!"); | |||||
| look_up_table_.clear(); | |||||
| values_.clear(); | |||||
| return GRAPH_SUCCESS; | |||||
| }; | |||||
| private: | |||||
| graphStatus BuildLookUpTables(); | |||||
| graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||||
| vector<vector<RefCell>> &node_refs); | |||||
| graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||||
| vector<vector<RefCell>> &node_refs); | |||||
| graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, | |||||
| const vector<vector<NodePtr>> &classed_data_nodes, | |||||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||||
| vector<vector<RefCell>> &node_refs); | |||||
| void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||||
| vector<NodePtr> &netoutput_nodes, const std::vector<std::string> &sub_graph_names, | |||||
| const std::string &node_type); | |||||
| graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); | |||||
| graphStatus ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, vector<vector<NodePtr>> &classed_data_nodes); | |||||
| graphStatus ProcessSubgraphNetoutput(const vector<NodePtr> &netoutput_nodes, | |||||
| vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes); | |||||
| std::unordered_map<string, vector<RefCell>> look_up_table_; | |||||
| std::vector<vector<vector<RefCell>>> values_; | |||||
| }; | |||||
| // Node Level | |||||
| graphStatus RefRelations::Impl::BuildRefRelationsForBranch( | |||||
| const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||||
| GELOGD("Enter BuildRefRelationsForBranch!"); | |||||
| size_t ref_i = 0; | |||||
| for (const auto &ref_i_data_nodes : classed_data_nodes) { | |||||
| vector<RefCell> in_ref_i_all_refs; | |||||
| RefCell cell_root; | |||||
| cell_root.node_name = root_node->GetName(); | |||||
| cell_root.node = root_node; | |||||
| cell_root.in_out = NODE_IN; | |||||
| cell_root.in_out_idx = ref_i; | |||||
| in_ref_i_all_refs.emplace_back(cell_root); | |||||
| for (const auto &data : ref_i_data_nodes) { | |||||
| RefCell cell_in; | |||||
| RefCell cell_out; | |||||
| cell_in.node_name = data->GetName(); | |||||
| cell_in.node = data; | |||||
| cell_in.in_out = NODE_IN; | |||||
| cell_in.in_out_idx = 0; | |||||
| cell_out.node_name = data->GetName(); | |||||
| cell_out.node = data; | |||||
| cell_out.in_out = NODE_OUT; | |||||
| cell_out.in_out_idx = 0; | |||||
| in_ref_i_all_refs.emplace_back(cell_in); | |||||
| in_ref_i_all_refs.emplace_back(cell_out); | |||||
| } | |||||
| node_refs.emplace_back(in_ref_i_all_refs); | |||||
| ref_i++; | |||||
| } | |||||
| size_t ref_o = 0; | |||||
| for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { | |||||
| vector<RefCell> out_ref_i_all_refs; | |||||
| RefCell cell_root; | |||||
| cell_root.node_name = root_node->GetName(); | |||||
| cell_root.node = root_node; | |||||
| cell_root.in_out = NODE_OUT; | |||||
| cell_root.in_out_idx = ref_o; | |||||
| out_ref_i_all_refs.emplace_back(cell_root); | |||||
| for (const auto &ele : ref_o_net_nodes) { | |||||
| RefCell cell_netoutput_in; | |||||
| cell_netoutput_in.node_name = (ele.first)->GetName(); | |||||
| cell_netoutput_in.node = ele.first; | |||||
| cell_netoutput_in.in_out = NODE_IN; | |||||
| cell_netoutput_in.in_out_idx = ele.second; | |||||
| out_ref_i_all_refs.emplace_back(cell_netoutput_in); | |||||
| } | |||||
| node_refs.emplace_back(out_ref_i_all_refs); | |||||
| ref_o++; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus RefRelations::Impl::BuildLookUpTables() { | |||||
| GELOGD("start to build look up table!"); | |||||
| for (size_t i = 0; i < values_.size(); i++) { | |||||
| vector<vector<RefCell>> &val = values_[i]; | |||||
| for (const auto &ele : val) { | |||||
| for (const auto &ref_cell : ele) { | |||||
| string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + | |||||
| std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get()))); | |||||
| look_up_table_[key] = ele; | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||||
| const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||||
| GELOGD("Enter BuildRefRelations for while op!"); | |||||
| // data_nodes has been sorted | |||||
| // for while, input num must be same as output num | |||||
| auto input_num = root_node->GetAllInDataAnchorsSize(); | |||||
| NodePtr netoutput = nullptr; | |||||
| size_t ref_i = 0; | |||||
| while (ref_i < input_num) { | |||||
| auto &ref_i_data_nodes = classed_data_nodes[ref_i]; | |||||
| auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; | |||||
| vector<RefCell> ref_i_all_refs; | |||||
| RefCell cell_root_i; | |||||
| RefCell cell_root_o; | |||||
| cell_root_i.node_name = root_node->GetName(); | |||||
| cell_root_i.node = root_node; | |||||
| cell_root_i.in_out = NODE_IN; | |||||
| cell_root_i.in_out_idx = ref_i; | |||||
| ref_i_all_refs.emplace_back(cell_root_i); | |||||
| cell_root_o.node_name = root_node->GetName(); | |||||
| cell_root_o.node = root_node; | |||||
| cell_root_o.in_out = NODE_OUT; | |||||
| cell_root_o.in_out_idx = ref_i; | |||||
| ref_i_all_refs.emplace_back(cell_root_o); | |||||
| for (const auto &data : ref_i_data_nodes) { | |||||
| RefCell cell_in; | |||||
| RefCell cell_out; | |||||
| cell_in.node_name = data->GetName(); | |||||
| cell_in.node = data; | |||||
| cell_in.in_out = NODE_IN; | |||||
| cell_in.in_out_idx = 0; | |||||
| cell_out.node_name = data->GetName(); | |||||
| cell_out.node = data; | |||||
| cell_out.in_out = NODE_OUT; | |||||
| cell_out.in_out_idx = 0; | |||||
| ref_i_all_refs.emplace_back(cell_in); | |||||
| ref_i_all_refs.emplace_back(cell_out); | |||||
| } | |||||
| for (const auto &ele : ref_i_net_nodes) { | |||||
| RefCell cell_netoutput_in; | |||||
| RefCell cell_netoutput_out; | |||||
| cell_netoutput_in.node_name = (ele.first)->GetName(); | |||||
| cell_netoutput_in.node = ele.first; | |||||
| cell_netoutput_in.in_out = NODE_IN; | |||||
| cell_netoutput_in.in_out_idx = ele.second; | |||||
| ref_i_all_refs.emplace_back(cell_netoutput_in); | |||||
| netoutput = ele.first; | |||||
| } | |||||
| node_refs.emplace_back(ref_i_all_refs); | |||||
| ref_i++; | |||||
| } | |||||
| /* There exist scene like the follows, it means data0 data1 netoutput 0'th | |||||
| * and 1'th tensor should be the same addr. | |||||
| * Data0 Data1 | |||||
| * \/ | |||||
| * /\ | |||||
| * netoutput | |||||
| */ | |||||
| if (netoutput == nullptr) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { | |||||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { | |||||
| GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); | |||||
| continue; | |||||
| } | |||||
| if (peer_out_data_node->GetType() != DATA) { | |||||
| continue; | |||||
| } | |||||
| auto in_data_anchor_idx = in_anchor->GetIdx(); | |||||
| auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx)); | |||||
| int ref_d = 0; | |||||
| int ref_n = 0; | |||||
| (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); | |||||
| (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); | |||||
| node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); | |||||
| node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // build ref relations according to diff func op type | |||||
| graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( | |||||
| const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||||
| // data_nodes has been sorted | |||||
| auto node_type = root_node->GetType(); | |||||
| auto status = GRAPH_SUCCESS; | |||||
| if (node_type != kWhile) { | |||||
| status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
| } else { | |||||
| status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
| } | |||||
| return status; | |||||
| } | |||||
| void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||||
| vector<NodePtr> &netoutput_nodes, | |||||
| const std::vector<std::string> &sub_graph_names, | |||||
| const std::string &node_type) { | |||||
| int sub_graph_idx = 0; | |||||
| for (const auto &name : sub_graph_names) { | |||||
| auto sub_graph = root_graph.GetSubgraph(name); | |||||
| if (sub_graph == nullptr) { | |||||
| GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||||
| auto sub_graph_node_type = sub_graph_node->GetType(); | |||||
| if (sub_graph_node_type == DATA) { | |||||
| data_nodes.emplace_back(sub_graph_node); | |||||
| } else if (sub_graph_node_type == NETOUTPUT) { | |||||
| // if while, the first subgraph must be cond subgraph. | |||||
| // There is no meaning for refs ,so continue | |||||
| if (node_type == kWhile && sub_graph_idx == 0) { | |||||
| continue; | |||||
| } | |||||
| netoutput_nodes.emplace_back(sub_graph_node); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| sub_graph_idx++; | |||||
| } | |||||
| } | |||||
| graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { | |||||
| auto parent_graph_ptr = graph.GetParentGraph(); | |||||
| if (parent_graph_ptr == nullptr) { | |||||
| root_graph = graph; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); | |||||
| if (root_graph_ptr == nullptr) { | |||||
| GE_LOGE("Get null root graph"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| root_graph = *root_graph_ptr; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, | |||||
| vector<vector<NodePtr>> &classed_data_nodes) { | |||||
| GELOGD("start to process subgraph data nodes!"); | |||||
| int max_ref_idx = 0; | |||||
| for (const auto &e : data_nodes) { | |||||
| int i; | |||||
| bool is_exist = true; | |||||
| is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); | |||||
| if (!is_exist) { | |||||
| GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; | |||||
| } | |||||
| while (!data_nodes.empty()) { | |||||
| auto data = data_nodes.back(); | |||||
| data_nodes.pop_back(); | |||||
| int ref_idx = 0; | |||||
| (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); | |||||
| if (ref_idx >= static_cast<int>(classed_data_nodes.size())) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| classed_data_nodes[ref_idx].emplace_back(data); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | |||||
| const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) { | |||||
| GELOGD("[RefRelations]Start to process subgraph netoutput!"); | |||||
| for (const auto &sub_netoutput_node : netoutput_nodes) { | |||||
| auto op_desc = sub_netoutput_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { | |||||
| auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); | |||||
| if (in_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", | |||||
| sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int ref_o; | |||||
| if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { | |||||
| if (ref_o >= static_cast<int>(classed_netoutput_nodes.size())) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| classed_netoutput_nodes[ref_o].emplace_back( | |||||
| std::pair<NodePtr, size_t>({sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())})); | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | |||||
| GELOGD("Start to build ref relations!"); | |||||
| /* First Step: Get root graph */ | |||||
| ge::ComputeGraph &root_graph = graph; | |||||
| auto status = GetRootGraph(graph, root_graph); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| for (const auto &node : graph.GetAllNodes()) { | |||||
| auto node_type = node->GetType(); | |||||
| std::vector<NodePtr> ref_nodes; | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| continue; | |||||
| } | |||||
| vector<NodePtr> data_nodes; | |||||
| vector<NodePtr> netoutput_nodes; | |||||
| // Get data and netoutput of sub_graph | |||||
| GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); | |||||
| size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; | |||||
| vector<vector<NodePtr>> classed_data_nodes(max_elem_num); // according to ref_idx | |||||
| vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(max_elem_num); // according to ref_idx | |||||
| status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); | |||||
| return status; | |||||
| } | |||||
| // for netoutput | |||||
| // check netoutput | |||||
| // here main graph output number must be the same as every sub_graph netoutput node | |||||
| // key: netoutput node_ptr ,<ref_idx, net_in_idx> | |||||
| status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "process netoutput failed!"); | |||||
| return status; | |||||
| } | |||||
| vector<vector<RefCell>> node_refs; | |||||
| status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| if (!node_refs.empty()) { | |||||
| values_.push_back(node_refs); | |||||
| } | |||||
| } | |||||
| /* Seconde Step: generate map */ | |||||
| status = BuildLookUpTables(); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(status, "Build look up tables failed!"); | |||||
| return status; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| /* Ref Relations Interface */ | |||||
| RefRelations::RefRelations() { | |||||
| impl_ = MakeShared<Impl>(); | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "MakeShared failed!"); | |||||
| return; | |||||
| } | |||||
| } | |||||
| graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||||
| GE_CHECK_NOTNULL(impl_); | |||||
| return impl_->LookUpRefRelations(key, result); | |||||
| } | |||||
| graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { | |||||
| GE_CHECK_NOTNULL(impl_); | |||||
| return impl_->BuildRefRelations(root_graph); | |||||
| } | |||||
| graphStatus RefRelations::Clear() { | |||||
| GE_CHECK_NOTNULL(impl_); | |||||
| return impl_->Clear(); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,96 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/runtime_inference_context.h" | |||||
| #include <cstdint> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> RuntimeInferenceContext::contexts_; | |||||
| std::mutex RuntimeInferenceContext::ctx_mu_; | |||||
| graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { | |||||
| GELOGI("To create context. session id = %s", context_id.c_str()); | |||||
| auto ctx = std::unique_ptr<RuntimeInferenceContext>(new (std::nothrow) RuntimeInferenceContext()); | |||||
| if (ctx == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lk(ctx_mu_); | |||||
| auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); | |||||
| if (!emplace_ret.second) { | |||||
| GELOGE(GRAPH_FAILED, "Old context not destroyed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { | |||||
| GELOGI("To destroy context. session id = %s", context_id.c_str()); | |||||
| std::lock_guard<std::mutex> lk(ctx_mu_); | |||||
| contexts_.erase(context_id); | |||||
| } | |||||
| graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { | |||||
| std::lock_guard<std::mutex> lk(ctx_mu_); | |||||
| auto it = contexts_.find(context_id); | |||||
| if (it != contexts_.end()) { | |||||
| *ctx = it->second.get(); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| auto &output_tensors = tensors_[node_id]; | |||||
| if (static_cast<uint32_t>(output_id) >= output_tensors.size()) { | |||||
| output_tensors.resize(output_id + 1); | |||||
| } | |||||
| GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); | |||||
| output_tensors[output_id] = std::move(tensor); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { | |||||
| if (output_id < 0) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| auto iter = tensors_.find(node_id); | |||||
| if (iter == tensors_.end()) { | |||||
| GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| auto &output_tensors = iter->second; | |||||
| if (static_cast<uint32_t>(output_id) >= output_tensors.size()) { | |||||
| GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); | |||||
| tensor = output_tensors[output_id]; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,688 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/shape_refiner.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "debug/ge_op_types.h" | |||||
| #include "external/graph/operator.h" | |||||
| #include "external/graph/operator_factory.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "utils/node_utils.h" | |||||
| #include "utils/op_desc_utils.h" | |||||
| #include "utils/tensor_utils.h" | |||||
| #include "utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const uint32_t kWhileBodySubGraphIdx = 1; | |||||
| graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { | |||||
| GELOGD("Enter reverse brush while body subgraph process!"); | |||||
| auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); | |||||
| if (sub_graph_body == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get while body graph failed!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (const auto &node_sub : sub_graph_body->GetAllNodes()) { | |||||
| for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { | |||||
| auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); | |||||
| GE_IF_BOOL_EXEC(input_desc == nullptr, | |||||
| GELOGW("Get null input by index %zu from node %s ", i, node_sub->GetName().c_str()); | |||||
| continue); | |||||
| (void)input_desc->SetUnknownDimNumShape(); | |||||
| } | |||||
| for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { | |||||
| auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); | |||||
| (void)output_desc->SetUnknownDimNumShape(); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||||
| // check sub_graph shape. Get max for update. | |||||
| for (size_t i = 0; i < ref_out_tensors.size(); ++i) { | |||||
| if (ref_out_tensors[i].empty()) { | |||||
| continue; | |||||
| } | |||||
| int64_t max_size = 0; | |||||
| size_t max_shape_index = 0; | |||||
| auto &ref_out_tensor = ref_out_tensors[i].at(0); | |||||
| const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||||
| for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { | |||||
| auto &tensor = ref_out_tensors[i].at(j); | |||||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor.MutableShape(); | |||||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||||
| GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", | |||||
| node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int64_t size = 1; | |||||
| for (auto dim : shape.GetDims()) { | |||||
| if (INT64_MAX / dim < size) { | |||||
| GELOGE(PARAM_INVALID, "The shape size overflow"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| size *= dim; | |||||
| } | |||||
| if (size > max_size) { | |||||
| max_size = size; | |||||
| max_shape_index = j; | |||||
| } | |||||
| } | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||||
| GELOGD("Enter update parent node shape for class branch op process"); | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||||
| return UpdataOutputForMultiBatcch(node, ref_out_tensors); | |||||
| } | |||||
| // check sub_graph shape.If not same ,do unknown shape process | |||||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||||
| if (ref_out_tensors[i].empty()) { | |||||
| continue; | |||||
| } | |||||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||||
| ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||||
| for (auto &tensor : ref_out_tensors[i]) { | |||||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor.MutableShape(); | |||||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||||
| GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||||
| shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||||
| ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||||
| break; | |||||
| } | |||||
| for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||||
| if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||||
| continue; | |||||
| } | |||||
| GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||||
| j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||||
| (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||||
| } | |||||
| } | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||||
| GELOGD("Enter update parent node shape for class while op process"); | |||||
| if (ref_data_tensors.size() != ref_out_tensors.size()) { | |||||
| GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), | |||||
| ref_data_tensors.size(), ref_out_tensors.size()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (size_t i = 0; i < ref_data_tensors.size(); i++) { | |||||
| if (ref_out_tensors[i].size() != 1) { | |||||
| GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| bool is_need_reverse_brush = false; | |||||
| // check input and output | |||||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||||
| if (ref_out_tensors[i].empty()) { | |||||
| continue; | |||||
| } | |||||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||||
| auto tmp_shape = ref_out_tensor.MutableShape(); | |||||
| // ref_i's data and output tensor shape should be same | |||||
| for (auto &tensor : ref_data_tensors[i]) { | |||||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto shape = tensor.MutableShape(); | |||||
| if (shape.GetDims() != tmp_shape.GetDims()) { | |||||
| ref_out_tensor.SetUnknownDimNumShape(); | |||||
| is_need_reverse_brush = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||||
| } | |||||
| // reverse refresh while body shape | |||||
| if (is_need_reverse_brush) { | |||||
| return ReverseBrushWhileBodySubGraph(node); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
| for (const auto &name : sub_graph_names) { | |||||
| if (name.empty()) { | |||||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto sub_graph = root_graph->GetSubgraph(name); | |||||
| if (sub_graph == nullptr) { | |||||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||||
| if (node_sub->GetType() != DATA) { | |||||
| continue; | |||||
| } | |||||
| int ref_i; | |||||
| auto data_opdesc = node_sub->GetOpDesc(); | |||||
| if (data_opdesc == nullptr) { | |||||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { | |||||
| continue; | |||||
| } | |||||
| auto input_desc = op_desc->MutableInputDesc(ref_i); | |||||
| if (input_desc == nullptr) { | |||||
| GE_LOGE( | |||||
| "The ref index(%d) on the data %s on the sub graph %s " | |||||
| "parent node %s are incompatible, inputs num %u", | |||||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||||
| node->GetName().c_str()); | |||||
| auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput, | |||||
| const ConstNodePtr &node, | |||||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) { | |||||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||||
| auto sub_node = sub_nodes.at(i - 1); | |||||
| if (sub_node->GetType() == NETOUTPUT) { | |||||
| netoutput = sub_node; | |||||
| } | |||||
| if (sub_node->GetType() == DATA) { | |||||
| if (sub_node->GetOpDesc() == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int ref_i; | |||||
| if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
| GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) { | |||||
| GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), | |||||
| ref_i, node->GetAllInDataAnchorsSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize()); | |||||
| std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
| for (const auto &name : sub_graph_names) { | |||||
| if (name.empty()) { | |||||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto sub_graph = root_graph->GetSubgraph(name); | |||||
| if (sub_graph == nullptr) { | |||||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr netoutput = nullptr; | |||||
| auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (netoutput == nullptr) { | |||||
| GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto netoutput_opdesc = netoutput->GetOpDesc(); | |||||
| if (netoutput_opdesc == nullptr) { | |||||
| GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||||
| auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||||
| if (edge_desc == nullptr) { | |||||
| GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), | |||||
| node->GetName().c_str(), edge_anchor->GetIdx()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), | |||||
| edge_desc->GetShape().GetDimNum()); | |||||
| int ref_i; | |||||
| if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||||
| // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||||
| continue; | |||||
| } | |||||
| GELOGI("Parent node index of edge desc is %d", ref_i); | |||||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ref_out_tensors[ref_i].emplace_back(*edge_desc); | |||||
| } | |||||
| } | |||||
| if (node->GetType() == WHILE) { | |||||
| return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); | |||||
| } | |||||
| return UpdateParentNodeForBranch(node, ref_out_tensors); | |||||
| } | |||||
| string Serial(const vector<int64_t> &dims) { | |||||
| string serial_string; | |||||
| serial_string += "["; | |||||
| for (int64_t dim : dims) { | |||||
| serial_string += std::to_string(dim) + " "; | |||||
| } | |||||
| serial_string += "]"; | |||||
| return serial_string; | |||||
| } | |||||
| graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { | |||||
| GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | |||||
| GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | |||||
| for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { | |||||
| auto in_idx = in_anchor->GetIdx(); | |||||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| int peer_out_idx = peer_out_data_anchor->GetIdx(); | |||||
| auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(peer_out_idx)); | |||||
| // check shape and dtype continuity. do not stop process | |||||
| auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_idx)); | |||||
| if (in_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto in_shape = in_desc->GetShape().GetDims(); | |||||
| auto in_dtype = in_desc->GetDataType(); | |||||
| auto peer_out_shape = peer_out_desc->GetShape().GetDims(); | |||||
| auto peer_out_dtype = peer_out_desc->GetDataType(); | |||||
| if (peer_out_dtype != in_dtype) { | |||||
| GELOGW( | |||||
| "current node [%s] [%d]\'th out_dtype is [%s].peer output node [%s] [%d]\'th " | |||||
| "output_dtype is [%s].The two dtype should be same! Please check graph and fix it", | |||||
| node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), | |||||
| peer_out_data_node->GetName().c_str(), peer_out_idx, TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); | |||||
| } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { | |||||
| string in_shape_str = Serial(in_shape); | |||||
| string peer_out_shape_str = Serial(peer_out_shape); | |||||
| GELOGW( | |||||
| "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " | |||||
| "input_shape is [%s].The two shape should be same! Please check graph and fix it", | |||||
| node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), peer_out_data_node->GetName().c_str(), peer_out_idx, | |||||
| peer_out_shape_str.c_str()); | |||||
| } | |||||
| // refresh current node input desc | |||||
| in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); | |||||
| in_desc->SetShape(peer_out_desc->GetShape()); | |||||
| in_desc->SetDataType(peer_out_desc->GetDataType()); | |||||
| in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)peer_out_desc->GetShapeRange(shape_range); | |||||
| in_desc->SetShapeRange(shape_range); | |||||
| ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->GetShape().GetDims().size())); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | |||||
| if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||||
| return; | |||||
| } | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "node is null"); | |||||
| return; | |||||
| } | |||||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | |||||
| std::string str; | |||||
| if (op_desc->GetInputsSize() != 0) { | |||||
| std::string input_desc_str = "input shape: "; | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| input_desc_str += "["; | |||||
| for (int64_t dim : input_desc->GetShape().GetDims()) { | |||||
| input_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| input_desc_str += "]"; | |||||
| input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; | |||||
| } | |||||
| str += input_desc_str; | |||||
| input_desc_str = "input origin shape: "; | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| input_desc_str += "["; | |||||
| for (int64_t dim : input_desc->GetOriginShape().GetDims()) { | |||||
| input_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| input_desc_str += "]"; | |||||
| input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; | |||||
| } | |||||
| str += input_desc_str; | |||||
| } | |||||
| if (op_desc->GetAllOutputsDescSize() != 0) { | |||||
| std::string output_desc_str = "output shape: "; | |||||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| if (output_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| output_desc_str += "["; | |||||
| for (int64_t dim : output_desc->GetShape().GetDims()) { | |||||
| output_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| output_desc_str += "]"; | |||||
| output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; | |||||
| } | |||||
| str += output_desc_str; | |||||
| output_desc_str = "output origin shape: "; | |||||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| if (output_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| output_desc_str += "["; | |||||
| for (int64_t dim : output_desc->GetOriginShape().GetDims()) { | |||||
| output_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| output_desc_str += "]"; | |||||
| output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; | |||||
| } | |||||
| str += output_desc_str; | |||||
| } | |||||
| GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); | |||||
| } | |||||
| graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | |||||
| return InferShapeAndType(node, op, true); | |||||
| } | |||||
| graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| const auto &op_type = op_desc->GetType(); | |||||
| graphStatus ret; | |||||
| if (before_subgraph) { | |||||
| ret = UpdateSubGraphDataNodes(node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| // Get infer func and execute | |||||
| ret = op_desc->CallInferFunc(op); | |||||
| if (ret == GRAPH_PARAM_INVALID) { | |||||
| // Op ir no infer func, try to get infer func from operator factory | |||||
| auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | |||||
| if (node_op.IsEmpty()) { | |||||
| GELOGW("get op from OperatorFactory fail. opType: %s", op_type.c_str()); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); | |||||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
| node_op.BreakConnect(); | |||||
| if (temp_op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "temp op desc is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
| GELOGW("InferShapeAndType UpdateInputName failed"); | |||||
| for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| if (out_desc != nullptr && out_desc->GetShape().GetDims().empty()) { | |||||
| break; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
| GELOGW("InferShapeAndType UpdateOutputName failed"); | |||||
| } | |||||
| op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); | |||||
| ret = op_desc->CallInferFunc(op); | |||||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||||
| } | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!before_subgraph) { | |||||
| return UpdateParentNodeOutTensor(node); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | |||||
| const NodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "node is null"); | |||||
| return nullptr; | |||||
| } | |||||
| InferenceContextPtr inference_context = std::shared_ptr<InferenceContext>(InferenceContext::Create()); | |||||
| if (inference_context == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to alloc InferenceContext"); | |||||
| return nullptr; | |||||
| } | |||||
| auto all_in_data_anchors = node->GetAllInDataAnchors(); | |||||
| std::vector<std::vector<ShapeAndType>> input_shapes_and_types(all_in_data_anchors.size()); | |||||
| std::vector<std::string> marks; | |||||
| bool has_input_shapes_and_types = false; | |||||
| for (const auto &in_anchor : all_in_data_anchors) { | |||||
| const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto input_node = out_anchor->GetOwnerNode(); | |||||
| if (input_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto iter = context_map.find(input_node); | |||||
| if (iter != context_map.end()) { | |||||
| const auto &src_context = iter->second; | |||||
| GE_IF_BOOL_EXEC(src_context == nullptr, GELOGE(GRAPH_FAILED, "src_context is null."); return nullptr); | |||||
| GELOGD("node:%s get %ld marks from node:%s", node->GetName().c_str(), src_context->GetMarks().size(), | |||||
| input_node->GetName().c_str()); | |||||
| for (auto mark : src_context->GetMarks()) { | |||||
| marks.push_back(mark); | |||||
| } | |||||
| auto output_idx = out_anchor->GetIdx(); | |||||
| auto input_idx = in_anchor->GetIdx(); | |||||
| auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); | |||||
| if (output_idx < static_cast<int>(output_shape_and_type.size())) { | |||||
| GELOGI("Add shape and type from %s:%d to %s:%d", input_node->GetName().c_str(), output_idx, | |||||
| node->GetName().c_str(), input_idx); | |||||
| input_shapes_and_types[input_idx] = output_shape_and_type[output_idx]; | |||||
| has_input_shapes_and_types = true; | |||||
| } else { | |||||
| GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, | |||||
| output_shape_and_type.size()); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (has_input_shapes_and_types) { | |||||
| inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); | |||||
| } | |||||
| inference_context->SetMarks(marks); | |||||
| return inference_context; | |||||
| } | |||||
| namespace { | |||||
| thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | |||||
| return InferShapeAndType(node, true); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, | |||||
| bool before_subgraph) { | |||||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | |||||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| auto opdesc = node->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(opdesc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | |||||
| // some op can not infershape twice such as aipp | |||||
| bool need_update_input = !is_unknown_graph && !opdesc->HasAttr("has_infered_verified"); | |||||
| if (need_update_input) { | |||||
| auto status = UpdateOpInputDesc(node); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "update op input_desc failed!"); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| if (node->Verify() != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| PrintInOutTensorShape(node, "before_infershape"); | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||||
| if (!is_unknown_graph) { | |||||
| auto inference_context = CreateInferenceContext(context_map, node); | |||||
| if (inference_context == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "inference context is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | |||||
| op.SetInferenceContext(inference_context); | |||||
| } | |||||
| graphStatus status = InferShapeAndType(node, op, before_subgraph); | |||||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||||
| if (is_unknown_graph) { | |||||
| PrintInOutTensorShape(node, "after_infershape when running"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||||
| output_tensor->SetOriginShape(output_tensor->GetShape()); | |||||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||||
| node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||||
| } | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!is_unknown_graph) { | |||||
| auto ctx_after_infer = op.GetInferenceContext(); | |||||
| if (ctx_after_infer != nullptr) { | |||||
| GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||||
| if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { | |||||
| GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), | |||||
| ctx_after_infer->GetMarks().size()); | |||||
| (void)context_map.emplace(node, ctx_after_infer); | |||||
| } | |||||
| } | |||||
| } | |||||
| PrintInOutTensorShape(node, "after_infershape"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,704 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "external/graph/tensor.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "securec.h" | |||||
| #include "utils/attr_utils.h" | |||||
| #include "utils/tensor_adapter.h" | |||||
| #include "utils/tensor_utils.h" | |||||
| #include "utils/type_utils.h" | |||||
| namespace { | |||||
| /// Extra 8 bytes store pointer of string | |||||
| /// Extra 1 byte store '\0' | |||||
| const int EXTRA_STORE_POINTER_FOR_STRING = 8; | |||||
| const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; | |||||
| const int64_t UNKNOWN_DIM_SIZE = -1; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| // If not overflow return true | |||||
| static bool Int64MulNotOverflow(int64_t a, int64_t b) { | |||||
| if (a > 0) { | |||||
| if (b > 0) { | |||||
| if (a > (INT64_MAX / b)) { | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| if (b < (INT64_MIN / a)) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (b > 0) { | |||||
| if (a < (INT64_MIN / b)) { | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| if ((a != 0) && (b < (INT64_MAX / a))) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| class TensorDescImpl { | |||||
| public: | |||||
| TensorDescImpl() = default; | |||||
| ~TensorDescImpl() = default; | |||||
| TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} | |||||
| Shape shape_; | |||||
| std::vector<std::pair<int64_t, int64_t>> range_; | |||||
| Format format_ = FORMAT_ND; | |||||
| Format origin_format_ = FORMAT_ND; | |||||
| DataType data_type_ = DT_FLOAT; | |||||
| Shape origin_shape_; | |||||
| int64_t size_ = 0; | |||||
| int64_t real_dim_cnt_ = 0; | |||||
| std::string name_; | |||||
| }; | |||||
| class TensorImpl { | |||||
| public: | |||||
| TensorImpl() = default; | |||||
| ~TensorImpl() = default; | |||||
| explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {} | |||||
| TensorImpl(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) | |||||
| : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {} | |||||
| TensorImpl(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||||
| : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {} | |||||
| TensorImpl(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) | |||||
| : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {} | |||||
| GeTensor ge_tensor; | |||||
| }; | |||||
| class ShapeImpl { | |||||
| public: | |||||
| ShapeImpl() = default; | |||||
| ~ShapeImpl() = default; | |||||
| explicit ShapeImpl(const std::vector<int64_t> &dims) { | |||||
| bool is_unknown_dim_num = false; | |||||
| for (const auto &dim : dims) { | |||||
| if (dim == UNKNOWN_DIM_NUM) { | |||||
| is_unknown_dim_num = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| dims_ = is_unknown_dim_num ? std::vector<int64_t>({UNKNOWN_DIM_NUM}) : dims; | |||||
| } | |||||
| std::vector<int64_t> dims_; | |||||
| }; | |||||
| Shape::Shape() { impl_ = ComGraphMakeShared<ShapeImpl>(); } | |||||
| Shape::Shape(const std::vector<int64_t> &dims) { impl_ = ComGraphMakeShared<ShapeImpl>(dims); } | |||||
| size_t Shape::GetDimNum() const { | |||||
| if (impl_ != nullptr) { | |||||
| for (auto i : impl_->dims_) { | |||||
| if (i == UNKNOWN_DIM_NUM) { | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| return impl_->dims_.size(); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| int64_t Shape::GetDim(size_t idx) const { | |||||
| if (impl_ != nullptr) { | |||||
| if (idx >= impl_->dims_.size()) { | |||||
| return 0; | |||||
| } | |||||
| return impl_->dims_[idx]; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| graphStatus Shape::SetDim(size_t idx, int64_t value) { | |||||
| if (impl_ != nullptr) { | |||||
| if (idx >= impl_->dims_.size()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| impl_->dims_[idx] = value; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::vector<int64_t> Shape::GetDims() const { | |||||
| vector<int64_t> dims; | |||||
| if (impl_ != nullptr) { | |||||
| return impl_->dims_; | |||||
| } | |||||
| return dims; | |||||
| } | |||||
| int64_t Shape::GetShapeSize() const { | |||||
| if (impl_ != nullptr) { | |||||
| if (impl_->dims_.empty()) { | |||||
| return 0; | |||||
| } | |||||
| int64_t size = 1; | |||||
| for (auto i : impl_->dims_) { | |||||
| if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { | |||||
| return UNKNOWN_DIM_SIZE; | |||||
| } | |||||
| if (!Int64MulNotOverflow(size, i)) { | |||||
| GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | |||||
| size = 0; | |||||
| return size; | |||||
| } | |||||
| size *= i; | |||||
| } | |||||
| return size; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| TensorDesc::TensorDesc() { | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665 | |||||
| } | |||||
| TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); // lint !e665 | |||||
| SetRealDimCnt(shape.GetDimNum()); | |||||
| } | |||||
| TensorDesc::TensorDesc(const TensorDesc &desc) { | |||||
| // Copy | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665 | |||||
| if (desc.impl != nullptr && impl != nullptr) { | |||||
| *impl = *desc.impl; | |||||
| } | |||||
| } | |||||
| TensorDesc::TensorDesc(TensorDesc &&desc) { | |||||
| // Move | |||||
| impl = std::move(desc.impl); | |||||
| } | |||||
| TensorDesc &TensorDesc::operator=(const TensorDesc &desc) { | |||||
| // Copy | |||||
| if (&desc != this) { | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(); | |||||
| if (desc.impl != nullptr && impl != nullptr) { | |||||
| *impl = *desc.impl; | |||||
| } | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| TensorDesc &TensorDesc::operator=(TensorDesc &&desc) { | |||||
| if (&desc != this) { | |||||
| impl = std::move(desc.impl); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| void TensorDesc::Update(const Shape &shape, Format format, DataType dt) { | |||||
| if (impl != nullptr) { | |||||
| impl->shape_ = shape; | |||||
| impl->format_ = format; | |||||
| impl->data_type_ = dt; | |||||
| } | |||||
| } | |||||
| Shape TensorDesc::GetShape() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->shape_; | |||||
| } | |||||
| return Shape(); | |||||
| } | |||||
| void TensorDesc::SetShape(const Shape &shape) { | |||||
| if (impl != nullptr) { | |||||
| impl->shape_ = shape; | |||||
| } | |||||
| } | |||||
| // set shape with -2, it stand for unknown shape | |||||
| graphStatus TensorDesc::SetUnknownDimNumShape() { | |||||
| if (impl != nullptr) { | |||||
| impl->shape_ = Shape({UNKNOWN_DIM_NUM}); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // for unknown shape | |||||
| graphStatus TensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||||
| if (impl != nullptr) { | |||||
| impl->range_ = range; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus TensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||||
| if (impl != nullptr) { | |||||
| range = impl->range_; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GELOGE(GRAPH_FAILED, "impl is nullptr!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Shape TensorDesc::GetOriginShape() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->origin_shape_; | |||||
| } | |||||
| return Shape(); | |||||
| } | |||||
| void TensorDesc::SetOriginShape(const Shape &origin_shape) { | |||||
| if (impl != nullptr) { | |||||
| impl->origin_shape_ = origin_shape; | |||||
| } | |||||
| } | |||||
| Format TensorDesc::GetFormat() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->format_; | |||||
| } | |||||
| return FORMAT_RESERVED; | |||||
| } | |||||
| void TensorDesc::SetFormat(Format format) { | |||||
| if (impl != nullptr) { | |||||
| impl->format_ = format; | |||||
| } | |||||
| } | |||||
| Format TensorDesc::GetOriginFormat() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->origin_format_; | |||||
| } | |||||
| return FORMAT_RESERVED; | |||||
| } | |||||
| void TensorDesc::SetOriginFormat(Format origin_format) { | |||||
| if (impl != nullptr) { | |||||
| impl->origin_format_ = origin_format; | |||||
| } | |||||
| } | |||||
| DataType TensorDesc::GetDataType() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->data_type_; | |||||
| } | |||||
| return DT_UNDEFINED; | |||||
| } | |||||
| void TensorDesc::SetDataType(DataType dt) { | |||||
| if (impl != nullptr) { | |||||
| impl->data_type_ = dt; | |||||
| } | |||||
| } | |||||
| void TensorDesc::SetSize(int64_t size) { | |||||
| if (impl != nullptr) { | |||||
| impl->size_ = size; | |||||
| } | |||||
| } | |||||
| int64_t TensorDesc::GetSize() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->size_; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) { | |||||
| if (impl != nullptr) { | |||||
| impl->real_dim_cnt_ = real_dim_cnt; | |||||
| } | |||||
| } | |||||
| int64_t TensorDesc::GetRealDimCnt() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->real_dim_cnt_; | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| std::string TensorDesc::GetName() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->name_; | |||||
| } | |||||
| return ""; | |||||
| } | |||||
| void TensorDesc::SetName(const std::string &name) { | |||||
| if (impl != nullptr) { | |||||
| impl->name_ = name; | |||||
| } | |||||
| } | |||||
| Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); } | |||||
| Tensor::Tensor(const TensorDesc &tensor_desc) { | |||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc); // lint !e665 | |||||
| } | |||||
| Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) { | |||||
| uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); | |||||
| DataType data_type = tensor_desc.GetDataType(); | |||||
| uint32_t type_length; | |||||
| bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); | |||||
| if (!ret) { | |||||
| GELOGW("datatype %d is not found.", data_type); | |||||
| } | |||||
| auto data_size = data.size(); | |||||
| if (ret && (shape_size || (data_size != type_length))) { | |||||
| if (type_length != 0 && UINT64_MAX / type_length < shape_size) { | |||||
| GELOGW("mul overflow: %lu, %u", shape_size, type_length); | |||||
| } else { | |||||
| if (shape_size * type_length != data_size) { | |||||
| GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | |||||
| data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); // lint !e665 | |||||
| } | |||||
| Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { | |||||
| uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); | |||||
| DataType data_type = tensor_desc.GetDataType(); | |||||
| uint32_t type_length; | |||||
| bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); | |||||
| if (!ret) { | |||||
| GELOGW("datatype %d is not found.", data_type); | |||||
| } | |||||
| if (ret && (shape_size || (size != type_length))) { | |||||
| if (type_length != 0 && UINT64_MAX / type_length < shape_size) { | |||||
| GELOGW("mul overflow: %lu, %u", shape_size, type_length); | |||||
| } else { | |||||
| if (shape_size * type_length != size) { | |||||
| GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | |||||
| size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); // lint !e665 | |||||
| } | |||||
| Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) { | |||||
| uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); | |||||
| DataType data_type = tensor_desc.GetDataType(); | |||||
| uint32_t type_length; | |||||
| bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); | |||||
| if (!ret) { | |||||
| GELOGW("datatype %d is not found.", data_type); | |||||
| } | |||||
| auto data_size = data.size(); | |||||
| if (ret && (shape_size || (data_size != type_length))) { | |||||
| if (type_length != 0 && UINT64_MAX / type_length < shape_size) { | |||||
| GELOGW("mul overflow: %lu, %u", shape_size, type_length); | |||||
| } else { | |||||
| if (shape_size * type_length != data_size) { | |||||
| GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | |||||
| data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); // lint !e665 | |||||
| } | |||||
| TensorDesc Tensor::GetTensorDesc() const { | |||||
| if (impl != nullptr) { | |||||
| return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc()); | |||||
| } | |||||
| return TensorDesc(); | |||||
| } | |||||
| graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) { | |||||
| if (impl != nullptr) { | |||||
| impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const uint8_t *Tensor::GetData() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->ge_tensor.GetData().data(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| uint8_t *Tensor::GetData() { | |||||
| if (impl != nullptr) { | |||||
| return impl->ge_tensor.MutableData().data(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| size_t Tensor::GetSize() const { | |||||
| if (impl != nullptr) { | |||||
| return impl->ge_tensor.GetData().size(); | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| graphStatus Tensor::SetData(std::vector<uint8_t> &&data) { | |||||
| if (impl != nullptr) { | |||||
| (void)impl->ge_tensor.SetData(data); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus Tensor::SetData(const std::vector<uint8_t> &data) { | |||||
| if (impl != nullptr) { | |||||
| (void)impl->ge_tensor.SetData(data); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus Tensor::SetData(const uint8_t *data, size_t size) { | |||||
| if (impl != nullptr) { | |||||
| (void)impl->ge_tensor.SetData(data, size); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus Tensor::SetData(const std::string &data) { | |||||
| if (impl != nullptr && (!data.empty())) { | |||||
| /// Extra 8 bytes store pointer of string | |||||
| /// Extra 1 byte store '\0' | |||||
| size_t total_size = data.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL; | |||||
| std::unique_ptr<char[]> buff(new (std::nothrow) char[total_size]()); | |||||
| if (buff == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| uint64_t *p = reinterpret_cast<uint64_t *>(buff.get()); | |||||
| // Front 8 bytes store pointer of string | |||||
| char *raw_data = buff.get() + EXTRA_STORE_POINTER_FOR_STRING; | |||||
| p[0] = reinterpret_cast<uintptr_t>(raw_data); | |||||
| int32_t memcpy_ret = memcpy_s(raw_data, total_size - EXTRA_STORE_POINTER_FOR_STRING, data.c_str(), data.size() + 1); | |||||
| GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); | |||||
| (void)impl->ge_tensor.SetData(reinterpret_cast<const uint8_t *>(buff.get()), total_size); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus Tensor::SetData(const std::vector<std::string> &data) { | |||||
| if (impl != nullptr) { | |||||
| if (data.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "there is no data, please check the input variable"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| size_t total_size = 0; | |||||
| for (auto str : data) { | |||||
| /// Extra 8 bytes store pointer of each string | |||||
| /// Extra 1 byte store '\0' | |||||
| total_size += (str.size() + EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL); | |||||
| } | |||||
| std::unique_ptr<char[]> buff(new (std::nothrow) char[total_size]); | |||||
| if (buff == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "allocate string raw data buff failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| uint64_t *p = reinterpret_cast<uint64_t *>(buff.get()); | |||||
| // Front some bytes store pointer of each string | |||||
| char *raw_data = buff.get() + data.size() * sizeof(uint64_t); | |||||
| uint64_t ptr_size = data.size() * sizeof(uint64_t); | |||||
| for (size_t i = 0; i < data.size(); ++i) { | |||||
| p[i] = reinterpret_cast<uintptr_t>(raw_data); | |||||
| if (total_size < ptr_size) { | |||||
| GELOGE(GRAPH_FAILED, "Subtraction invalid, total_size: %zu, ptr_size: %lu", total_size, ptr_size); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int32_t memcpy_ret = memcpy_s(raw_data, total_size - ptr_size, data[i].c_str(), data[i].size() + 1); | |||||
| GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); | |||||
| raw_data += (data[i].size() + 1); | |||||
| ptr_size += (data[i].size() + 1); | |||||
| } | |||||
| (void)impl->ge_tensor.SetData(reinterpret_cast<const uint8_t *>(buff.get()), total_size); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus Tensor::IsValid() { | |||||
| uint64_t shape_size = GetTensorDesc().GetShape().GetShapeSize(); | |||||
| DataType data_type = GetTensorDesc().GetDataType(); | |||||
| uint32_t type_length; | |||||
| bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); | |||||
| if (!ret) { | |||||
| GELOGW("datatype %d is not found.", data_type); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| size_t data_size = GetSize(); | |||||
| if (data_type != DT_STRING) { | |||||
| if (shape_size || (data_size != type_length)) { | |||||
| if (type_length != 0 && UINT64_MAX / type_length < shape_size) { | |||||
| GELOGW("mul overflow: %lu, %u", shape_size, type_length); | |||||
| } else { | |||||
| if (shape_size * type_length != data_size) { | |||||
| GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | |||||
| data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| Tensor Tensor::Clone() const { | |||||
| Tensor tensor; | |||||
| if (impl != nullptr && tensor.impl != nullptr) { | |||||
| tensor.impl->ge_tensor = impl->ge_tensor.Clone(); | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) { | |||||
| GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(), | |||||
| tensor_desc.GetDataType()); | |||||
| ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | |||||
| ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | |||||
| ge_tensor_desc.SetName(tensor_desc.GetName()); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| auto status = tensor_desc.GetShapeRange(shape_range); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||||
| return ge_tensor_desc; | |||||
| } | |||||
| status = ge_tensor_desc.SetShapeRange(shape_range); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||||
| return ge_tensor_desc; | |||||
| } | |||||
| auto size = tensor_desc.GetSize(); | |||||
| TensorUtils::SetSize(ge_tensor_desc, size); | |||||
| auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | |||||
| TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); | |||||
| return ge_tensor_desc; | |||||
| } | |||||
| TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) { | |||||
| TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(), | |||||
| ge_tensor_desc.GetDataType()); | |||||
| tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | |||||
| tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | |||||
| tensor_desc.SetName(ge_tensor_desc.GetName()); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| auto status = ge_tensor_desc.GetShapeRange(shape_range); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||||
| return tensor_desc; | |||||
| } | |||||
| status = tensor_desc.SetShapeRange(shape_range); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||||
| return tensor_desc; | |||||
| } | |||||
| int64_t size = 0; | |||||
| (void)TensorUtils::GetSize(ge_tensor_desc, size); | |||||
| tensor_desc.SetSize(size); | |||||
| uint32_t real_dim_cnt = 0; | |||||
| (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); | |||||
| tensor_desc.SetRealDimCnt(real_dim_cnt); | |||||
| return tensor_desc; | |||||
| } | |||||
| GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { | |||||
| GeTensorPtr ge_tensor; | |||||
| if (tensor.impl != nullptr) { | |||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); // lint !e665 | |||||
| } | |||||
| return ge_tensor; | |||||
| } | |||||
| Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { | |||||
| Tensor tensor; | |||||
| if (ge_tensor != nullptr && tensor.impl != nullptr) { | |||||
| tensor.impl->ge_tensor = ge_tensor->Clone(); | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { | |||||
| GeTensorPtr ge_tensor; | |||||
| if (tensor.impl != nullptr) { | |||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665 | |||||
| } | |||||
| return ge_tensor; | |||||
| } | |||||
| GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { | |||||
| GeTensorPtr ge_tensor; | |||||
| if (tensor.impl != nullptr) { | |||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665 | |||||
| } | |||||
| return ge_tensor; | |||||
| } | |||||
| const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) { | |||||
| if (tensor.impl != nullptr) { | |||||
| return tensor.impl->ge_tensor; | |||||
| } | |||||
| return GeTensor(); | |||||
| } | |||||
| GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) { | |||||
| if (tensor.impl != nullptr) { | |||||
| return tensor.impl->ge_tensor; | |||||
| } | |||||
| return GeTensor(); | |||||
| } | |||||
| const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) { | |||||
| Tensor tensor; | |||||
| if (tensor.impl != nullptr) { | |||||
| tensor.impl->ge_tensor = ge_tensor; | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) { | |||||
| Tensor tensor; | |||||
| if (tensor.impl != nullptr) { | |||||
| tensor.impl->ge_tensor = ge_tensor; | |||||
| } | |||||
| return tensor; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,102 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/anchor_utils.h" | |||||
| #include <algorithm> | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Format AnchorUtils::GetFormat(const DataAnchorPtr &data_anchor) { | |||||
| if (data_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); | |||||
| return FORMAT_RESERVED; | |||||
| } | |||||
| return data_anchor->format_; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetFormat(const DataAnchorPtr &data_anchor, | |||||
| Format data_format) { | |||||
| if ((data_anchor == nullptr) || (data_format == FORMAT_RESERVED)) { | |||||
| GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| data_anchor->format_ = data_format; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // Get anchor status | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) { | |||||
| if (data_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "The input data anchor is invalid."); | |||||
| return ANCHOR_RESERVED; | |||||
| } | |||||
| return data_anchor->status_; | |||||
| } | |||||
| // Set anchor status | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor, | |||||
| AnchorStatus anchor_status) { | |||||
| if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) { | |||||
| GELOGE(GRAPH_FAILED, "The input data anchor or input data format is invalid ."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| data_anchor->status_ = anchor_status; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool AnchorUtils::HasControlEdge(const AnchorPtr &anchor) { | |||||
| auto control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(anchor); | |||||
| if (control_anchor != nullptr) { | |||||
| return (control_anchor->GetPeerAnchors().size() != 0); | |||||
| } | |||||
| auto data_anchor = Anchor::DynamicAnchorCast<DataAnchor>(anchor); | |||||
| if (data_anchor) { | |||||
| for (const auto &peer : data_anchor->GetPeerAnchors()) { | |||||
| auto peer_cast = Anchor::DynamicAnchorCast<ControlAnchor>(peer); | |||||
| if (peer_cast) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| GELOGE(GRAPH_FAILED, "the anchor is neither control anchor nor data anchor"); | |||||
| return false; | |||||
| } | |||||
| bool AnchorUtils::IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst) { | |||||
| GE_CHK_BOOL_EXEC(src != nullptr, return false, "src is null."); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(src->IsLinkedWith(dst), false); | |||||
| auto src_control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(src); | |||||
| auto dst_control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(dst); | |||||
| return (src_control_anchor || dst_control_anchor); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int AnchorUtils::GetIdx(const AnchorPtr &anchor) { | |||||
| // Check if it can add edge between DataAnchor | |||||
| auto data_anchor = Anchor::DynamicAnchorCast<DataAnchor>(anchor); | |||||
| if (data_anchor != nullptr) { | |||||
| return data_anchor->GetIdx(); | |||||
| } | |||||
| // Check if it can add edge between ControlAnchor | |||||
| auto control_anchor = Anchor::DynamicAnchorCast<ControlAnchor>(anchor); | |||||
| if (control_anchor != nullptr) { | |||||
| return control_anchor->GetIdx(); | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,206 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||||
| #define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||||
| #include <google/protobuf/map.h> | |||||
| #include <google/protobuf/repeated_field.h> | |||||
| #include <google/protobuf/stubs/port.h> | |||||
| #include <graph/anchor.h> | |||||
| #include <graph/debug/ge_log.h> | |||||
| #include <graph/debug/ge_util.h> | |||||
| #include <graph/detail/attributes_holder.h> | |||||
| #include <graph/ge_tensor.h> | |||||
| #include <graph/graph.h> | |||||
| #include <graph/model.h> | |||||
| #include <graph/node.h> | |||||
| #include <graph/utils/graph_utils.h> | |||||
| #include <graph/utils/type_utils.h> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "proto/ge_ir.pb.h" | |||||
| #include "proto/onnx.pb.h" | |||||
| namespace ge { | |||||
| const int kOffsetToString = 2; | |||||
| /// | |||||
| /// @ingroup ge_ir_utils | |||||
| /// @brief RepeatedField->String | |||||
| /// @param [in] const rpd_field RepeatedField | |||||
| /// @return String | |||||
| /// | |||||
| template <typename T> | |||||
| const std::string ToString(const google::protobuf::RepeatedField<T> &rpd_field) { | |||||
| std::stringstream ss; | |||||
| ss << "["; | |||||
| for (const T &x : rpd_field) { | |||||
| ss << x; | |||||
| ss << ", "; | |||||
| } | |||||
| std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); | |||||
| str_ret += "]"; | |||||
| return str_ret; | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge_ir_utils | |||||
| /// @brief RepeatedPtrField->String | |||||
| /// @param [in] const rpd_field RepeatedPtrField | |||||
| /// @return String | |||||
| /// | |||||
| template <typename T> | |||||
| const std::string ToString(const google::protobuf::RepeatedPtrField<T> &rpd_ptr_field) { | |||||
| std::stringstream ss; | |||||
| ss << "["; | |||||
| for (const T &x : rpd_ptr_field) { | |||||
| ss << x; | |||||
| ss << ", "; | |||||
| } | |||||
| std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); | |||||
| str_ret += "]"; | |||||
| return str_ret; | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge_ir_utils | |||||
| /// @brief check, if not equal, log with tag | |||||
| /// @param [in] const left_value, right_value reference, log_info_tag | |||||
| /// @return bool | |||||
| /// | |||||
| template <typename T> | |||||
| bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { | |||||
| if (l_value == r_value) { | |||||
| return true; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| class OnnxUtils { | |||||
| public: | |||||
| enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END }; | |||||
| static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto); | |||||
| static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model); | |||||
| private: | |||||
| // Part 1: from IR convert to ONNX Protobuf | |||||
| static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, | |||||
| const std::string &name, void *data); | |||||
| static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, | |||||
| const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data); | |||||
| static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, | |||||
| const std::string &name, ::google::protobuf::RepeatedField<bool> data); | |||||
| static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, | |||||
| const std::string &name, ::google::protobuf::RepeatedField<float> data); | |||||
| static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, | |||||
| const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data); | |||||
| static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto); | |||||
| static void AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value, | |||||
| onnx::NodeProto *node_proto); | |||||
| static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | |||||
| static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map, | |||||
| onnx::NodeProto *node_proto, const std::string &prefix = "", | |||||
| const std::string &suffix = ""); | |||||
| static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | |||||
| static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | |||||
| static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto); | |||||
| static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto); | |||||
| static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto); | |||||
| static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto); | |||||
| static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type); | |||||
| static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v); | |||||
| static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto); | |||||
| /// Part 2: from ONNX Protobuf convert to IR | |||||
| /// Describes node's link relationships | |||||
| struct NodeLinkInfo { | |||||
| std::string src_node_name; | |||||
| int32_t src_out_index; | |||||
| NodePtr dst_node; | |||||
| int32_t dst_in_index; | |||||
| std::string dst_node_name; | |||||
| }; | |||||
| // Parse node name and index | |||||
| static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index); | |||||
| static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type); | |||||
| static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings); | |||||
| static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints); | |||||
| static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value); | |||||
| static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | |||||
| static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_output_desc, int32_t index, | |||||
| OpDescPtr &op_desc); | |||||
| static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_input_desc, int32_t index, | |||||
| OpDescPtr &op_desc); | |||||
| static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_input_output_desc, int32_t index, | |||||
| OpDescPtr &op_desc); | |||||
| static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); | |||||
| static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | |||||
| static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | |||||
| static bool DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector, | |||||
| const std::map<std::string, NodePtr> &node_map); | |||||
| static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node); | |||||
| static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||||
| @@ -1,32 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_ | |||||
| #define COMMON_GRAPH_UTILS_MEM_UTILS_H_ | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| namespace ge { | |||||
| template <typename _Tp, typename... _Args> | |||||
| static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { | |||||
| typedef typename std::remove_const<_Tp>::type _Tp_nc; | |||||
| std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| #endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ | |||||
| @@ -1,956 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/node_utils.h" | |||||
| #include "utils/op_desc_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "debug/ge_op_types.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/types.h" | |||||
| #include "utils/tensor_utils.h" | |||||
| #include "utils/type_utils.h" | |||||
| namespace ge { | |||||
| std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | |||||
| std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | |||||
| const std::set<std::string> kConstOpTypes = {"Const", "Constant"}; | |||||
| const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"}; | |||||
| const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"}; | |||||
| const std::set<std::string> kCaseOpTypes = {"Case"}; | |||||
| const std::set<std::string> kForOpTypes = {"For"}; | |||||
| bool OpShapeIsUnknown(const OpDescPtr &desc) { | |||||
| for (const auto &ptr : desc->GetAllInputsDescPtr()) { | |||||
| auto ge_shape = ptr->GetShape(); | |||||
| for (const auto &dim : ge_shape.GetDims()) { | |||||
| if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (const auto &ptr : desc->GetAllOutputsDescPtr()) { | |||||
| auto ge_shape = ptr->GetShape(); | |||||
| for (const auto &dim : ge_shape.GetDims()) { | |||||
| if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, | |||||
| const uint32_t &event_id) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| map_send_info_[node].push_back(event_id); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node, | |||||
| const uint32_t &event_id) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| map_recv_info_[node].push_back(event_id); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto find = map_send_info_.find(node); | |||||
| if (find == map_send_info_.end()) { | |||||
| return GRAPH_FAILED; | |||||
| } else { | |||||
| vec_send = find->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto find = map_recv_info_.find(node); | |||||
| if (find == map_recv_info_.end()) { | |||||
| return GRAPH_FAILED; | |||||
| } else { | |||||
| vec_recv = find->second; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() { | |||||
| map_send_info_.clear(); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() { | |||||
| map_recv_info_.clear(); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) { | |||||
| GE_CHECK_NOTNULL(src); | |||||
| NodePtr cur_ptr; | |||||
| if (depth < 1) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (int i = 0; i < depth; i++) { | |||||
| if (src->GetOutDataNodes().size() != 1) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| cur_ptr = src->GetOutDataNodes().at(0); | |||||
| GE_CHECK_NOTNULL(cur_ptr); | |||||
| } | |||||
| dst = cur_ptr; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, | |||||
| InControlAnchorPtr &in_control) { | |||||
| GE_CHECK_NOTNULL(node_ptr); | |||||
| for (const auto &p : node_ptr->GetAllOutDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr"); | |||||
| for (const auto &p_in : p->GetPeerInControlAnchors()) { | |||||
| GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr"); | |||||
| out_data = p; | |||||
| in_control = p_in; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | |||||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | |||||
| "node or in_data_anchor is nullptr"); | |||||
| bool find_flag = false; | |||||
| uint32_t index = 0; | |||||
| vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | |||||
| for (const auto &tmp : node_ptr->in_data_anchors_) { | |||||
| if (tmp == in_data_anchor) { | |||||
| find_flag = true; | |||||
| auto iter = node_ptr->in_data_anchors_.begin() + index; | |||||
| if (iter != node_ptr->in_data_anchors_.end()) { | |||||
| it = node_ptr->in_data_anchors_.erase(iter); | |||||
| } | |||||
| break; | |||||
| } | |||||
| index++; | |||||
| } | |||||
| for (; it != node_ptr->in_data_anchors_.end(); ++it) { | |||||
| (*it)->SetIdx(index); | |||||
| index++; | |||||
| } | |||||
| if (!find_flag) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { | |||||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { | |||||
| node.anchor_status_updated_ = true; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { | |||||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr"); | |||||
| return IsAnchorStatusSet(*node_ptr); | |||||
| } | |||||
| bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; } | |||||
| graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { | |||||
| if ((origin_node == nullptr) || (new_node == nullptr)) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); | |||||
| auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); | |||||
| if (origin_out_data_anchors.size() != new_out_data_anchors.size()) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) { | |||||
| for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, | |||||
| "unlink peer_anchor failed"); | |||||
| GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, | |||||
| "linkto peer_anchor failed"); | |||||
| } | |||||
| for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue, | |||||
| "unlink peer_anchor failed"); | |||||
| GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, | |||||
| "linkto peer_anchor failed"); | |||||
| } | |||||
| } | |||||
| auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(origin_out_control_anchor); | |||||
| auto new_out_control_anchor = new_node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(new_out_control_anchor); | |||||
| for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, | |||||
| "linkto peer_anchor failed"); | |||||
| } | |||||
| for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { | |||||
| GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue, | |||||
| "linkto peer_anchor failed"); | |||||
| } | |||||
| origin_out_control_anchor->UnlinkAll(); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool NodeUtils::IsConst(const Node &node) { | |||||
| auto src_node_type = node.GetType(); | |||||
| bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); | |||||
| return is_const; | |||||
| } | |||||
| void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "node is null"); | |||||
| return; | |||||
| } | |||||
| UpdateIsInputConst(*node_ptr); | |||||
| } | |||||
| /// | |||||
| /// update is_input_const | |||||
| /// @param node | |||||
| /// @return void | |||||
| /// | |||||
| void NodeUtils::UpdateIsInputConst(Node &node) { | |||||
| std::vector<bool> is_input_const; | |||||
| size_t anchor_num = node.GetAllInDataAnchors().size(); | |||||
| for (size_t i = 0; i < anchor_num; i++) { | |||||
| auto in_anchor = node.GetInDataAnchor(static_cast<int>(i)); | |||||
| if (in_anchor == nullptr) { | |||||
| is_input_const.push_back(false); | |||||
| continue; | |||||
| } | |||||
| auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (peer_out_anchor == nullptr) { | |||||
| is_input_const.push_back(false); | |||||
| continue; | |||||
| } | |||||
| auto src_node = peer_out_anchor->GetOwnerNode(); | |||||
| if (src_node == nullptr) { | |||||
| is_input_const.push_back(false); | |||||
| continue; | |||||
| } | |||||
| if (IsConst(*(src_node))) { | |||||
| is_input_const.push_back(true); | |||||
| } else { | |||||
| is_input_const.push_back(false); | |||||
| } | |||||
| } | |||||
| if (node.GetOpDesc() == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr"); | |||||
| return; | |||||
| } | |||||
| node.GetOpDesc()->SetIsInputConst(is_input_const); | |||||
| } | |||||
| void NodeUtils::UnlinkAll(const Node &node) { | |||||
| for (const auto &anchor : node.GetAllOutAnchors()) { | |||||
| anchor->UnlinkAll(); | |||||
| } | |||||
| for (const auto &anchor : node.GetAllInAnchors()) { | |||||
| anchor->UnlinkAll(); | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) { | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||||
| if (is_unknown_graph) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { | |||||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||||
| auto out_dims = output_tensor->GetShape().GetDims(); | |||||
| auto out_dtype = output_tensor->GetDataType(); | |||||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||||
| output_tensor->SetOriginShape(output_tensor->GetShape()); | |||||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||||
| node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||||
| for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | |||||
| continue; | |||||
| } | |||||
| auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); | |||||
| if (peer_input_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | |||||
| continue; | |||||
| } | |||||
| // check shape and dtype continuity. do not stop process | |||||
| auto peer_input_dims = peer_input_desc->GetShape().GetDims(); | |||||
| auto peer_input_dtype = peer_input_desc->GetDataType(); | |||||
| if (out_dtype != peer_input_dtype) { | |||||
| GELOGW( | |||||
| "current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th " | |||||
| "input_dtype is [%s].The two dtype should be same! Please check graph and fix it", | |||||
| node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(), | |||||
| peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), | |||||
| TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str()); | |||||
| } else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) { | |||||
| string out_shape_str, peer_in_shape_str; | |||||
| out_shape_str += "["; | |||||
| for (int64_t dim : out_dims) { | |||||
| out_shape_str += std::to_string(dim) + " "; | |||||
| } | |||||
| out_shape_str += "]"; | |||||
| peer_in_shape_str += "["; | |||||
| for (int64_t dim : peer_input_dims) { | |||||
| peer_in_shape_str += std::to_string(dim) + " "; | |||||
| } | |||||
| peer_in_shape_str += "]"; | |||||
| GELOGW( | |||||
| "current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th " | |||||
| "input_shape is [%s].The two shape should be same! Please check graph and fix it", | |||||
| node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(), | |||||
| peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str()); | |||||
| } | |||||
| GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), | |||||
| output_tensor->GetDataType(), output_tensor->GetOriginDataType()); | |||||
| peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); | |||||
| peer_input_desc->SetShape(output_tensor->GetShape()); | |||||
| peer_input_desc->SetDataType(output_tensor->GetDataType()); | |||||
| peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)output_tensor->GetShapeRange(shape_range); | |||||
| peer_input_desc->SetShapeRange(shape_range); | |||||
| ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||||
| static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||||
| GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||||
| peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, | |||||
| uint32_t num) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Input node is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| for (size_t i = op_desc->GetInputsSize(); i < num; ++i) { | |||||
| if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Add input desc failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto anchor = ComGraphMakeShared<InDataAnchor>(node, i); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| node->in_data_anchors_.push_back(anchor); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, | |||||
| uint32_t num) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Input node is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| while (op_desc->GetInputsSize() > num) { | |||||
| if (!OpDescUtils::ClearInputDesc(op_desc, num)) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| auto input_names = op_desc->GetAllInputName(); | |||||
| (void)op_desc->UpdateInputName(input_names); | |||||
| auto is_input_const = op_desc->GetIsInputConst(); | |||||
| is_input_const.resize(num); | |||||
| op_desc->SetIsInputConst(is_input_const); | |||||
| while (node->in_data_anchors_.size() > num) { | |||||
| node->in_data_anchors_.pop_back(); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, | |||||
| uint32_t num) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Input node is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); | |||||
| const OpDescPtr &op_desc = node->GetOpDesc(); | |||||
| for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { | |||||
| if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Add output desc failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i); | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| node->out_data_anchors_.push_back(anchor); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, | |||||
| uint32_t num) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Input node is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| auto output_names = op_desc->GetAllOutputName(); | |||||
| while (op_desc->GetOutputsSize() > num) { | |||||
| if (!OpDescUtils::ClearOutputDesc(op_desc, num)) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| (void)op_desc->UpdateOutputName(output_names); | |||||
| while (node->out_data_anchors_.size() > num) { | |||||
| node->out_data_anchors_.pop_back(); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool NodeUtils::IsInNodesEmpty(const Node &node) { | |||||
| for (const auto &in_anchor : node.in_data_anchors_) { | |||||
| if (in_anchor != nullptr) { | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor != nullptr) { | |||||
| if (out_anchor->GetOwnerNode() != nullptr) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) { | |||||
| auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors(); | |||||
| for (const auto &out_control_anchor : peer_out_control_anchors) { | |||||
| if (out_control_anchor != nullptr) { | |||||
| if (out_control_anchor->GetOwnerNode() != nullptr) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) { | |||||
| auto desc = node.GetOpDesc(); | |||||
| if (desc == nullptr) { | |||||
| return GeTensorDesc(); | |||||
| } | |||||
| return desc->GetOutputDesc(index); | |||||
| } | |||||
| GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) { | |||||
| auto desc = node.GetOpDesc(); | |||||
| if (desc == nullptr) { | |||||
| return GeTensorDesc(); | |||||
| } | |||||
| return desc->GetInputDesc(index); | |||||
| } | |||||
| graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) { | |||||
| auto desc = node.GetOpDesc(); | |||||
| if (desc == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto output_desc = desc->MutableOutputDesc(index); | |||||
| if (output_desc == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| output_desc->SetShape(shape); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) { | |||||
| auto desc = node.GetOpDesc(); | |||||
| if (desc == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto input_desc = desc->MutableInputDesc(index); | |||||
| if (input_desc == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| input_desc->SetShape(shape); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||||
| auto desc = node.GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(desc); | |||||
| // check self | |||||
| is_unknow = OpShapeIsUnknown(desc); | |||||
| if (is_unknow) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } else { | |||||
| auto owner_graph = node.GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(owner_graph); | |||||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| for (auto &sub_graph_name : sub_graph_names) { | |||||
| auto sub_graph = root_graph->GetSubgraph(sub_graph_name); | |||||
| GE_CHECK_NOTNULL(sub_graph); | |||||
| for (const auto &node_ptr : sub_graph->GetDirectNode()) { | |||||
| auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GE_LOGE("get node unknown shape status failed!"); | |||||
| return status; | |||||
| } | |||||
| if (is_unknow) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::string NodeUtils::GetNodeType(const Node &node) { | |||||
| if (node.GetType() != FRAMEWORKOP) { | |||||
| return node.GetType(); | |||||
| } | |||||
| std::string type; | |||||
| (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
| return type; | |||||
| } | |||||
| std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } | |||||
| graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||||
| } | |||||
| graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { | |||||
| if (subgraph == nullptr) { | |||||
| GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); | |||||
| return ret; | |||||
| } | |||||
| subgraph->SetParentNode(node.shared_from_this()); | |||||
| subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||||
| return root_graph->AddSubgraph(subgraph); | |||||
| } | |||||
| /// | |||||
| /// Check if node is input of subgraph | |||||
| /// @param [in] node | |||||
| /// @return bool | |||||
| /// | |||||
| bool NodeUtils::IsSubgraphInput(const NodePtr &node) { | |||||
| if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||||
| (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { | |||||
| return false; | |||||
| } | |||||
| auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); | |||||
| if (parent_op_desc == nullptr) { | |||||
| return false; | |||||
| } | |||||
| // dynamic shape unknown graph false | |||||
| // dynamic shape known graph with functional subgraph maybe true | |||||
| if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { | |||||
| if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { | |||||
| return false; | |||||
| } else { | |||||
| if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | |||||
| } | |||||
| /// | |||||
| /// Check if node is output of subgraph | |||||
| /// @param [in] node | |||||
| /// @return bool | |||||
| /// | |||||
| bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { | |||||
| if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||||
| (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { | |||||
| return false; | |||||
| } | |||||
| auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); | |||||
| if (parent_op_desc == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { | |||||
| if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { | |||||
| return false; | |||||
| } else { | |||||
| if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { | |||||
| if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| /// | |||||
| /// @brief Get subgraph original input node. | |||||
| /// @param [in] node | |||||
| /// @return Node | |||||
| /// | |||||
| NodePtr NodeUtils::GetParentInput(const Node &node) { | |||||
| uint32_t parent_index = 0; | |||||
| if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
| return nullptr; | |||||
| } | |||||
| // Subgraph Data Node, check for constant input. | |||||
| const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | |||||
| const NodePtr &parent_node = graph->GetParentNode(); | |||||
| GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); | |||||
| const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); | |||||
| GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); | |||||
| const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); | |||||
| return peer_out_anchor->GetOwnerNode(); | |||||
| } | |||||
| NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); } | |||||
| /// | |||||
| /// @brief Get is dynamic shape graph from node. | |||||
| /// @param [in] node | |||||
| /// @return bool | |||||
| /// | |||||
| bool NodeUtils::IsDynamicShape(const Node &node) { | |||||
| const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (graph == nullptr) { | |||||
| return false; | |||||
| } | |||||
| bool is_dynamic_shape = false; | |||||
| (void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); | |||||
| return is_dynamic_shape; | |||||
| } | |||||
| bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); } | |||||
| /// | |||||
| /// @brief Check is varying_input for while node | |||||
| /// @param [in] node: Data node for subgraph | |||||
| /// @return bool | |||||
| /// | |||||
| bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (node->GetType() != DATA) { | |||||
| return false; // not input_node for subgraph | |||||
| } | |||||
| const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||||
| if (parent_node == nullptr) { | |||||
| return false; // root graph | |||||
| } | |||||
| if (kWhileOpTypes.count(parent_node->GetType()) == 0) { | |||||
| return false; // not input_node for while subgraph | |||||
| } | |||||
| uint32_t index_i = 0; | |||||
| if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { | |||||
| GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| bool varying_flag = true; | |||||
| for (const auto &item : node->GetOutDataNodesAndAnchors()) { | |||||
| if (item.first->GetType() != NETOUTPUT) { | |||||
| continue; | |||||
| } | |||||
| OpDescPtr op_desc = item.first->GetOpDesc(); | |||||
| uint32_t index_o = 0; | |||||
| if ((op_desc == nullptr) || | |||||
| !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { | |||||
| continue; // input for while-cond subgraph | |||||
| } | |||||
| if (index_i != index_o) { | |||||
| continue; // varying input for while-body subgraph | |||||
| } | |||||
| varying_flag = false; | |||||
| break; | |||||
| } | |||||
| return varying_flag; | |||||
| } | |||||
| /// | |||||
| /// @brief Get subgraph input is constant. | |||||
| /// @param [in] node | |||||
| /// @param [out] string | |||||
| /// @return bool | |||||
| /// | |||||
| bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) { | |||||
| type = node->GetType(); | |||||
| return true; | |||||
| } | |||||
| if (node->GetType() != DATA) { | |||||
| return false; // not subgraph input node | |||||
| } | |||||
| const auto &parent = GetParentInput(node); | |||||
| return GetConstOpType(parent, type); | |||||
| } | |||||
| /// | |||||
| /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. | |||||
| /// @param [in] node | |||||
| /// @return return GRAPH_SUCCESS if remove successfully, other for failed. | |||||
| /// | |||||
| Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (subgraph_names.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } else { | |||||
| auto owner_graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(owner_graph); | |||||
| auto root_graph = GraphUtils::FindRootGraph(owner_graph); | |||||
| GE_CHECK_NOTNULL(root_graph); | |||||
| std::unordered_set<std::string> subgraph_to_remove; | |||||
| for (auto &subgraph_name : subgraph_names) { | |||||
| std::deque<std::string> queue; | |||||
| queue.push_back(subgraph_name); | |||||
| subgraph_to_remove.insert(subgraph_name); | |||||
| op_desc->RemoveSubgraphInstanceName(subgraph_name); | |||||
| while (!queue.empty()) { | |||||
| auto graph_name = queue.front(); | |||||
| queue.pop_front(); | |||||
| auto subgraph = root_graph->GetSubgraph(graph_name); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| for (const auto &sub_node : subgraph->GetDirectNode()) { | |||||
| auto sub_op_desc = sub_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(sub_op_desc); | |||||
| auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); | |||||
| // Subgraph and all nodes in it will be removed later, | |||||
| // no need to remove 'SubgraphInstanceName' in op desc here. | |||||
| for (auto &name : sub_names) { | |||||
| if (subgraph_to_remove.insert(name).second) { | |||||
| queue.push_back(name); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| // Remove subgraph from root_graph | |||||
| for (const auto &name : subgraph_to_remove) { | |||||
| GELOGI("Remove subgraph:%s.", name.c_str()); | |||||
| root_graph->RemoveSubgraph(name); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @brief Get subgraph input data node by index. | |||||
| /// @param [in] node | |||||
| /// @return Node | |||||
| /// | |||||
| vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { | |||||
| vector<NodePtr> in_data_node_vec; | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); | |||||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (subgraph_names.empty()) { | |||||
| GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); | |||||
| return in_data_node_vec; | |||||
| } | |||||
| auto compute_graph = node.GetOwnerComputeGraph(); | |||||
| for (const std::string &instance_name : subgraph_names) { | |||||
| auto subgraph = compute_graph->GetSubgraph(instance_name); | |||||
| for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||||
| int parent_index = -1; | |||||
| if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { | |||||
| (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); | |||||
| if (parent_index == index) { | |||||
| in_data_node_vec.emplace_back(node_in_subgraph); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return in_data_node_vec; | |||||
| } | |||||
| /// | |||||
| /// @brief Get subgraph input data node by index. | |||||
| /// @param [in] node | |||||
| /// @return Node | |||||
| /// | |||||
| vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) { | |||||
| vector<NodePtr> out_data_node_vec; | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); | |||||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (subgraph_names.empty()) { | |||||
| GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); | |||||
| return out_data_node_vec; | |||||
| } | |||||
| auto compute_graph = node.GetOwnerComputeGraph(); | |||||
| for (const std::string &instance_name : subgraph_names) { | |||||
| auto subgraph = compute_graph->GetSubgraph(instance_name); | |||||
| for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||||
| if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { | |||||
| out_data_node_vec.emplace_back(node_in_subgraph); | |||||
| } | |||||
| } | |||||
| } | |||||
| return out_data_node_vec; | |||||
| } | |||||
| NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) { | |||||
| if (node.GetInDataAnchor(index) == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| } | |||||
| vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) { | |||||
| vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes; | |||||
| auto out_data_anchor = node.GetOutDataAnchor(index); | |||||
| if (out_data_anchor == nullptr) { | |||||
| return out_data_nodes; | |||||
| } | |||||
| for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||||
| if (peer_in_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (peer_in_anchor->GetOwnerNode() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode())); | |||||
| } | |||||
| return out_data_nodes; | |||||
| } | |||||
| ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } | |||||
| } // namespace ge | |||||
| @@ -1,778 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/op_desc_utils.h" | |||||
| #include <algorithm> | |||||
| #include "debug/ge_attr_define.h" | |||||
| #include "debug/ge_op_types.h" | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "utils/graph_utils.h" | |||||
| #include "utils/node_utils.h" | |||||
| using std::vector; | |||||
| /*lint -e512 -e737 -e752*/ | |||||
| namespace ge { | |||||
| const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; | |||||
| static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; | |||||
| bool OpDescUtils::ClearInputDesc(const NodePtr &node) { | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); | |||||
| vector<int> index_list; | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_anchor->GetPeerOutAnchor() == nullptr) { | |||||
| index_list.push_back(in_anchor->GetIdx()); | |||||
| } | |||||
| } | |||||
| std::sort(index_list.begin(), index_list.end()); | |||||
| // Node's in anchor index need shrink | |||||
| for (size_t i = 0; i < index_list.size(); ++i) { | |||||
| auto iter = node->GetOpDesc()->inputs_desc_.begin() + index_list[i]; | |||||
| if (iter < node->GetOpDesc()->inputs_desc_.end()) { | |||||
| (void)node->GetOpDesc()->inputs_desc_.erase(iter); | |||||
| } else { | |||||
| GELOGW("inputs_desc_ iterator out of range."); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(OpDescPtr op_desc, | |||||
| const uint32_t index) { | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(index < op_desc->inputs_desc_.size(), return false, "index %u is invalid.", index); | |||||
| auto iter = op_desc->inputs_desc_.begin() + index; | |||||
| if (iter < op_desc->inputs_desc_.end()) { | |||||
| (void)op_desc->inputs_desc_.erase(iter); | |||||
| } else { | |||||
| GELOGW("inputs_desc_ iterator out of range."); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { | |||||
| GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return false, "op_desc is nullptr"); | |||||
| return op_desc->HasAttr(OP_DESC_QUANT_PARAMS); | |||||
| } | |||||
| bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, return false, "opdesc is nullptr"); | |||||
| vector<int> index_list; | |||||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
| if (out_anchor->GetPeerInDataAnchors().empty()) { | |||||
| index_list.push_back(out_anchor->GetIdx()); | |||||
| } | |||||
| } | |||||
| std::sort(index_list.begin(), index_list.end()); | |||||
| // Node's out anchor index need shrink | |||||
| for (size_t i = 0; i < index_list.size(); ++i) { | |||||
| auto iter = node->GetOpDesc()->outputs_desc_.begin() + index_list[i]; | |||||
| if (iter < node->GetOpDesc()->outputs_desc_.end()) { | |||||
| (void)node->GetOpDesc()->outputs_desc_.erase(iter); | |||||
| } else { | |||||
| GELOGW("outputs_desc_ iterator out of range."); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, | |||||
| uint32_t index) { | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(index < op_desc->outputs_desc_.size(), return false, "index %u is invalid.", index); | |||||
| auto iter = op_desc->outputs_desc_.begin() + index; | |||||
| if (iter < op_desc->outputs_desc_.end()) { | |||||
| (void)op_desc->outputs_desc_.erase(iter); | |||||
| } else { | |||||
| GELOGW("outputs_desc_ iterator out of range."); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| OpDescUtils::GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant) { | |||||
| GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | |||||
| GeAttrValue attr_value; | |||||
| GE_CHK_BOOL_EXEC_INFO(op_desc->GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, | |||||
| "GetQuantizeFactorParams failed"); | |||||
| return attr_value.GetValue<QuantizeFactorParams>(quant); | |||||
| } | |||||
| graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant) { | |||||
| GeAttrValue attr_value; | |||||
| GE_CHK_BOOL_EXEC_INFO(op_desc.GetAttr(OP_DESC_QUANT_PARAMS, attr_value) == GRAPH_SUCCESS, return GRAPH_FAILED, | |||||
| "GetQuantizeFactorParams failed"); | |||||
| return attr_value.GetValue<QuantizeFactorParams>(quant); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { | |||||
| GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | |||||
| return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||||
| } | |||||
| graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { | |||||
| return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||||
| } | |||||
| GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { | |||||
| GeTensorPtr weight = nullptr; | |||||
| if (!AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight)) { | |||||
| GELOGW("MutableTensor error"); | |||||
| } | |||||
| return weight; | |||||
| } | |||||
| GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(OpDescPtr op_desc) { | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "op_desc is null"); | |||||
| return nullptr; | |||||
| } | |||||
| return MutableWeights(*op_desc); | |||||
| } | |||||
| graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { | |||||
| if (weight == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "weight is null"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; | |||||
| } | |||||
| graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| GE_CHECK_NOTNULL(weight); | |||||
| return SetWeights(*op_desc, weight); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights(const ge::Node &node) { | |||||
| auto weights = MutableWeights(node); | |||||
| vector<ConstGeTensorPtr> ret(weights.size()); | |||||
| std::copy(weights.begin(), weights.end(), ret.begin()); | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetWeights( | |||||
| const ge::ConstNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return vector<ge::ConstGeTensorPtr>(); | |||||
| } | |||||
| return GetWeights(*node); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputNode( | |||||
| const ge::Node &node) { | |||||
| vector<ge::NodePtr> ret; | |||||
| auto in_anchors = node.GetAllInDataAnchors(); | |||||
| for (const auto &in_anchor : in_anchors) { | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| // normally out_anchor could be null, this is ok | |||||
| GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto in_node = out_anchor->GetOwnerNode(); | |||||
| while (true) { | |||||
| if (in_node == nullptr) { | |||||
| break; | |||||
| } | |||||
| if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||||
| ret.push_back(in_node); | |||||
| break; | |||||
| } else if (in_node->GetType() == DATA) { | |||||
| if (NodeUtils::IsWhileVaryingInput(in_node)) { | |||||
| break; | |||||
| } | |||||
| in_node = NodeUtils::GetParentInput(in_node); | |||||
| } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { | |||||
| bool is_constant = false; | |||||
| (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); | |||||
| if (!is_constant) { | |||||
| break; | |||||
| } | |||||
| // Enter node has and only has one input | |||||
| if (in_node->GetInDataNodes().size() != 1) { | |||||
| GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), | |||||
| in_node->GetInDataNodes().size()); | |||||
| break; | |||||
| } | |||||
| in_node = in_node->GetInDataNodes().at(0); | |||||
| } else { | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | |||||
| const vector<ge::NodePtr> &input_nodes) { | |||||
| vector<ConstGeTensorPtr> ret; | |||||
| for (const auto &input_node : input_nodes) { | |||||
| auto temp_weight = MutableWeights(input_node->GetOpDesc()); | |||||
| if (temp_weight == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); | |||||
| return vector<ConstGeTensorPtr>(); | |||||
| } | |||||
| ret.push_back(temp_weight); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { | |||||
| if (NodeUtils::IsAnchorStatusSet(node)) { | |||||
| size_t input_num = 0; | |||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||||
| if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { | |||||
| input_num++; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| return input_num; // lint !e712 | |||||
| } else { | |||||
| GE_IF_BOOL_EXEC( | |||||
| node.GetInDataNodes().size() < GetConstInputs(node).size(), | |||||
| GELOGE(GRAPH_FAILED, "%zu is smaller than %zu", node.GetInDataNodes().size(), GetConstInputs(node).size()); | |||||
| return 0); | |||||
| return node.GetInDataNodes().size() - GetConstInputs(node).size(); | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Node is nullptr"); | |||||
| return 0; | |||||
| } | |||||
| return GetNonConstInputsSize(*node); | |||||
| } | |||||
| GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const) { | |||||
| GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GeTensorDesc(), "node.GetOpDesc() is nullptr!"); | |||||
| size_t i = 0; | |||||
| if (NodeUtils::IsAnchorStatusSet(node)) { | |||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||||
| if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { | |||||
| if (index_non_const == i) { | |||||
| return node.GetOpDesc()->GetInputDesc(static_cast<uint32_t>(anchor->GetIdx())); | |||||
| } | |||||
| ++i; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||||
| auto peer_anchor = anchor->GetPeerOutAnchor(); | |||||
| if (peer_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto owner_node = peer_anchor->GetOwnerNode(); | |||||
| if (owner_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (owner_node->GetType() == CONSTANT) { | |||||
| continue; | |||||
| } | |||||
| if (index_non_const == i) { | |||||
| return node.GetOpDesc()->GetInputDesc(anchor->GetIdx()); | |||||
| } | |||||
| ++i; | |||||
| } | |||||
| } | |||||
| return GeTensorDesc(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc | |||||
| OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const) { | |||||
| CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc()); | |||||
| return GetNonConstInputTensorDesc(*node, index_non_const); | |||||
| } | |||||
| bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) { | |||||
| bool ret = false; | |||||
| size_t i = 0; | |||||
| if (NodeUtils::IsAnchorStatusSet(node)) { | |||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||||
| if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { | |||||
| if (index_non_const == i) { | |||||
| index = static_cast<size_t>(anchor->GetIdx()); | |||||
| ret = true; | |||||
| } | |||||
| ++i; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||||
| auto peer_anchor = anchor->GetPeerOutAnchor(); | |||||
| if (peer_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto owner_node = peer_anchor->GetOwnerNode(); | |||||
| if (owner_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (owner_node->GetType() == CONSTANT) { | |||||
| continue; | |||||
| } | |||||
| if (index_non_const == i) { | |||||
| index = static_cast<size_t>(anchor->GetIdx()); | |||||
| ret = true; | |||||
| } | |||||
| ++i; | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node, | |||||
| size_t index_non_const, | |||||
| size_t &index) { | |||||
| CHECK_FALSE_EXEC(node != nullptr, return false); | |||||
| return GetNonConstInputIndex(*node, index_non_const, index); | |||||
| } | |||||
| bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { | |||||
| bool ret = false; | |||||
| if (index < node.GetAllInDataAnchors().size()) { | |||||
| if (NodeUtils::IsAnchorStatusSet(node)) { | |||||
| ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712 | |||||
| } else { | |||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||||
| if (anchor->GetIdx() != static_cast<int>(index)) { | |||||
| continue; | |||||
| } | |||||
| auto peer_anchor = anchor->GetPeerOutAnchor(); | |||||
| if (peer_anchor == nullptr) { | |||||
| break; | |||||
| } | |||||
| auto owner_node = peer_anchor->GetOwnerNode(); | |||||
| if (owner_node == nullptr) { | |||||
| break; | |||||
| } | |||||
| ret = (owner_node->GetType() != CONSTANT); | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node, | |||||
| size_t index) { | |||||
| CHECK_FALSE_EXEC(node != nullptr, return false); | |||||
| return IsNonConstInput(*node, index); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs( | |||||
| const ge::ConstNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return vector<ge::NodePtr>(); | |||||
| } | |||||
| return GetConstInputs(*node); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUtils::GetNonConstTensorDesc( | |||||
| const ge::ConstNodePtr &node) { | |||||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||||
| return vector<ge::GeTensorDesc>(); | |||||
| } | |||||
| vector<ge::GeTensorDesc> ret; | |||||
| if (NodeUtils::IsAnchorStatusSet(*node)) { | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | |||||
| ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr || out_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||||
| continue; | |||||
| } | |||||
| if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | |||||
| ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils::GetConstInputs(const ge::Node &node) { | |||||
| vector<ge::NodePtr> ret; | |||||
| auto in_anchors = node.GetAllInDataAnchors(); | |||||
| for (const auto &in_anchor : in_anchors) { | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) continue; | |||||
| auto in_node = out_anchor->GetOwnerNode(); | |||||
| if (in_node->GetType() == CONSTANT) { | |||||
| ret.push_back(in_node); | |||||
| } else if (in_node->GetType() == SWITCH && node.GetType() == MATMUL) { | |||||
| // const --> switch --> matmul | |||||
| auto switch_input = GetConstInputs(*in_node); | |||||
| if (switch_input.size() > 0) { | |||||
| ret.insert(ret.end(), switch_input.begin(), switch_input.end()); | |||||
| } | |||||
| } else if (in_node->GetType() == DATA) { | |||||
| auto parent = NodeUtils::GetParentInput(in_node); | |||||
| if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { | |||||
| ret.push_back(parent); | |||||
| } | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) { | |||||
| vector<GeTensorPtr> ret; | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); | |||||
| // Place holder operator, try to get the weight from parent node | |||||
| // when parent node is const operator | |||||
| if (node.GetType() == PLACEHOLDER) { | |||||
| std::string parent_op; | |||||
| (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); | |||||
| // This if judgment is necessary because the current subgraph optimization is multithreaded | |||||
| // and the parent node of the PLD operation should be a stable type, such as const | |||||
| if (parent_op == CONSTANT || parent_op == CONSTANTOP) { | |||||
| NodePtr parent_node = nullptr; | |||||
| parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); | |||||
| if (parent_node != nullptr) { | |||||
| op_desc = parent_node->GetOpDesc(); | |||||
| GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| // Const operator, take the weight directly | |||||
| if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { | |||||
| auto weight = MutableWeights(op_desc); | |||||
| if (weight == nullptr) { | |||||
| GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret.push_back(weight); | |||||
| return ret; | |||||
| } | |||||
| if (node.GetType() == DATA) { | |||||
| auto parent = NodeUtils::GetParentInput(node); | |||||
| if ((parent != nullptr) && NodeUtils::IsConst(*parent)) { | |||||
| auto weight = MutableWeights(parent->GetOpDesc()); | |||||
| if (weight == nullptr) { | |||||
| GELOGI("const op has no weight, op name:%s", parent->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret.push_back(weight); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| // Other operators, get weights from connected constop | |||||
| auto input_nodes = GetConstInputs(node); | |||||
| for (const auto &input_node : input_nodes) { | |||||
| auto temp_weight = MutableWeights(input_node->GetOpDesc()); | |||||
| if (temp_weight == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "const op's weight is null, name: %s", input_node->GetName().c_str()); | |||||
| return vector<GeTensorPtr>(); | |||||
| } | |||||
| ret.push_back(temp_weight); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::NodePtr node) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Node is nullptr"); | |||||
| return vector<ge::GeTensorPtr>(); | |||||
| } | |||||
| return MutableWeights(*node); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) { | |||||
| GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return GRAPH_PARAM_INVALID, "node.GetOpDesc is nullptr!"); | |||||
| if (node.GetOpDesc()->GetType() == CONSTANT) { | |||||
| if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { | |||||
| return SetWeights(node.GetOpDesc(), weights[0]); | |||||
| } | |||||
| GELOGI("const op weight size %zu should be 1", weights.size()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto input_nodes = GetConstInputs(node); | |||||
| if (weights.size() < input_nodes.size()) { | |||||
| GELOGE(GRAPH_FAILED, "weights count can't be less than const input count"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| ge::GeAttrValue::NAMED_ATTRS named_attrs; | |||||
| (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); | |||||
| vector<ge::GeTensorPtr> copy_weights; | |||||
| (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); | |||||
| for (size_t i = 0; i < input_nodes.size(); ++i) { | |||||
| if (input_nodes[i]->GetOpDesc() != nullptr) { | |||||
| SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]); | |||||
| } | |||||
| } | |||||
| // If set more weights than constop, need to add constop | |||||
| for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) { | |||||
| // Use org weight before SetWeights Overwrite | |||||
| auto const_opdesc = CreateConstOp(copy_weights[i]); | |||||
| GE_CHECK_NOTNULL(const_opdesc); | |||||
| auto owner_graph = node.GetOwnerComputeGraph(); | |||||
| if (owner_graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "node's graph is empty, name: %s", node.GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto const_node = owner_graph->AddNodeFront(const_opdesc); | |||||
| GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, return GRAPH_FAILED, "graph add link failed!"); | |||||
| std::vector<ge::NodePtr> original_nodes; | |||||
| ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { | |||||
| GE_CHK_BOOL_EXEC(tensor_ptr != nullptr, return nullptr, "tensor_ptr is nullptr!"); | |||||
| shared_ptr<OpDesc> const_opdesc = ComGraphMakeShared<OpDesc>(); | |||||
| if (const_opdesc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "failed to make_shared "); | |||||
| return nullptr; | |||||
| } | |||||
| CHECK_FALSE_EXEC(SetWeights(const_opdesc, tensor_ptr) == ge::GRAPH_SUCCESS, return nullptr); | |||||
| const_opdesc->SetType(CONSTANT); | |||||
| thread_local int64_t const_count = 0; | |||||
| const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count)); | |||||
| GELOGI("add const op: %s", const_opdesc->GetName().c_str()); | |||||
| ++const_count; | |||||
| (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); | |||||
| GELOGI("after add const op: %s", const_opdesc->GetName().c_str()); | |||||
| return const_opdesc; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| OpDescUtils::AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) { | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| GE_CHECK_NOTNULL(tensor_ptr); | |||||
| auto const_opdesc = CreateConstOp(tensor_ptr); | |||||
| GE_CHECK_NOTNULL(const_opdesc); | |||||
| auto in_node = in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(in_node); | |||||
| auto owner_graph = in_node->GetOwnerComputeGraph(); | |||||
| if (owner_graph == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "node's graph is empty, name: %s", in_node->GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc); | |||||
| GE_CHECK_NOTNULL(const_node); | |||||
| if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Addedge const to node failed."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| OpDescUtils::SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| return SetWeights(*node, weights); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto const_ops = GetConstInputs(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Graph is nullptr"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| for (const auto &const_op : const_ops) { | |||||
| GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "Isolate removed node: %s, type: %s failed", | |||||
| const_op->GetName().c_str(), const_op->GetType().c_str()); | |||||
| GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op), | |||||
| "Remove node: %s, type: %s without relink failed", const_op->GetName().c_str(), | |||||
| const_op->GetType().c_str()); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @brief Add input | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||||
| inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add input | |||||
| /// @param [in] name | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, | |||||
| const GeTensorDesc &tensor) { | |||||
| inputs_.emplace_back(std::make_pair(name, tensor)); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add dynamic input | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||||
| uint32_t num) { | |||||
| for (uint32_t i = 0; i < num; i++) { | |||||
| inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add dynamic input | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( | |||||
| const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||||
| for (uint32_t i = 0; i < num; i++) { | |||||
| inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add output | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||||
| outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add output | |||||
| /// @param [in] name | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, | |||||
| const GeTensorDesc &tensor) { | |||||
| outputs_.emplace_back(std::make_pair(name, tensor)); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add dynamic output | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||||
| uint32_t num) { | |||||
| for (uint32_t i = 0; i < num; i++) { | |||||
| outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add dynamic output | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( | |||||
| const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||||
| for (uint32_t i = 0; i < num; i++) { | |||||
| outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Build op_desc | |||||
| /// @return OpDescPtr | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { | |||||
| OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_)); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); | |||||
| return nullptr; | |||||
| } | |||||
| for (auto &input : inputs_) { | |||||
| if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| for (auto &output : outputs_) { | |||||
| if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return op_desc; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( | |||||
| const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { | |||||
| const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||||
| auto iter = subgraph_names_to_index.find(subgraph_name); | |||||
| if (iter == subgraph_names_to_index.end()) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, | |||||
| "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", | |||||
| subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
| subgraph_name.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | |||||
| } | |||||
| } // namespace ge | |||||
| /*lint +e512 +e737 +e752*/ | |||||
| @@ -1,68 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_UTILS_STRING_UTILS_H_ | |||||
| #define COMMON_GRAPH_UTILS_STRING_UTILS_H_ | |||||
| #include <algorithm> | |||||
| #include <functional> | |||||
| #include <sstream> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "securec.h" | |||||
| namespace ge { | |||||
| class StringUtils { | |||||
| public: | |||||
| static std::string &Ltrim(std::string &s) { | |||||
| (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); | |||||
| return s; | |||||
| } | |||||
| static std::string &Rtrim(std::string &s) { | |||||
| (void)s.erase(std::find_if(s.rbegin(), s.rend(), [](int c) { return !std::isspace(c); }).base(), s.end()); | |||||
| return s; | |||||
| } | |||||
| /// @ingroup domi_common | |||||
| /// @brief trim space | |||||
| static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } | |||||
| // split string | |||||
| static std::vector<std::string> Split(const std::string &str, char delim) { | |||||
| std::vector<std::string> elems; | |||||
| if (str.empty()) { | |||||
| elems.emplace_back(""); | |||||
| return elems; | |||||
| } | |||||
| std::stringstream ss(str); | |||||
| std::string item; | |||||
| while (getline(ss, item, delim)) { | |||||
| elems.push_back(item); | |||||
| } | |||||
| auto str_size = str.size(); | |||||
| if (str_size > 0 && str[str_size - 1] == delim) { | |||||
| elems.emplace_back(""); | |||||
| } | |||||
| return elems; | |||||
| } | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // COMMON_GRAPH_UTILS_STRING_UTILS_H_ | |||||
| @@ -1,401 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include <cmath> | |||||
| #include "debug/ge_log.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| // When nc1hwc0 dim size = 5, calc element count directly. | |||||
| const uint32_t kNc1hwc0CalcByDimsSize = 5; | |||||
| // Unknown shape element num | |||||
| const int64_t kElementCntUnknownShape = -1; | |||||
| // Unknown shape mem size | |||||
| const int64_t kMemSizeUnknownShape = -1; | |||||
| // Nchw and nhwc dim size must be 4 | |||||
| const uint32_t kDimSize4d = 4; | |||||
| // C1HWNCoC0 dim size must be 6 | |||||
| const uint32_t kDimSizeC1hwncoc0 = 6; | |||||
| // Cube size is 16 | |||||
| const uint32_t kTheCubeSize = 16; | |||||
| // Default c0 size equals cube size. | |||||
| const uint32_t kC0SizeDefault = kTheCubeSize; | |||||
| // Size equals int8 cube size is 32 | |||||
| const uint32_t kC0SizeInt8 = 32; | |||||
| // NCHW dim N index | |||||
| const int32_t kNchwDimIdxN = 0; | |||||
| // NCHW dim C index | |||||
| const int32_t kNchwDimIdxC = 1; | |||||
| // NCHW dim H index | |||||
| const int32_t kNchwDimIdxH = 2; | |||||
| // NCHW dim W index | |||||
| const int32_t kNchwDimIdxW = 3; | |||||
| const int kDataMemAlignSize = 32; | |||||
| const int kNum2 = 2; | |||||
| } // namespace | |||||
| /// | |||||
| /// Check if a * b overflow. | |||||
| /// @param a multiplier | |||||
| /// @param b Multiplicand | |||||
| /// @return true: overflow | |||||
| /// false: not overflow | |||||
| /// | |||||
| static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) { | |||||
| if (a > 0) { | |||||
| if (b > 0) { | |||||
| if (a > (INT64_MAX / b)) { | |||||
| return true; | |||||
| } | |||||
| } else { | |||||
| if (b < (INT64_MIN / a)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (b > 0) { | |||||
| if (a < (INT64_MIN / b)) { | |||||
| return true; | |||||
| } | |||||
| } else { | |||||
| if ((a != 0) && (b < (INT64_MAX / a))) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| /// | |||||
| /// Calculate element num by dims directly. | |||||
| /// @param dims dim info | |||||
| /// @param element_cnt element count | |||||
| /// @return GRAPH_SUCCESS:success | |||||
| /// other:failed | |||||
| /// | |||||
| static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_t &element_cnt) { | |||||
| element_cnt = 1; | |||||
| for (int64_t dim : dims) { | |||||
| if (CheckMultiplyOverflowInt64(element_cnt, dim)) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19013", {"function", "var1", "var2"}, | |||||
| {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); | |||||
| GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| element_cnt *= dim; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// Calculate fixed dims element num. | |||||
| /// @param dims dim info | |||||
| /// @param fixed_dim_size fixed dim size | |||||
| /// @param element_cnt element count | |||||
| /// @return GRAPH_SUCCESS:success | |||||
| /// other:failed | |||||
| /// | |||||
| static graphStatus CalcElementCntOfFixedDims(const std::vector<int64_t> &dims, Format format, uint32_t fixed_dim_size, | |||||
| int64_t &element_cnt) { | |||||
| if (dims.size() != fixed_dim_size) { | |||||
| GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format, | |||||
| TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size()); | |||||
| } | |||||
| return CalcElementCntByDims(dims, element_cnt); | |||||
| } | |||||
| /// | |||||
| /// Get dim c0 size by type | |||||
| /// @param data_type data type | |||||
| /// @return c0 size | |||||
| /// | |||||
| static uint32_t GetDimC0(DataType &data_type) { | |||||
| bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) || | |||||
| (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8); | |||||
| return is_int8_size ? kC0SizeInt8 : kC0SizeDefault; | |||||
| } | |||||
| /// | |||||
| /// Calculate nc1hwc0 element num. | |||||
| /// @param dims dim info | |||||
| /// @param data_type data type | |||||
| /// @param element_cnt element count | |||||
| /// @return GRAPH_SUCCESS:success | |||||
| /// other:failed | |||||
| /// | |||||
| static graphStatus CalcElementCntOfNc1hwc0(const std::vector<int64_t> &dims, DataType data_type, int64_t &element_cnt) { | |||||
| // When nc1hwc0 dims size = 5, no need split dim c | |||||
| if (dims.size() == kNc1hwc0CalcByDimsSize) { | |||||
| return CalcElementCntByDims(dims, element_cnt); | |||||
| } else if (dims.size() != kDimSize4d) { | |||||
| GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d, | |||||
| kNc1hwc0CalcByDimsSize); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto c0 = static_cast<int64_t>(GetDimC0(data_type)); | |||||
| // Nc1hwc0 dims is according to nchw, dim c index is 1. | |||||
| auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); | |||||
| // Store dims is split c to c1 and c0. | |||||
| std::vector<int64_t> store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; | |||||
| return CalcElementCntByDims(store_dims, element_cnt); | |||||
| } | |||||
| /// | |||||
| /// Calculate FractalZ element num. | |||||
| /// @param dims dim info | |||||
| /// @param data_type data type | |||||
| /// @param element_cnt element count | |||||
| /// @return GRAPH_SUCCESS:success | |||||
| /// other:failed | |||||
| /// | |||||
| static graphStatus CalcElementCntOfFractalZ(const std::vector<int64_t> &dims, DataType data_type, | |||||
| int64_t &element_cnt) { | |||||
| static char *parser_priority = std::getenv("PARSER_PRIORITY"); | |||||
| if (parser_priority != nullptr && string(parser_priority) == "cce") { | |||||
| if (dims.size() != kDimSize4d) { | |||||
| GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto c0 = static_cast<int64_t>(GetDimC0(data_type)); | |||||
| // FractalZ dims is according to nchw, dim c index is 1. | |||||
| auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0)); | |||||
| // Spread NC1HWC0 as a two dimension array, n as column dimension, | |||||
| // C1HWC0 as row dimension | |||||
| std::vector<int64_t> r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0}; | |||||
| int64_t r_count = 1; | |||||
| graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count); | |||||
| if (graph_status != GRAPH_SUCCESS) { | |||||
| GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH], | |||||
| dims[kNchwDimIdxW], c0); | |||||
| return graph_status; | |||||
| } | |||||
| // Cube count in n | |||||
| auto nc_cnt = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize)); | |||||
| // Cube count in vertical direction(C1HWC0) | |||||
| int64_t vc_cnt = r_count / c0; | |||||
| // Element count in each cube | |||||
| int64_t cube_elem_cnt = c0 * kTheCubeSize; | |||||
| if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) { | |||||
| GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // Read data times needed by cube | |||||
| int64_t c_cnt = nc_cnt * vc_cnt; | |||||
| if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) { | |||||
| GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // Element count after fractal arrangement | |||||
| element_cnt = c_cnt * cube_elem_cnt; | |||||
| return GRAPH_SUCCESS; | |||||
| } else { | |||||
| return CalcElementCntByDims(dims, element_cnt); | |||||
| } | |||||
| } | |||||
| /// | |||||
| /// Calculate tensor element num. | |||||
| /// @param dims dim info | |||||
| /// @param format tensor format | |||||
| /// @param data_type data type | |||||
| /// @param element_cnt element count | |||||
| /// @return GRAPH_SUCCESS:success | |||||
| /// other:failed | |||||
| /// | |||||
| static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format format, DataType data_type, | |||||
| int64_t &element_cnt) { | |||||
| const string format_str = TypeUtils::FormatToSerialString(format); | |||||
| // Check dims | |||||
| for (size_t i = 0; i < dims.size(); ++i) { | |||||
| int64_t dim = dims[i]; | |||||
| if (dim < 0) { | |||||
| GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str()); | |||||
| element_cnt = kElementCntUnknownShape; | |||||
| return GRAPH_SUCCESS; | |||||
| } else if (dim == 0) { | |||||
| GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str()); | |||||
| element_cnt = 0; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } | |||||
| graphStatus graph_status; | |||||
| switch (format) { | |||||
| case FORMAT_ND: | |||||
| case FORMAT_MD: | |||||
| graph_status = CalcElementCntByDims(dims, element_cnt); | |||||
| break; | |||||
| case FORMAT_NCHW: | |||||
| case FORMAT_HWCN: | |||||
| case FORMAT_NHWC: | |||||
| case FORMAT_CHWN: | |||||
| graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt); | |||||
| break; | |||||
| case FORMAT_C1HWNCoC0: | |||||
| graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt); | |||||
| break; | |||||
| case FORMAT_NC1HWC0: | |||||
| graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt); | |||||
| break; | |||||
| case FORMAT_FRACTAL_Z: | |||||
| graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); | |||||
| break; | |||||
| case FORMAT_FRACTAL_NZ: | |||||
| case FORMAT_FRACTAL_ZZ: | |||||
| case FORMAT_NDHWC: | |||||
| case FORMAT_NCDHW: | |||||
| case FORMAT_DHWCN: | |||||
| case FORMAT_DHWNC: | |||||
| case FORMAT_FRACTAL_Z_3D: | |||||
| case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||||
| case FORMAT_NDC1HWC0: | |||||
| case FORMAT_FRACTAL_Z_C04: | |||||
| case FORMAT_FRACTAL_ZN_LSTM: | |||||
| case FORMAT_NC1HWC0_C04: | |||||
| graph_status = CalcElementCntByDims(dims, element_cnt); | |||||
| break; | |||||
| default: | |||||
| GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); | |||||
| graph_status = GRAPH_FAILED; | |||||
| break; | |||||
| } | |||||
| const string type_str = TypeUtils::DataTypeToSerialString(data_type); | |||||
| if (graph_status == GRAPH_SUCCESS) { | |||||
| GELOGD( | |||||
| "CalcTensorElementCnt end, format=%d(%s)," | |||||
| " data_type=%d(%s), element_cnt=%ld.", | |||||
| format, format_str.c_str(), data_type, type_str.c_str(), element_cnt); | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(), | |||||
| data_type, type_str.c_str()); | |||||
| } | |||||
| return graph_status; | |||||
| } | |||||
| /// | |||||
| /// Calculate tensor mem size. | |||||
| /// @param shape tensor shape | |||||
| /// @param format tensor format | |||||
| /// @param data_type tensor data type | |||||
| /// @param mem_size -1 means unknown shape,other means mem size | |||||
| /// @return GRAPH_SUCCESS:success, other:failed | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape, | |||||
| Format format, | |||||
| DataType data_type, | |||||
| int64_t &mem_size) { | |||||
| const string format_str = TypeUtils::FormatToSerialString(format); | |||||
| const string type_str = TypeUtils::DataTypeToSerialString(data_type); | |||||
| uint32_t type_size = 0; | |||||
| bool result = TypeUtils::GetDataTypeLength(data_type, type_size); | |||||
| if (!result) { | |||||
| GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::vector<int64_t> dims = shape.GetDims(); | |||||
| int64_t element_cnt = 0; | |||||
| graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format, | |||||
| format_str.c_str(), data_type, type_str.c_str()); | |||||
| return status; | |||||
| } | |||||
| // Support unknown shape | |||||
| if (element_cnt < 0) { | |||||
| mem_size = kMemSizeUnknownShape; | |||||
| GELOGD( | |||||
| "element_cnt is unknown. " | |||||
| "format=%d(%s), data_type=%d(%s), mem_size=%ld", | |||||
| format, format_str.c_str(), data_type, type_str.c_str(), mem_size); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto type_size_int64 = static_cast<int64_t>(type_size); | |||||
| if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) { | |||||
| GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).", | |||||
| element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| mem_size = element_cnt * type_size_int64; | |||||
| GELOGD( | |||||
| "CalcTensorMemSize end, " | |||||
| "format=%d(%s), data_type=%d(%s), mem_size=%ld", | |||||
| format, format_str.c_str(), data_type, type_str.c_str(), mem_size); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||||
| graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | |||||
| if (graph_status != GRAPH_SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| // 64-byte alignment, if size is 0, align to 32 bytes | |||||
| if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { | |||||
| GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); | |||||
| } else { | |||||
| size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||||
| GeShape output_shape = desc_temp.GetShape(); | |||||
| Format format = desc_temp.GetFormat(); | |||||
| DataType data_type = desc_temp.GetDataType(); | |||||
| int64_t output_mem_size = 0; | |||||
| graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size); | |||||
| if (graph_status != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (output_mem_size < 0) { | |||||
| GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", | |||||
| output_mem_size, INT64_MAX); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| size_temp = output_mem_size; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,684 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/tuning_utils.h" | |||||
| #include "../debug/ge_util.h" | |||||
| #include "../debug/ge_op_types.h" | |||||
| namespace ge { | |||||
| const std::string peer_node_name_attr = "_peerNodeName"; | |||||
| const std::string parent_node_name_attr = "_parentNodeName"; | |||||
| const std::string alias_name_attr = "_aliasName"; | |||||
| const std::string parent_node_attr = "parentNode"; | |||||
| const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; | |||||
| const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; | |||||
| const std::string non_tuning_subgraph_prefix = "/subgraph_"; | |||||
| const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END}; | |||||
| const std::set<std::string> kExeTypes = {DATA, NETOUTPUT}; | |||||
| NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; | |||||
| NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; | |||||
| NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; | |||||
| NodeSet TuningUtils::netoutput_nodes_; | |||||
| NodeSet TuningUtils::merged_graph_nodes_; | |||||
| SubgraphCreateOutNode TuningUtils::create_output_; | |||||
| std::mutex TuningUtils::mutex_; | |||||
| std::string TuningUtils::PrintCheckLog() { | |||||
| std::stringstream ss; | |||||
| ss << "d2n:{"; | |||||
| for (const auto &pair : data_2_netoutput_) { | |||||
| ss << "data:" << pair.first << "-" | |||||
| << "netoutput:" << pair.second; | |||||
| ss << " | "; | |||||
| } | |||||
| ss << "}"; | |||||
| ss << "netoutputs:{"; | |||||
| for (const auto &node : netoutput_nodes_) { | |||||
| ss << "netoutput:" << node->GetName(); | |||||
| ss << " | "; | |||||
| } | |||||
| ss << "}"; | |||||
| return ss.str(); | |||||
| } | |||||
| std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) { | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Anchor is nullptr"); | |||||
| return "Null"; | |||||
| } | |||||
| auto node = anchor->GetOwnerNode(); | |||||
| return node == nullptr ? "Null" : node->GetName(); | |||||
| } | |||||
| // part 1 | |||||
| graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs, | |||||
| std::vector<ComputeGraphPtr> non_tuning_subgraphs, bool exe_flag, | |||||
| const std::string &path, const std::string &user_path) { | |||||
| int64_t i = 0; | |||||
| int64_t j = 0; | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| for (auto &subgraph : tuning_subgraphs) { | |||||
| create_output_.emplace(subgraph, nullptr); | |||||
| auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; | |||||
| if (MakeExeGraph(subgraph, help_info) != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| i++; | |||||
| } | |||||
| for (auto &subgraph : non_tuning_subgraphs) { | |||||
| create_output_.emplace(subgraph, nullptr); | |||||
| auto help_info = HelpInfo{j, true, false, path, user_path}; | |||||
| if (MakeExeGraph(subgraph, help_info) != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| j++; | |||||
| } | |||||
| create_output_.clear(); | |||||
| return SUCCESS; | |||||
| } | |||||
| // +---------------+ | |||||
| // | pld pld | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | end | | |||||
| // +---------------+ | |||||
| // | | |||||
| // | | |||||
| // V | |||||
| // +---------------+ | |||||
| // | data data | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | netoutput | | |||||
| // +---------------+ | |||||
| graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { | |||||
| GE_CHECK_NOTNULL(exe_graph); | |||||
| // if not make exe, just dump and return | |||||
| if (!help_info.exe_flag) { | |||||
| DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); | |||||
| GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index); | |||||
| return SUCCESS; | |||||
| } | |||||
| // modify sub graph | |||||
| for (NodePtr &node : exe_graph->GetDirectNode()) { | |||||
| // 1.handle pld | |||||
| if (node->GetType() == PLACEHOLDER) { | |||||
| if (HandlePld(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), | |||||
| exe_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| // 2.handle end | |||||
| if (node->GetType() == END) { | |||||
| if (HandleEnd(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), | |||||
| exe_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus ret = exe_graph->TopologicalSorting(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| // dump subgraphs which modified by us | |||||
| if (help_info.user_path.empty()) { | |||||
| DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); | |||||
| } else { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) { | |||||
| if (!path.empty()) { | |||||
| if (is_tuning_graph) { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } else { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } | |||||
| } else { | |||||
| path = "./"; | |||||
| if (is_tuning_graph) { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } else { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) { | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA); | |||||
| GE_CHECK_NOTNULL(data_op_desc); | |||||
| auto pld_op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(pld_op_desc); | |||||
| auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data | |||||
| // data inputdesc & outputdesc set as same | |||||
| if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| data_node = graph->AddNode(data_op_desc); | |||||
| GE_CHECK_NOTNULL(data_node); | |||||
| if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) { | |||||
| auto op_desc = data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto pld_desc = pld->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(pld_desc); | |||||
| // inherit | |||||
| // a. set `end's input node type` as attr | |||||
| std::string parent_op_type; | |||||
| if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); | |||||
| // b. set `end's input node name` as attr | |||||
| std::string parent_op_name; | |||||
| if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); | |||||
| // c. set `end's input node's out anchor index` as attr | |||||
| int parent_node_anchor_index; | |||||
| if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); | |||||
| GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), | |||||
| data_node->GetName().c_str(), data_node->GetType().c_str()); | |||||
| // d. set `end node name` as attr | |||||
| std::string peer_end_name; | |||||
| if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); | |||||
| GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), | |||||
| data_node->GetName().c_str(), data_node->GetType().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) { | |||||
| auto type_pld = node->GetType(); | |||||
| auto type_data = data_node->GetType(); | |||||
| if (type_pld != PLACEHOLDER || type_data != DATA) { | |||||
| GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(), | |||||
| type_data.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| std::vector<int> output_map(node->GetAllOutDataAnchorsSize()); | |||||
| for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { | |||||
| output_map[i] = static_cast<int>(i); | |||||
| } | |||||
| auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(), | |||||
| data_node->GetName().c_str(), ret); | |||||
| return FAILED; | |||||
| } | |||||
| NodeUtils::UnlinkAll(*node); | |||||
| ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(), | |||||
| node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); | |||||
| return ret; | |||||
| } | |||||
| graphStatus TuningUtils::HandlePld(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| NodePtr data_node = nullptr; | |||||
| // 1. create data node | |||||
| if (CreateDataNode(node, data_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 2. add necessary info to data_node for recovery whole graph | |||||
| if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. replace pld node by data node created before | |||||
| if (ChangePld2Data(node, data_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| auto search = create_output_.find(graph); | |||||
| if (search == create_output_.end()) { | |||||
| GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(), | |||||
| graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (search->second != nullptr) { | |||||
| out_node = search->second; | |||||
| GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT); | |||||
| GE_CHECK_NOTNULL(out_op_desc); | |||||
| out_node = graph->AddNode(out_op_desc); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); | |||||
| return FAILED; | |||||
| } | |||||
| create_output_[graph] = out_node; | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(end); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| auto op_desc = out_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::vector<std::string> alias_names = {}; | |||||
| (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); | |||||
| alias_names.push_back(end->GetName()); | |||||
| (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(end_node); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| // get end in node is control node or normal node | |||||
| AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) | |||||
| ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor()) | |||||
| : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0)); | |||||
| auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 | |||||
| if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | |||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // add edge between `end in node` and `out_node` | |||||
| if (src_anchor->IsTypeOf<OutDataAnchor>()) { | |||||
| std::shared_ptr<InDataAnchor> anchor = | |||||
| ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size()); | |||||
| GE_CHECK_NOTNULL(anchor); | |||||
| out_node->in_data_anchors_.push_back(anchor); | |||||
| if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | |||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto end_op_desc = end_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(end_op_desc); | |||||
| auto out_node_op_desc = out_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(out_node_op_desc); | |||||
| // end node always has one input | |||||
| if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else if (src_anchor->IsTypeOf<OutControlAnchor>()) { | |||||
| auto anchor = out_node->GetInControlAnchor(); | |||||
| if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | |||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(end_node); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| auto type_end = end_node->GetType(); | |||||
| auto type_out = out_node->GetType(); | |||||
| if (type_end != END || type_out != NETOUTPUT) { | |||||
| GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(), | |||||
| type_end.c_str(), type_out.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // link all `end nodes's in node` to this out_node | |||||
| if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // remove `end node` | |||||
| NodeUtils::UnlinkAll(*end_node); | |||||
| auto graph = end_node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::HandleEnd(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| NodePtr out_node = nullptr; | |||||
| // 1. create net_output node , add only one NetOutput node to one subgraph | |||||
| if (CreateNetOutput(node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 2. add necessary info to out_node for recovery whole graph | |||||
| if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. replace all end nodes by one output node created before | |||||
| if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| // part 2 | |||||
| graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) { | |||||
| // 1. get all subgraph object | |||||
| std::vector<ComputeGraphPtr> graphs; | |||||
| // options format like {index:"subgraph_path"} | |||||
| for (const auto &pair : options) { | |||||
| ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first)); | |||||
| if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) { | |||||
| GELOGE(FAILED, "TUU:load graph from file failed"); | |||||
| } | |||||
| graphs.push_back(compute_graph); | |||||
| } | |||||
| // 2. merge graph | |||||
| ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune"); | |||||
| GE_CHECK_NOTNULL(merged_graph); | |||||
| if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:MergeGraph failed"); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. set parent graph | |||||
| for (const auto &node : merged_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph); | |||||
| return SUCCESS; | |||||
| } | |||||
| // +----------------------------------+ | |||||
| // | const const | | |||||
| // | \ / | | |||||
| // | netoutput(end,end) | | |||||
| // +----------------------------------+ | |||||
| // + | |||||
| // +----------------------------------+ | |||||
| // | data(pld) data(pld) | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | netoutput(end) | | |||||
| // +----------------------------------+ | |||||
| // + | |||||
| // +----------------------------------+ | |||||
| // | data(pld) | | |||||
| // | / | | |||||
| // | netoutput | | |||||
| // +----------------------------------+ | |||||
| // | | |||||
| // | | |||||
| // V | |||||
| // +----------------------------------+ | |||||
| // | const const | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | netoutput | | |||||
| // +----------------------------------+ | |||||
| graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs, | |||||
| ComputeGraphPtr &output_merged_compute_graph) { | |||||
| GE_CHECK_NOTNULL(output_merged_compute_graph); | |||||
| // 1. handle all subgraphs | |||||
| for (auto &subgraph : subgraphs) { | |||||
| Status ret_status = MergeSubGraph(subgraph); | |||||
| if (ret_status != SUCCESS) { | |||||
| GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str()); | |||||
| return ret_status; | |||||
| } | |||||
| } | |||||
| for (const auto &node : merged_graph_nodes_) { | |||||
| (void)output_merged_compute_graph->AddNode(node); | |||||
| GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| // 2. remove data and output node added by us | |||||
| if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| graphStatus ret = output_merged_compute_graph->TopologicalSorting(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); | |||||
| GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) { | |||||
| for (auto &node : subgraph->GetDirectNode()) { | |||||
| if (kPartitionOpTypes.count(node->GetType()) > 0) { | |||||
| GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type"); | |||||
| return FAILED; | |||||
| } | |||||
| // handle data converted from pld node | |||||
| if (node->GetType() == DATA) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::string peer_out_name; | |||||
| bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); | |||||
| if (has_valid_str) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name); | |||||
| data_node_2_netoutput_.emplace(node, peer_out_name); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| // handle netoutput converted from end node | |||||
| if (node->GetType() == NETOUTPUT) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::vector<string> out_alias_name; | |||||
| bool has_valid_str = | |||||
| (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); | |||||
| if (has_valid_str) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| netoutput_nodes_.insert(node); | |||||
| } | |||||
| } | |||||
| { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| merged_graph_nodes_.emplace(node); | |||||
| } | |||||
| GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| // 1. traverse | |||||
| for (auto &pair : data_node_2_netoutput_) { | |||||
| auto data_node = pair.first; | |||||
| GE_CHECK_NOTNULL(data_node); | |||||
| auto netoutput_name = pair.second; | |||||
| auto netoutput_node = graph->FindNode(netoutput_name); | |||||
| GE_CHECK_NOTNULL(netoutput_node); | |||||
| data_node_2_netoutput_node_.emplace(data_node, netoutput_node); | |||||
| // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` | |||||
| AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) | |||||
| ? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor()) | |||||
| : Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0)); | |||||
| AnchorPtr net_output_in_anchor = nullptr; | |||||
| AnchorPtr src_out_anchor = nullptr; | |||||
| if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed", | |||||
| netoutput_node->GetName().c_str(), data_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. relink | |||||
| if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", | |||||
| GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(), | |||||
| data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHECK_NOTNULL(data_out_anchor); | |||||
| for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { | |||||
| if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", | |||||
| GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), | |||||
| data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", | |||||
| GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), | |||||
| data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| // 4. remove out nodes added by us | |||||
| for (auto &node : netoutput_nodes_) { | |||||
| NodeUtils::UnlinkAll(*node); | |||||
| if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, | |||||
| AnchorPtr &src_out_anchor) { | |||||
| // 1. get `data parent node name`, i.e. `netoutput input node name` | |||||
| std::string netoutput_input_name; | |||||
| auto op_desc = data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) { | |||||
| GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 2. find index | |||||
| int parent_node_anchor_index; | |||||
| if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) { | |||||
| GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3.find in data or ctrl anchor by 1&2 step | |||||
| for (auto &in_anchor : out_node->GetAllInAnchors()) { | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl | |||||
| GE_CHECK_NOTNULL(src_anchor); | |||||
| auto src_node = src_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) { | |||||
| dest_in_anchor = in_anchor; | |||||
| src_out_anchor = src_anchor; | |||||
| GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", | |||||
| out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(), | |||||
| parent_node_anchor_index, data_node->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_CHECK_NOTNULL(dest_in_anchor); | |||||
| GE_CHECK_NOTNULL(src_out_anchor); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,448 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "debug/ge_util.h" | |||||
| using domi::domiTensorFormat_t; | |||||
| namespace ge { | |||||
| static const std::map<Format, std::string> kFormatToStringMap = { | |||||
| {FORMAT_NCHW, "NCHW"}, | |||||
| {FORMAT_NHWC, "NHWC"}, | |||||
| {FORMAT_ND, "ND"}, | |||||
| {FORMAT_NC1HWC0, "NC1HWC0"}, | |||||
| {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||||
| {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||||
| {FORMAT_NHWC1C0, "NHWC1C0"}, | |||||
| {FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||||
| {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||||
| {FORMAT_C1HWNC0, "C1HWNC0"}, | |||||
| {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||||
| {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||||
| {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||||
| {FORMAT_CHWN, "CHWN"}, | |||||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||||
| {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||||
| {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||||
| {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||||
| {FORMAT_HWCN, "HWCN"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||||
| {FORMAT_MD, "MD"}, | |||||
| {FORMAT_NDHWC, "NDHWC"}, | |||||
| {FORMAT_NCDHW, "NCDHW"}, | |||||
| {FORMAT_DHWCN, "DHWCN"}, | |||||
| {FORMAT_DHWNC, "DHWNC"}, | |||||
| {FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||||
| {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||||
| {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, | |||||
| {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||||
| {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||||
| {FORMAT_CN, "CN"}, | |||||
| {FORMAT_NC, "NC"}, | |||||
| {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, | |||||
| {FORMAT_FRACTAL_Z_G, "FRACTAL_Z_G"}, | |||||
| {FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||||
| {FORMAT_ALL, "ALL"}}; | |||||
| static const std::map<domiTensorFormat_t, Format> kDomiFormatToGeFormat = { | |||||
| {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, | |||||
| {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, | |||||
| {domi::DOMI_TENSOR_ND, FORMAT_ND}, | |||||
| {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, | |||||
| {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, | |||||
| {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, | |||||
| {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, | |||||
| {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, | |||||
| {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, | |||||
| {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, | |||||
| {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, | |||||
| {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, | |||||
| {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, | |||||
| {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, | |||||
| {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, | |||||
| {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, | |||||
| {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}}; | |||||
| static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||||
| "FRACTAL_Z", | |||||
| "NC1C0HWPAD", | |||||
| "NHWC1C0", | |||||
| "FRACTAL_DECONV", | |||||
| "C1HWNC0", | |||||
| "FRACTAL_DECONV_TRANSPOSE", | |||||
| "FRACTAL_DECONV_SP_STRIDE_TRANS", | |||||
| "NC1HWC0_C04", | |||||
| "FRACTAL_Z_C04", | |||||
| "FRACTAL_DECONV_SP_STRIDE8_TRANS", | |||||
| "NC1KHKWHWC0", | |||||
| "C1HWNCoC0", | |||||
| "FRACTAL_ZZ", | |||||
| "FRACTAL_NZ", | |||||
| "NDC1HWC0", | |||||
| "FORMAT_FRACTAL_Z_3D", | |||||
| "FORMAT_FRACTAL_Z_3D_TRANSPOSE", | |||||
| "FORMAT_FRACTAL_ZN_LSTM", | |||||
| "FORMAT_FRACTAL_Z_G"}; | |||||
| static const std::map<std::string, Format> kDataFormatMap = { | |||||
| {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||||
| static const std::map<std::string, Format> kStringToFormatMap = { | |||||
| {"NCHW", FORMAT_NCHW}, | |||||
| {"NHWC", FORMAT_NHWC}, | |||||
| {"ND", FORMAT_ND}, | |||||
| {"NC1HWC0", FORMAT_NC1HWC0}, | |||||
| {"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||||
| {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||||
| {"NHWC1C0", FORMAT_NHWC1C0}, | |||||
| {"FSR_NCHW", FORMAT_FSR_NCHW}, | |||||
| {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||||
| {"C1HWNC0", FORMAT_C1HWNC0}, | |||||
| {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||||
| {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||||
| {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||||
| {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||||
| {"CHWN", FORMAT_CHWN}, | |||||
| {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||||
| {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||||
| {"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||||
| {"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||||
| {"HWCN", FORMAT_HWCN}, | |||||
| {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||||
| {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||||
| {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||||
| {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||||
| {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||||
| {"MD", FORMAT_MD}, | |||||
| {"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||||
| {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||||
| {"NDHWC", FORMAT_NDHWC}, | |||||
| {"NCDHW", FORMAT_NCDHW}, | |||||
| {"DHWCN", FORMAT_DHWCN}, | |||||
| {"DHWNC", FORMAT_DHWNC}, | |||||
| {"NDC1HWC0", FORMAT_NDC1HWC0}, | |||||
| {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||||
| {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||||
| {"CN", FORMAT_CN}, | |||||
| {"NC", FORMAT_NC}, | |||||
| {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||||
| {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, | |||||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||||
| {"ALL", FORMAT_ALL}, | |||||
| {"NULL", FORMAT_NULL}}; | |||||
| static const std::map<DataType, std::string> kDataTypeToStringMap = { | |||||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||||
| {DT_FLOAT, "DT_FLOAT"}, // float type | |||||
| {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||||
| {DT_INT8, "DT_INT8"}, // int8 type | |||||
| {DT_INT16, "DT_INT16"}, // int16 type | |||||
| {DT_UINT16, "DT_UINT16"}, // uint16 type | |||||
| {DT_UINT8, "DT_UINT8"}, // uint8 type | |||||
| {DT_INT32, "DT_INT32"}, // uint32 type | |||||
| {DT_INT64, "DT_INT64"}, // int64 type | |||||
| {DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||||
| {DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||||
| {DT_BOOL, "DT_BOOL"}, // bool type | |||||
| {DT_DOUBLE, "DT_DOUBLE"}, // double type | |||||
| {DT_DUAL, "DT_DUAL"}, // dual output type | |||||
| {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||||
| {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||||
| {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||||
| {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||||
| {DT_QINT8, "DT_QINT8"}, // qint8 type | |||||
| {DT_QINT16, "DT_QINT16"}, // qint16 type | |||||
| {DT_QINT32, "DT_QINT32"}, // qint32 type | |||||
| {DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||||
| {DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||||
| {DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||||
| {DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||||
| {DT_STRING, "DT_STRING"}, // string type | |||||
| }; | |||||
| static const std::map<std::string, DataType> kStringTodataTypeMap = { | |||||
| {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||||
| {"DT_FLOAT", DT_FLOAT}, // float type | |||||
| { | |||||
| "DT_FLOAT16", | |||||
| DT_FLOAT16, | |||||
| }, // fp16 type | |||||
| {"DT_INT8", DT_INT8}, // int8 type | |||||
| {"DT_INT16", DT_INT16}, // int16 type | |||||
| {"DT_UINT16", DT_UINT16}, // uint16 type | |||||
| {"DT_UINT8", DT_UINT8}, // uint8 type | |||||
| {"DT_INT32", DT_INT32}, // uint32 type | |||||
| {"DT_INT64", DT_INT64}, // int64 type | |||||
| {"DT_UINT32", DT_UINT32}, // unsigned int32 | |||||
| {"DT_UINT64", DT_UINT64}, // unsigned int64 | |||||
| {"DT_BOOL", DT_BOOL}, // bool type | |||||
| {"DT_DOUBLE", DT_DOUBLE}, // double type | |||||
| {"DT_DUAL", DT_DUAL}, // dual output type | |||||
| {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||||
| {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||||
| {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||||
| {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||||
| {"DT_QINT8", DT_QINT8}, // qint8 type | |||||
| {"DT_QINT16", DT_QINT16}, // qint16 type | |||||
| {"DT_QINT32", DT_QINT32}, // qint32 type | |||||
| {"DT_QUINT8", DT_QUINT8}, // quint8 type | |||||
| {"DT_QUINT16", DT_QUINT16}, // quint16 type | |||||
| {"DT_RESOURCE", DT_RESOURCE}, // resource type | |||||
| {"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||||
| {"DT_STRING", DT_STRING}, // string type | |||||
| }; | |||||
| static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | |||||
| {DT_BOOL, sizeof(bool)}, | |||||
| {DT_INT64, sizeof(int64_t)}, | |||||
| {DT_UINT64, sizeof(int64_t)}, | |||||
| {DT_FLOAT, sizeof(float)}, | |||||
| {DT_INT32, sizeof(int32_t)}, | |||||
| {DT_UINT32, sizeof(int32_t)}, | |||||
| {DT_INT8, sizeof(char)}, | |||||
| {DT_UINT8, sizeof(char)}, | |||||
| {DT_INT16, sizeof(int16_t)}, | |||||
| {DT_UINT16, sizeof(int16_t)}, | |||||
| {DT_FLOAT16, sizeof(int16_t)}, | |||||
| {DT_DOUBLE, sizeof(double)}, | |||||
| {DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||||
| {DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||||
| {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||||
| {DT_COMPLEX64, sizeof(int64_t)}, | |||||
| {DT_COMPLEX128, sizeof(int64_t) * 2}, | |||||
| {DT_QINT8, sizeof(int8_t)}, | |||||
| {DT_QINT16, sizeof(int16_t)}, | |||||
| {DT_QINT32, sizeof(int32_t)}, | |||||
| {DT_QUINT8, sizeof(uint8_t)}, | |||||
| {DT_QUINT16, sizeof(uint16_t)}, | |||||
| {DT_STRING_REF, sizeof(uint64_t) * 2}, | |||||
| {DT_STRING, sizeof(uint64_t)}, | |||||
| {DT_RESOURCE, sizeof(uint64_t)}, | |||||
| }; | |||||
| static const std::map<domi::FrameworkType, std::string> kFmkTypeToString = { | |||||
| {domi::CAFFE, "caffe"}, {domi::MINDSPORE, "mindspore"}, {domi::TENSORFLOW, "tensorflow"}, | |||||
| {domi::ANDROID_NN, "android_nn"}, {domi::ONNX, "onnx"}, {domi::FRAMEWORK_RESERVED, "framework_reserved"}, | |||||
| }; | |||||
| static const std::map<domi::ImplyType, std::string> kImplyTypeToString = { | |||||
| {domi::ImplyType::BUILDIN, "buildin"}, {domi::ImplyType::TVM, "tvm"}, {domi::ImplyType::CUSTOM, "custom"}, | |||||
| {domi::ImplyType::AI_CPU, "ai_cpu"}, {domi::ImplyType::CCE, "cce"}, {domi::ImplyType::GELOCAL, "gelocal"}, | |||||
| {domi::ImplyType::HCCL, "hccl"}, {domi::ImplyType::INVALID, "invalid"}}; | |||||
| std::string TypeUtils::ImplyTypeToSerialString(domi::ImplyType imply_type) { | |||||
| auto it = kImplyTypeToString.find(imply_type); | |||||
| if (it != kImplyTypeToString.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "ImplyTypeToSerialString: imply_type not support %u", imply_type); | |||||
| return "UNDEFINED"; | |||||
| } | |||||
| } | |||||
| bool TypeUtils::IsDataTypeValid(DataType dt) { | |||||
| uint32_t num = static_cast<uint32_t>(dt); | |||||
| GE_CHK_BOOL_EXEC((num <= DT_UNDEFINED), return false, "The DataType is invalid"); | |||||
| return true; | |||||
| } | |||||
| std::string TypeUtils::DataTypeToSerialString(DataType data_type) { | |||||
| auto it = kDataTypeToStringMap.find(data_type); | |||||
| if (it != kDataTypeToStringMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "DataTypeToSerialString: datatype not support %u", data_type); | |||||
| return "UNDEFINED"; | |||||
| } | |||||
| } | |||||
| DataType TypeUtils::SerialStringToDataType(const std::string &str) { | |||||
| auto it = kStringTodataTypeMap.find(str); | |||||
| if (it != kStringTodataTypeMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "SerialStringToDataType: datatype not support %s", str.c_str()); | |||||
| return DT_UNDEFINED; | |||||
| } | |||||
| } | |||||
| bool TypeUtils::IsFormatValid(Format format) { | |||||
| uint32_t num = static_cast<uint32_t>(format); | |||||
| GE_CHK_BOOL_EXEC((num <= FORMAT_RESERVED), return false, "The Format is invalid"); | |||||
| return true; | |||||
| } | |||||
| bool TypeUtils::IsInternalFormat(Format format) { | |||||
| std::string serial_format = FormatToSerialString(format); | |||||
| auto iter = kInternalFormat.find(serial_format); | |||||
| bool result = (iter == kInternalFormat.end()) ? false : true; | |||||
| return result; | |||||
| } | |||||
| std::string TypeUtils::FormatToSerialString(Format format) { | |||||
| auto it = kFormatToStringMap.find(format); | |||||
| if (it != kFormatToStringMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "Format not support %u", format); | |||||
| return "RESERVED"; | |||||
| } | |||||
| } | |||||
| Format TypeUtils::SerialStringToFormat(const std::string &str) { | |||||
| auto it = kStringToFormatMap.find(str); | |||||
| if (it != kStringToFormatMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); | |||||
| return FORMAT_RESERVED; | |||||
| } | |||||
| } | |||||
| Format TypeUtils::DataFormatToFormat(const std::string &str) { | |||||
| auto it = kDataFormatMap.find(str); | |||||
| if (it != kDataFormatMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "Format not support %s", str.c_str()); | |||||
| return FORMAT_RESERVED; | |||||
| } | |||||
| } | |||||
| Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { | |||||
| auto it = kDomiFormatToGeFormat.find(domi_format); | |||||
| if (it != kDomiFormatToGeFormat.end()) { | |||||
| return it->second; | |||||
| } | |||||
| GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); | |||||
| return FORMAT_RESERVED; | |||||
| } | |||||
| std::string TypeUtils::FmkTypeToSerialString(domi::FrameworkType fmk_type) { | |||||
| auto it = kFmkTypeToString.find(fmk_type); | |||||
| if (it != kFmkTypeToString.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| GELOGW("Framework type not support %d.", fmk_type); | |||||
| return ""; | |||||
| } | |||||
| } | |||||
| static inline void CopyDataFromBuffer(vector<uint8_t> &data, const Buffer &buffer) { | |||||
| data.clear(); | |||||
| if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { | |||||
| data.assign(buffer.GetData(), buffer.GetData() + buffer.GetSize()); | |||||
| } | |||||
| } | |||||
| graphStatus Usr2DefQuantizeFactor(const UsrQuantizeFactor &usr, QuantizeFactor &def) { | |||||
| def.scale_mode = uint32_t(usr.scale_mode); | |||||
| def.set_scale_value(usr.scale_value.data(), usr.scale_value.size()); | |||||
| def.scale_offset = usr.scale_offset; | |||||
| def.set_offset_data_value(usr.offset_data_value.data(), usr.offset_data_value.size()); | |||||
| def.offset_data_offset = usr.offset_data_offset; | |||||
| def.set_offset_weight_value(usr.offset_weight_value.data(), usr.offset_weight_value.size()); | |||||
| def.offset_weight_offset = usr.offset_weight_offset; | |||||
| def.set_offset_pad_value(usr.offset_pad_value.data(), usr.offset_pad_value.size()); | |||||
| def.offset_pad_offset = usr.offset_pad_offset; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Def2UsrQuantizeFactor(const QuantizeFactor &def, UsrQuantizeFactor &usr) { | |||||
| usr.scale_mode = UsrQuantizeScaleMode(def.scale_mode); | |||||
| CopyDataFromBuffer(usr.scale_value, def.scale_value); | |||||
| usr.scale_offset = def.scale_offset; | |||||
| CopyDataFromBuffer(usr.offset_data_value, def.offset_data_value); | |||||
| usr.offset_data_offset = def.offset_data_offset; | |||||
| CopyDataFromBuffer(usr.offset_weight_value, def.offset_weight_value); | |||||
| usr.offset_weight_offset = def.offset_weight_offset; | |||||
| CopyDataFromBuffer(usr.offset_pad_value, def.offset_pad_value); | |||||
| usr.offset_pad_offset = def.offset_pad_offset; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Usr2DefUsrQuantizeCalcFactor(const UsrQuantizeCalcFactor &usr, QuantizeCalcFactor &def) { | |||||
| def.set_offsetw(usr.offsetw.data(), usr.offsetw.size()); | |||||
| def.offsetw_offset = usr.offsetw_offset; | |||||
| def.set_offsetd(usr.offsetd.data(), usr.offsetd.size()); | |||||
| def.offsetd_offset = usr.offsetd_offset; | |||||
| def.set_scalereq(usr.scalereq.data(), usr.scalereq.size()); | |||||
| def.scaledreq_offset = usr.scaledreq_offset; | |||||
| def.set_offsetdnext(usr.offsetdnext.data(), usr.offsetdnext.size()); | |||||
| def.offsetdnext_offset = usr.offsetdnext_offset; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Def2UsrQuantizeCalcFactor(const QuantizeCalcFactor &def, UsrQuantizeCalcFactor &usr) { | |||||
| CopyDataFromBuffer(usr.offsetw, def.offsetw); | |||||
| usr.offsetw_offset = def.offsetw_offset; | |||||
| CopyDataFromBuffer(usr.offsetd, def.offsetd); | |||||
| usr.offsetd_offset = def.offsetd_offset; | |||||
| CopyDataFromBuffer(usr.scalereq, def.scalereq); | |||||
| usr.scaledreq_offset = def.scaledreq_offset; | |||||
| CopyDataFromBuffer(usr.offsetdnext, def.offsetdnext); | |||||
| usr.offsetdnext_offset = def.offsetdnext_offset; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus TypeUtils::Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def) { | |||||
| def.quantize_algo = uint32_t(usr.quantize_algo); | |||||
| def.scale_type = uint32_t(usr.scale_type); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.quantize_param, def.quantize_param), | |||||
| "Usr2DefQuantizeFactor quantize_param failed"); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.dequantize_param, def.dequantize_param), | |||||
| "Usr2DefQuantizeFactor dequantize_param failed"); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefQuantizeFactor(usr.requantize_param, def.requantize_param), | |||||
| "Usr2DefQuantizeFactor requantize_param failed"); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Usr2DefUsrQuantizeCalcFactor(usr.quantizecalc_param, def.quantizecalc_param), | |||||
| "Usr2DefQuantizeFactor quantizecalc_param failed"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus TypeUtils::Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr) { | |||||
| usr.quantize_algo = UsrQuantizeAlgorithm(def.quantize_algo); | |||||
| usr.scale_type = UsrQuantizeScaleType(def.scale_type); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.quantize_param, usr.quantize_param), | |||||
| "Def2UsrQuantizeFactor quantize_param failed"); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.dequantize_param, usr.dequantize_param), | |||||
| "Def2UsrQuantizeFactor dequantize_param failed"); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeFactor(def.requantize_param, usr.requantize_param), | |||||
| "Def2UsrQuantizeFactor requantize_param failed"); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(Def2UsrQuantizeCalcFactor(def.quantizecalc_param, usr.quantizecalc_param), | |||||
| "Def2UsrQuantizeCalcFactor quantizecalc_param failed"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) { | |||||
| auto it = kDataTypeToLength.find(data_type); | |||||
| if (it != kDataTypeToLength.end()) { | |||||
| length = it->second; | |||||
| return true; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool TypeUtils::CheckUint64MulOverflow(uint64_t a, uint32_t b) { | |||||
| // Not overflow | |||||
| if (a == 0) { | |||||
| return false; | |||||
| } | |||||
| if ((ULLONG_MAX / a) >= b) { | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,75 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ | |||||
| #define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "./ge_error_codes.h" | |||||
| using std::make_shared; | |||||
| using std::map; | |||||
| using std::pair; | |||||
| using std::string; | |||||
| using std::to_string; | |||||
| using std::unique_ptr; | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| class AttrValueImpl; | |||||
| /*lint -e148*/ | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { | |||||
| public: | |||||
| using INT = int64_t; | |||||
| using FLOAT = float; | |||||
| using STR = std::string; | |||||
| AttrValue(); | |||||
| ~AttrValue() = default; | |||||
| // GetValue, not list type | |||||
| template <typename T, typename DT> | |||||
| graphStatus GetValue(DT &val) const { | |||||
| T valGet; | |||||
| auto status = GetValue(valGet); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| val = DT(valGet); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| template <typename T, typename DT> | |||||
| static T CreateFrom(DT &&val) { | |||||
| return val; | |||||
| } | |||||
| std::shared_ptr<AttrValueImpl> impl; | |||||
| private: | |||||
| #define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const; | |||||
| VALUE_SET_GET_DEC(AttrValue::STR) | |||||
| VALUE_SET_GET_DEC(AttrValue::INT) | |||||
| VALUE_SET_GET_DEC(AttrValue::FLOAT) | |||||
| #undef VALUE_SET_GET_DEC | |||||
| }; | |||||
| /*lint +e148*/ | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ | |||||
| @@ -1,38 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ | |||||
| #define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ | |||||
| namespace ge { | |||||
| #ifdef HOST_VISIBILITY | |||||
| #define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_HOST_VISIBILITY | |||||
| #endif | |||||
| #ifdef DEV_VISIBILITY | |||||
| #define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_DEV_VISIBILITY | |||||
| #endif | |||||
| using graphStatus = uint32_t; | |||||
| const graphStatus GRAPH_FAILED = 0xFFFFFFFF; | |||||
| const graphStatus GRAPH_SUCCESS = 0; | |||||
| const graphStatus GRAPH_PARAM_INVALID = 50331649; | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ | |||||
| @@ -1,81 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ | |||||
| #define INC_EXTERNAL_GRAPH_GRAPH_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "./operator.h" | |||||
| namespace ge { | |||||
| class GraphImpl; | |||||
| using GraphImplPtr = std::shared_ptr<GraphImpl>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | |||||
| friend class GraphUtils; | |||||
| public: | |||||
| explicit Graph(const std::string &name); | |||||
| Graph() = default; | |||||
| ~Graph() = default; | |||||
| Graph &SetInputs(const std::vector<Operator> &inputs); | |||||
| Graph &SetOutputs(const std::vector<Operator> &outputs); | |||||
| Graph &SetOutputs(const std::vector<std::pair<Operator, std::vector<size_t>>> &output_indexs); | |||||
| Graph &SetOutputs(const std::vector<std::pair<ge::Operator, std::string>> &outputs); | |||||
| Graph &SetTargets(const std::vector<Operator> &targets); | |||||
| bool IsValid() const; | |||||
| graphStatus AddOp(const ge::Operator &op); | |||||
| graphStatus FindOpByName(const string &name, ge::Operator &op) const; | |||||
| graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const; | |||||
| graphStatus GetAllOpName(std::vector<string> &op_name) const; | |||||
| graphStatus SaveToFile(const string &file_name) const; | |||||
| graphStatus LoadFromFile(const string &file_name); | |||||
| const std::string &GetName() const; | |||||
| /// | |||||
| /// Set is need train iteration. | |||||
| /// If set true, it means this graph need to be run iteration some | |||||
| /// times(according variant "npu_runconfig/iterations_per_loop"). | |||||
| /// @param need_iteration need_iteration:whether to set iteration or not | |||||
| /// | |||||
| void SetNeedIteration(bool need_iteration); | |||||
| private: | |||||
| GraphImplPtr impl_{nullptr}; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_GRAPH_H_ | |||||
| @@ -1,76 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ | |||||
| #define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "./tensor.h" | |||||
| #include "./types.h" | |||||
| namespace ge { | |||||
| class InferenceContext; | |||||
| using InferenceContextPtr = std::shared_ptr<InferenceContext>; | |||||
| class ShapeAndTypeImpl; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { | |||||
| public: | |||||
| ShapeAndType(); | |||||
| ~ShapeAndType() = default; | |||||
| ShapeAndType(const Shape &shape, DataType dataType); | |||||
| void SetShape(const Shape &shape); | |||||
| void SetType(DataType dataType); | |||||
| Shape GetShape() const; | |||||
| DataType GetDataType() const; | |||||
| private: | |||||
| std::shared_ptr<ShapeAndTypeImpl> shape_and_type_impl_; | |||||
| }; | |||||
| class InferenceContextImpl; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||||
| public: | |||||
| ~InferenceContext() = default; | |||||
| InferenceContext(const InferenceContext &context) = delete; | |||||
| InferenceContext(const InferenceContext &&context) = delete; | |||||
| InferenceContext &operator=(const InferenceContext &context) = delete; | |||||
| InferenceContext &operator=(const InferenceContext &&context) = delete; | |||||
| void SetInputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types); | |||||
| const std::vector<std::vector<ShapeAndType>> &GetInputHandleShapesAndTypes() const; | |||||
| const std::vector<std::vector<ShapeAndType>> &GetOutputHandleShapesAndTypes() const; | |||||
| void SetOutputHandleShapesAndTypes(const std::vector<std::vector<ShapeAndType>> &shapes_and_types); | |||||
| void SetOutputHandleShapesAndTypes(std::vector<std::vector<ShapeAndType>> &&shapes_and_types); | |||||
| void SetMarks(const std::vector<std::string> &marks); | |||||
| const std::vector<std::string> &GetMarks() const; | |||||
| static std::unique_ptr<InferenceContext> Create(); | |||||
| private: | |||||
| explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
| std::shared_ptr<InferenceContextImpl> inference_context_impl_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ | |||||
| @@ -1,289 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ | |||||
| #define INC_EXTERNAL_GRAPH_OPERATOR_H_ | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "./ge_error_codes.h" | |||||
| #include "./inference_context.h" | |||||
| #include "./tensor.h" | |||||
| #ifndef USER_GE_LOGI | |||||
| #define USER_GE_LOGI(...) | |||||
| #endif // USER_GE_LOGI | |||||
| #ifndef USER_GE_LOGW | |||||
| #define USER_GE_LOGW(...) | |||||
| #endif // USER_GE_LOGW | |||||
| #ifndef USER_GE_LOGE | |||||
| #define USER_GE_LOGE(...) | |||||
| #endif // USER_GE_LOGE | |||||
| #define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") | |||||
| #define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") | |||||
| namespace ge { | |||||
| class Operator; | |||||
| class OperatorImpl; | |||||
| class NodeUtils; | |||||
| class NamedAttrs; | |||||
| class Graph; | |||||
| class AttrValue; | |||||
| class Node; | |||||
| using SubgraphBuilder = std::function<Graph()>; | |||||
| using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | |||||
| using OperatorPtr = std::shared_ptr<Operator>; | |||||
| class OpIO; | |||||
| using OutHandler = std::shared_ptr<OpIO>; | |||||
| using InHandler = std::shared_ptr<OpIO>; | |||||
| using std::function; | |||||
| using std::shared_ptr; | |||||
| using std::string; | |||||
| /*lint -e148*/ | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
| public: | |||||
| friend class OperatorImpl; | |||||
| friend class GraphBuilderImpl; | |||||
| friend class NodeUtils; | |||||
| using OpInt = int64_t; | |||||
| using OpFloat = float; | |||||
| using OpString = string; | |||||
| using OpBool = bool; | |||||
| using OpTensor = Tensor; | |||||
| using OpType = ge::DataType; | |||||
| using OpNamedAttrs = ge::NamedAttrs; | |||||
| using OpListInt = std::vector<int64_t>; | |||||
| using OpListFloat = std::vector<float>; | |||||
| using OpListString = std::vector<string>; | |||||
| using OpListBool = std::vector<bool>; | |||||
| using OpListTensor = std::vector<Tensor>; | |||||
| using OpBytes = std::vector<uint8_t>; | |||||
| using OpListListInt = std::vector<std::vector<int64_t>>; | |||||
| using OpListType = std::vector<ge::DataType>; | |||||
| using OpListNamedAttrs = std::vector<ge::NamedAttrs>; | |||||
| Operator() {} | |||||
| explicit Operator(const string &type); | |||||
| Operator(const string &name, const string &type); // lint !e148 | |||||
| virtual ~Operator() = default; | |||||
| bool IsEmpty() const; | |||||
| string GetName() const; | |||||
| string GetOpType() const; | |||||
| // Only has one output index = 0 | |||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt); | |||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 | |||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); | |||||
| Operator &AddControlInput(const Operator &src_oprt); | |||||
| graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; | |||||
| TensorDesc GetInputDesc(const string &name) const; | |||||
| TensorDesc GetInputDesc(uint32_t index) const; | |||||
| int GetDynamicOutputNum(const string &name) const; | |||||
| int GetDynamicInputNum(const string &name) const; | |||||
| graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; | |||||
| graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); | |||||
| TensorDesc GetOutputDesc(const string &name) const; | |||||
| TensorDesc GetOutputDesc(uint32_t index) const; | |||||
| graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 | |||||
| TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; | |||||
| graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 | |||||
| TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; | |||||
| graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 | |||||
| graphStatus InferShapeAndType(); // lint !e148 | |||||
| void SetInferenceContext(const InferenceContextPtr &inference_context); | |||||
| InferenceContextPtr GetInferenceContext() const; | |||||
| graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 | |||||
| size_t GetInputsSize() const; | |||||
| size_t GetOutputsSize() const; | |||||
| const std::map<std::string, std::string> GetAllAttrNamesAndTypes() const; | |||||
| Operator &SetAttr(const string &name, int64_t attr_value); | |||||
| Operator &SetAttr(const string &name, int32_t attr_value); | |||||
| Operator &SetAttr(const string &name, uint32_t attr_value); | |||||
| graphStatus GetAttr(const string &name, int64_t &attr_value) const; | |||||
| graphStatus GetAttr(const string &name, int32_t &attr_value) const; | |||||
| graphStatus GetAttr(const string &name, uint32_t &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<int64_t> &attr_value); | |||||
| Operator &SetAttr(const string &name, const std::vector<int32_t> &attr_value); | |||||
| Operator &SetAttr(const string &name, const std::vector<uint32_t> &attr_value); | |||||
| Operator &SetAttr(const string &name, std::initializer_list<int64_t> &&attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<int64_t> &attr_value) const; | |||||
| graphStatus GetAttr(const string &name, std::vector<int32_t> &attr_value) const; | |||||
| graphStatus GetAttr(const string &name, std::vector<uint32_t> &attr_value) const; | |||||
| Operator &SetAttr(const string &name, float attr_value); | |||||
| graphStatus GetAttr(const string &name, float &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<float> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<float> &attr_value) const; | |||||
| Operator &SetAttr(const string &name, AttrValue &&attr_value); | |||||
| graphStatus GetAttr(const string &name, AttrValue &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const string &attr_value); | |||||
| graphStatus GetAttr(const string &name, string &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<string> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<string> &attr_value) const; | |||||
| Operator &SetAttr(const string &name, bool attr_value); | |||||
| graphStatus GetAttr(const string &name, bool &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<bool> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<bool> &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const Tensor &attr_value); | |||||
| graphStatus GetAttr(const string &name, Tensor &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<Tensor> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<Tensor> &attr_value) const; | |||||
| // Bytes type | |||||
| Operator &SetAttr(const string &name, const OpBytes &attr_value); | |||||
| // Bytes type | |||||
| graphStatus GetAttr(const string &name, OpBytes &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<std::vector<int64_t>> &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<ge::DataType> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<ge::DataType> &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const ge::DataType &attr_value); | |||||
| graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; | |||||
| // func type | |||||
| Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); | |||||
| graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; | |||||
| Operator &SetAttr(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||||
| graphStatus GetAttr(const string &name, std::vector<ge::NamedAttrs> &attr_value) const; | |||||
| void BreakConnect() const; | |||||
| size_t GetSubgraphNamesCount() const; | |||||
| std::vector<std::string> GetSubgraphNames() const; | |||||
| SubgraphBuilder GetSubgraphBuilder(const string &name) const; | |||||
| Graph GetSubgraph(const string &name) const; | |||||
| SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; | |||||
| Graph GetDynamicSubgraph(const string &name, uint32_t index) const; | |||||
| protected: | |||||
| void AttrRegister(const string &name, float attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<float> &attr_value); | |||||
| void AttrRegister(const string &name, int64_t attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<int64_t> &attr_value); | |||||
| void AttrRegister(const string &name, const string &attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<string> &attr_value); | |||||
| void AttrRegister(const string &name, bool attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<bool> &attr_value); | |||||
| void AttrRegister(const string &name, const Tensor &attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<Tensor> &attr_value); | |||||
| void AttrRegister(const string &name, const OpBytes &attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value); | |||||
| void AttrRegister(const string &name, const ge::DataType &attr_value); | |||||
| void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); | |||||
| void AttrRegister(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||||
| explicit Operator(OperatorImplPtr &&op_impl); | |||||
| void InputRegister(const string &name); | |||||
| void OptionalInputRegister(const string &name); | |||||
| void InferFuncRegister(const std::function<graphStatus(Operator &)> &func); | |||||
| void VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func); | |||||
| void InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func); | |||||
| void OutputRegister(const string &name); | |||||
| void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); | |||||
| void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); | |||||
| void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); | |||||
| void RequiredAttrRegister(const string &name); | |||||
| graphStatus VerifyAll(); // lint !e148 | |||||
| // Only has one output index = 0 | |||||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); | |||||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, | |||||
| const string &name); // lint !e148 | |||||
| void SubgraphRegister(const string &ir_name, bool dynamic); | |||||
| void SubgraphCountRegister(const string &ir_name, uint32_t count); | |||||
| void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); | |||||
| private: | |||||
| Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 | |||||
| OutHandler GetOutput(const string &name) const; | |||||
| OutHandler GetOutput(uint32_t index) const; | |||||
| OperatorImplPtr GetOperatorImplPtr() const; | |||||
| OperatorImplPtr operator_impl_{nullptr}; | |||||
| graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; | |||||
| std::shared_ptr<const Node> GetNode() const; | |||||
| }; | |||||
| /*lint +e148*/ | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ | |||||
| @@ -1,68 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ | |||||
| #define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "./operator.h" | |||||
| #include "./ge_error_codes.h" | |||||
| namespace ge { | |||||
| using OpCreator = std::function<Operator(const std::string &)>; | |||||
| using InferShapeFunc = std::function<graphStatus(Operator &)>; | |||||
| using InferFormatFunc = std::function<graphStatus(Operator &)>; | |||||
| using VerifyFunc = std::function<graphStatus(Operator &)>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory { | |||||
| public: | |||||
| static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); | |||||
| static graphStatus GetOpsTypeList(std::vector<std::string> &all_ops); | |||||
| static bool IsExistOp(const string &operator_type); | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister { | |||||
| public: | |||||
| OperatorCreatorRegister(const string &operator_type, OpCreator const &op_creator); | |||||
| ~OperatorCreatorRegister() = default; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister { | |||||
| public: | |||||
| InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func); | |||||
| ~InferShapeFuncRegister() = default; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister { | |||||
| public: | |||||
| InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func); | |||||
| ~InferFormatFuncRegister() = default; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister { | |||||
| public: | |||||
| VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func); | |||||
| ~VerifyFuncRegister() = default; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ | |||||
| @@ -1,376 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ | |||||
| #define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/operator.h" | |||||
| #include "graph/operator_factory.h" | |||||
| #include "graph/tensor.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/graph.h" | |||||
| namespace ge { | |||||
| using std::function; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| class OpReg { | |||||
| public: | |||||
| OpReg &N() { return *this; } | |||||
| OpReg &ATTR() { return *this; } | |||||
| OpReg &REQUIRED_ATTR() { return *this; } | |||||
| OpReg &INPUT() { return *this; } | |||||
| OpReg &OPTIONAL_INPUT() { return *this; } | |||||
| OpReg &OUTPUT() { return *this; } | |||||
| OpReg &GRAPH() { return *this; } | |||||
| OpReg &DYNAMIC_GRAPH() { return *this; } | |||||
| OpReg &INFER_SHAPE_AND_TYPE() { return *this; } | |||||
| }; | |||||
| #define REG_OP(x) \ | |||||
| namespace op { \ | |||||
| class x : public Operator { \ | |||||
| typedef x _THIS_TYPE; \ | |||||
| \ | |||||
| public: \ | |||||
| explicit x(const string &name) : Operator(name, #x) { __##x(); } \ | |||||
| x() : Operator(#x) { __##x(); } \ | |||||
| \ | |||||
| private: \ | |||||
| void __##x() { \ | |||||
| OpReg() | |||||
| #define ATTR(x, Type, ...) \ | |||||
| N(); \ | |||||
| __attr_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_attr_##x() { return #x; } \ | |||||
| Op##Type get_attr_##x() const { \ | |||||
| Op##Type ret = __VA_ARGS__; \ | |||||
| if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ | |||||
| return ret; \ | |||||
| } \ | |||||
| return ret; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ | |||||
| Operator::SetAttr(#x, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \ | |||||
| \ | |||||
| private: \ | |||||
| void __attr_##x() { \ | |||||
| Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | |||||
| string attr_name(#x); \ | |||||
| (void)OpReg() | |||||
| #define REQUIRED_ATTR(x, Type) \ | |||||
| N(); \ | |||||
| __required_attr_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_attr_##x() { return #x; } \ | |||||
| Op##Type get_attr_##x() const { \ | |||||
| Op##Type ret; \ | |||||
| if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ | |||||
| return ret; \ | |||||
| } \ | |||||
| return ret; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ | |||||
| Operator::SetAttr(#x, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_attr_##x(const function<Op##Type()> &v) { return *this; } \ | |||||
| \ | |||||
| private: \ | |||||
| void __required_attr_##x() { \ | |||||
| Operator::RequiredAttrRegister(#x); \ | |||||
| string attr_name(#x); \ | |||||
| (void)OpReg() | |||||
| #define INPUT(x, t) \ | |||||
| N(); \ | |||||
| __input_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_in_##x() { return #x; } \ | |||||
| _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ | |||||
| Operator::SetInput(#x, v, srcName); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ | |||||
| Operator::SetInput(#x, v, index); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_input_##x(Operator &v) { \ | |||||
| Operator::SetInput(#x, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ | |||||
| graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ | |||||
| return Operator::UpdateInputDesc(#x, tensorDesc); \ | |||||
| } \ | |||||
| \ | |||||
| private: \ | |||||
| void __input_##x() { \ | |||||
| Operator::InputRegister(#x); \ | |||||
| (void)OpReg() | |||||
| #define OPTIONAL_INPUT(x, t) \ | |||||
| N(); \ | |||||
| __optional_input_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_in_##x() { return #x; } \ | |||||
| _THIS_TYPE &set_input_##x(Operator &v) { \ | |||||
| Operator::SetInput(#x, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_input_##x(Operator &v, const string &srcName) { \ | |||||
| Operator::SetInput(#x, v, srcName); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ | |||||
| Operator::SetInput(#x, v, index); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| TensorDesc get_input_desc_##x() const { return Operator::GetInputDesc(#x); } \ | |||||
| graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ | |||||
| return Operator::UpdateInputDesc(#x, tensorDesc); \ | |||||
| } \ | |||||
| \ | |||||
| private: \ | |||||
| void __optional_input_##x() { \ | |||||
| Operator::OptionalInputRegister(#x); \ | |||||
| (void)OpReg() | |||||
| #define OUTPUT(x, t) \ | |||||
| N(); \ | |||||
| __out_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_out_##x() { return #x; } \ | |||||
| TensorDesc get_output_desc_##x() const { return Operator::GetOutputDesc(#x); } \ | |||||
| graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ | |||||
| return Operator::UpdateOutputDesc(#x, tensorDesc); \ | |||||
| } \ | |||||
| \ | |||||
| private: \ | |||||
| void __out_##x() { \ | |||||
| Operator::OutputRegister(#x); \ | |||||
| (void)OpReg() | |||||
| #define DYNAMIC_INPUT(x, t) \ | |||||
| N(); \ | |||||
| __dy_input_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ | |||||
| Operator::DynamicInputRegister(#x, num, isPushBack); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ | |||||
| Operator::DynamicInputRegisterByIndex(#x, num, index); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | |||||
| graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ | |||||
| return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | |||||
| } \ | |||||
| _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ | |||||
| Operator::SetInput(#x, dstIndex, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ | |||||
| Operator::SetInput(#x, dstIndex, v, srcName); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| \ | |||||
| private: \ | |||||
| void __dy_input_##x() { \ | |||||
| Operator::DynamicInputRegister(#x, 0, true); \ | |||||
| (void)OpReg() | |||||
| #define DYNAMIC_OUTPUT(x, t) \ | |||||
| N(); \ | |||||
| __dy_output_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ | |||||
| Operator::DynamicOutputRegister(#x, num, isPushBack); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \ | |||||
| graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ | |||||
| return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ | |||||
| } \ | |||||
| \ | |||||
| private: \ | |||||
| void __dy_output_##x() { \ | |||||
| Operator::DynamicOutputRegister(#x, 0, true); \ | |||||
| (void)OpReg() | |||||
| #define GRAPH(x) \ | |||||
| N(); \ | |||||
| __graph_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_graph_##x() { return #x; } \ | |||||
| SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ | |||||
| _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ | |||||
| Operator::SetSubgraphBuilder(#x, 0, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ | |||||
| \ | |||||
| private: \ | |||||
| void __graph_##x() { \ | |||||
| Operator::SubgraphRegister(#x, false); \ | |||||
| Operator::SubgraphCountRegister(#x, 1); \ | |||||
| (void)OpReg() | |||||
| #define DYNAMIC_GRAPH(x) \ | |||||
| N(); \ | |||||
| __graph_##x(); \ | |||||
| } \ | |||||
| \ | |||||
| public: \ | |||||
| static const string name_graph_##x() { return #x; } \ | |||||
| _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ | |||||
| Operator::SubgraphCountRegister(#x, num); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ | |||||
| return Operator::GetDynamicSubgraphBuilder(#x, index); \ | |||||
| } \ | |||||
| Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \ | |||||
| _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ | |||||
| Operator::SetSubgraphBuilder(#x, index, v); \ | |||||
| return *this; \ | |||||
| } \ | |||||
| \ | |||||
| private: \ | |||||
| void __graph_##x() { \ | |||||
| Operator::SubgraphRegister(#x, true); \ | |||||
| (void)OpReg() | |||||
| #define PASTE(g_register, y) g_register##y | |||||
| #define __OP_END_IMPL__(x, y) \ | |||||
| N(); \ | |||||
| } \ | |||||
| static_assert( \ | |||||
| std::is_same<x, _THIS_TYPE>::value, \ | |||||
| "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||||
| } \ | |||||
| ; \ | |||||
| static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||||
| } | |||||
| #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | |||||
| // Specialized shape inferencer macro | |||||
| #define IMPLEMT_INFERFUNC(op_name, func_name) \ | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) | |||||
| #define IMPLEMT_COMMON_INFERFUNC(func_name) \ | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) | |||||
| #define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) | |||||
| // Specialized verifier macro | |||||
| #define IMPLEMT_VERIFIER(op_name, func_name) \ | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) | |||||
| #define INFER_VERIFY_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } | |||||
| #define COMMON_INFER_VERIFY_FUNC(x) [&](Operator &v) { return x(v); } | |||||
| #define INFER_FORMAT_FUNC(op_name, x) [&](Operator &v) { return x((op::op_name &)v); } | |||||
| #define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) | |||||
| #define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) | |||||
| // Infer format func register | |||||
| #define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ | |||||
| static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) | |||||
| // Shape inferencer & verifier register macro | |||||
| #define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) | |||||
| #define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC(x), __COUNTER__) | |||||
| #define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) | |||||
| // Infer format func reg | |||||
| #define INFER_FORMAT_FUNC_REG(op_name, x) \ | |||||
| __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) | |||||
| // Common shape inferencer | |||||
| #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | |||||
| [](Operator op) -> graphStatus { \ | |||||
| auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | |||||
| auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | |||||
| TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | |||||
| op_output_desc.SetShape(ge::Shape(x_shape)); \ | |||||
| op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ | |||||
| op_output_desc.SetDataType(x_type); \ | |||||
| return op.UpdateOutputDesc(out_name, op_output_desc); \ | |||||
| } | |||||
| graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, | |||||
| const function<vector<int64_t>()> &get_in2_shape, | |||||
| const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | |||||
| #define BROADCAST_INFER(in1_name, in2_name, out_name) \ | |||||
| [](Operator op) -> graphStatus { \ | |||||
| return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | |||||
| [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | |||||
| [&](const vector<int64_t> &y_shape) { \ | |||||
| TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | |||||
| op_output_desc.SetShape(ge::Shape(y_shape)); \ | |||||
| (void)op.UpdateOutputDesc(out_name, op_output_desc); \ | |||||
| }); \ | |||||
| } | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ | |||||
| @@ -1,131 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_TENSOR_H_ | |||||
| #define INC_EXTERNAL_GRAPH_TENSOR_H_ | |||||
| #include <atomic> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <utility> | |||||
| #include "./ge_error_codes.h" | |||||
| #include "./types.h" | |||||
| namespace ge { | |||||
| class ShapeImpl; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { | |||||
| public: | |||||
| Shape(); | |||||
| ~Shape() = default; | |||||
| explicit Shape(const std::vector<int64_t> &dims); | |||||
| size_t GetDimNum() const; | |||||
| // If the idx is invalid, return 0 | |||||
| int64_t GetDim(size_t idx) const; | |||||
| graphStatus SetDim(size_t idx, int64_t value); | |||||
| std::vector<int64_t> GetDims() const; | |||||
| int64_t GetShapeSize() const; | |||||
| private: | |||||
| std::shared_ptr<ShapeImpl> impl_; | |||||
| }; | |||||
| class TensorDescImpl; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { | |||||
| public: | |||||
| TensorDesc(); | |||||
| ~TensorDesc() = default; | |||||
| explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | |||||
| // Copy | |||||
| TensorDesc(const TensorDesc &desc); | |||||
| // Move | |||||
| TensorDesc(TensorDesc &&desc); | |||||
| // Copy | |||||
| TensorDesc &operator=(const TensorDesc &desc); | |||||
| // Move | |||||
| TensorDesc &operator=(TensorDesc &&desc); | |||||
| void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | |||||
| Shape GetShape() const; | |||||
| void SetShape(const Shape &shape); | |||||
| // set shape with -2, it stand for unknown shape | |||||
| graphStatus SetUnknownDimNumShape(); | |||||
| // for unknown shape | |||||
| graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||||
| graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||||
| Format GetFormat() const; | |||||
| void SetFormat(Format format); | |||||
| Shape GetOriginShape() const; | |||||
| void SetOriginShape(const Shape &originShape); | |||||
| Format GetOriginFormat() const; | |||||
| void SetOriginFormat(Format originFormat); | |||||
| DataType GetDataType() const; | |||||
| void SetDataType(DataType dt); | |||||
| std::string GetName() const; | |||||
| void SetName(const std::string &name); | |||||
| // Attr acess | |||||
| void SetSize(int64_t size); | |||||
| int64_t GetSize() const; | |||||
| int64_t GetRealDimCnt() const; | |||||
| void SetRealDimCnt(const int64_t realDimCnt); | |||||
| private: | |||||
| std::shared_ptr<TensorDescImpl> impl; | |||||
| }; | |||||
| class TensorImpl; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { | |||||
| public: | |||||
| Tensor(); | |||||
| ~Tensor() = default; | |||||
| explicit Tensor(const TensorDesc &tensorDesc); | |||||
| Tensor(const TensorDesc &tensorDesc, const std::vector<uint8_t> &data); | |||||
| Tensor(const TensorDesc &tensorDesc, const uint8_t *data, size_t size); | |||||
| Tensor(TensorDesc &&tensorDesc, std::vector<uint8_t> &&data); | |||||
| TensorDesc GetTensorDesc() const; | |||||
| graphStatus SetTensorDesc(const TensorDesc &tensorDesc); | |||||
| const uint8_t *GetData() const; | |||||
| uint8_t *GetData(); | |||||
| size_t GetSize() const; | |||||
| graphStatus SetData(std::vector<uint8_t> &&data); | |||||
| graphStatus SetData(const std::vector<uint8_t> &data); | |||||
| graphStatus SetData(const uint8_t *data, size_t size); | |||||
| graphStatus SetData(const std::string &data); | |||||
| graphStatus SetData(const std::vector<std::string> &data); | |||||
| graphStatus IsValid(); | |||||
| Tensor Clone() const; | |||||
| private: | |||||
| std::shared_ptr<TensorImpl> impl; | |||||
| friend class TensorAdapter; | |||||
| }; | |||||
| } // namespace ge | |||||
| /*lint +e148*/ | |||||
| #endif // INC_EXTERNAL_GRAPH_TENSOR_H_ | |||||
| @@ -1,240 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_TYPES_H_ | |||||
| #define INC_EXTERNAL_GRAPH_TYPES_H_ | |||||
| #include <atomic> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| namespace ge { | |||||
| static const int64_t UNKNOWN_DIM = -1; | |||||
| static const int64_t UNKNOWN_DIM_NUM = -2; | |||||
| static const std::vector<int64_t> UNKNOWN_SHAPE = {-1}; | |||||
| static const std::vector<int64_t> UNKNOWN_RANK = {-2}; | |||||
| #ifdef HOST_VISIBILITY | |||||
| #define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_HOST_VISIBILITY | |||||
| #endif | |||||
| #ifdef DEV_VISIBILITY | |||||
| #define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_DEV_VISIBILITY | |||||
| #endif | |||||
| enum DataType { | |||||
| DT_FLOAT = 0, // float type | |||||
| DT_FLOAT16 = 1, // fp16 type | |||||
| DT_INT8 = 2, // int8 type | |||||
| DT_INT16 = 6, // int16 type | |||||
| DT_UINT16 = 7, // uint16 type | |||||
| DT_UINT8 = 4, // uint8 type | |||||
| DT_INT32 = 3, // | |||||
| DT_INT64 = 9, // int64 type | |||||
| DT_UINT32 = 8, // unsigned int32 | |||||
| DT_UINT64 = 10, // unsigned int64 | |||||
| DT_BOOL = 12, // bool type | |||||
| DT_DOUBLE = 11, // double type | |||||
| DT_STRING = 13, // string type | |||||
| DT_DUAL_SUB_INT8 = 14, // dual output int8 type | |||||
| DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type | |||||
| DT_COMPLEX64 = 16, // complex64 type | |||||
| DT_COMPLEX128 = 17, // complex128 type | |||||
| DT_QINT8 = 18, // qint8 type | |||||
| DT_QINT16 = 19, // qint16 type | |||||
| DT_QINT32 = 20, // qint32 type | |||||
| DT_QUINT8 = 21, // quint8 type | |||||
| DT_QUINT16 = 22, // quint16 type | |||||
| DT_RESOURCE = 23, // resource type | |||||
| DT_STRING_REF = 24, // string ref type | |||||
| DT_DUAL = 25, // dual output type | |||||
| DT_UNDEFINED // Used to indicate a DataType field has not been set. | |||||
| }; | |||||
| inline int GetSizeByDataType(DataType data_type) { | |||||
| static int data_type_size[DT_UNDEFINED] = { | |||||
| 4, // DT_FLOAT = 0, float type | |||||
| 2, // DT_FLOAT16 = 1, fp16 type | |||||
| 1, // DT_INT8 = 2, int8 type | |||||
| 4, // DT_INT32 = 3, | |||||
| 1, // DT_UINT8 = 4, uint8 type | |||||
| -1, | |||||
| 2, // DT_INT16 = 6, int16 type | |||||
| 2, // DT_UINT16 = 7, uint16 type | |||||
| 4, // DT_UINT32 = 8, unsigned int32 | |||||
| 8, // DT_INT64 = 9, int64 type | |||||
| 8, // DT_UINT64 = 10, unsigned int64 | |||||
| 8, // DT_DOUBLE = 11, double type | |||||
| 1, // DT_BOOL = 12, bool type | |||||
| -1, // DT_STRING = 13, string type | |||||
| 1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type | |||||
| 1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type | |||||
| 8, // DT_COMPLEX64 = 16, complex64 type | |||||
| 16, // DT_COMPLEX128 = 17, complex128 type | |||||
| 1, // DT_QINT8 = 18, qint8 type | |||||
| 2, // DT_QINT16 = 19, qint16 type | |||||
| 4, // DT_QINT32 = 20, qint32 type | |||||
| 1, // DT_QUINT8 = 21, quint8 type | |||||
| 2, // DT_QUINT16 = 22, quint16 type | |||||
| -1, // DT_RESOURCE = 23, resource type | |||||
| -1, // DT_STRING_REF = 24, string ref type | |||||
| 5, // DT_DUAL = 25, dual output type (float + int8) | |||||
| // DT_UNDEFINED Used to indicate a DataType field has not been set. | |||||
| }; | |||||
| if (data_type >= DT_UNDEFINED) { | |||||
| return -1; | |||||
| } | |||||
| return data_type_size[data_type]; | |||||
| } | |||||
| enum Format { | |||||
| FORMAT_NCHW = 0, // NCHW | |||||
| FORMAT_NHWC, // NHWC | |||||
| FORMAT_ND, // Nd Tensor | |||||
| FORMAT_NC1HWC0, // NC1HWC0 | |||||
| FORMAT_FRACTAL_Z, // FRACTAL_Z | |||||
| FORMAT_NC1C0HWPAD, | |||||
| FORMAT_NHWC1C0, | |||||
| FORMAT_FSR_NCHW, | |||||
| FORMAT_FRACTAL_DECONV, | |||||
| FORMAT_C1HWNC0, | |||||
| FORMAT_FRACTAL_DECONV_TRANSPOSE, | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, | |||||
| FORMAT_NC1HWC0_C04, // NC1HWC0, C0 =4 | |||||
| FORMAT_FRACTAL_Z_C04, // FRACZ, C0 =4 | |||||
| FORMAT_CHWN, | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, | |||||
| FORMAT_HWCN, | |||||
| FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format | |||||
| FORMAT_BN_WEIGHT, | |||||
| FORMAT_FILTER_HWCK, // filter input tensor format | |||||
| FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, | |||||
| FORMAT_HASHTABLE_LOOKUP_KEYS, | |||||
| FORMAT_HASHTABLE_LOOKUP_VALUE, | |||||
| FORMAT_HASHTABLE_LOOKUP_OUTPUT, | |||||
| FORMAT_HASHTABLE_LOOKUP_HITS = 24, | |||||
| FORMAT_C1HWNCoC0, | |||||
| FORMAT_MD, | |||||
| FORMAT_NDHWC, | |||||
| FORMAT_FRACTAL_ZZ, | |||||
| FORMAT_FRACTAL_NZ, | |||||
| FORMAT_NCDHW, | |||||
| FORMAT_DHWCN, // 3D filter input tensor format | |||||
| FORMAT_NDC1HWC0, | |||||
| FORMAT_FRACTAL_Z_3D, | |||||
| FORMAT_CN, | |||||
| FORMAT_NC, | |||||
| FORMAT_DHWNC, | |||||
| FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||||
| FORMAT_FRACTAL_ZN_LSTM, | |||||
| FORMAT_FRACTAL_Z_G, | |||||
| FORMAT_RESERVED, | |||||
| FORMAT_ALL, | |||||
| FORMAT_NULL | |||||
| }; | |||||
| // for unknown shape op type | |||||
| enum UnknowShapeOpType { | |||||
| DEPEND_IN_SHAPE = 1, // op out shape get by input shape | |||||
| DEPEND_CONST_VALUE = 2, // op out shape get by const op value | |||||
| DEPEND_SHAPE_RANGE = 3, // op out shape get by range | |||||
| DEPEND_COMPUTE = 4 // op out shape get by totally computing | |||||
| }; | |||||
| struct TensorDescInfo { | |||||
| Format format_ = FORMAT_RESERVED; // tbe op register support format | |||||
| DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype | |||||
| }; | |||||
| enum DeviceType { | |||||
| NPU = 0, | |||||
| CPU = 1, | |||||
| }; | |||||
| class TensorTypeImpl; | |||||
| struct TensorType { | |||||
| explicit TensorType(DataType dt); | |||||
| TensorType(const std::initializer_list<DataType> &types); | |||||
| static TensorType ALL() { | |||||
| return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, | |||||
| DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, | |||||
| DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; | |||||
| } | |||||
| static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; } | |||||
| static TensorType OrdinaryType() { | |||||
| return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, | |||||
| DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; | |||||
| } | |||||
| static TensorType BasicType() { | |||||
| return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, | |||||
| DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, | |||||
| DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; | |||||
| } | |||||
| static TensorType NumberType() { | |||||
| return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, | |||||
| DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; | |||||
| } | |||||
| static TensorType RealNumberType() { | |||||
| return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, | |||||
| DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; | |||||
| } | |||||
| static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64}; } | |||||
| static TensorType IntegerDataType() { | |||||
| return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; | |||||
| } | |||||
| static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; } | |||||
| static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; } | |||||
| static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } | |||||
| static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; } | |||||
| static TensorType UnaryDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } | |||||
| static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16}; } | |||||
| std::shared_ptr<TensorTypeImpl> tensor_type_impl_; | |||||
| }; | |||||
| } // namespace ge | |||||
| namespace domi { | |||||
| enum class ImplyType : unsigned int { | |||||
| BUILDIN = 0, // Built in operator, normally executed by OME | |||||
| TVM, // Compile to TVM bin file for execution | |||||
| CUSTOM, // User defined calculation logic, executed by CPU | |||||
| AI_CPU, // AICPU | |||||
| CCE, // Cce | |||||
| GELOCAL, // GE local, do node need execute by device | |||||
| HCCL, // Hccl | |||||
| INVALID = 0xFFFFFFFF, | |||||
| }; | |||||
| } // namespace domi | |||||
| #endif // INC_EXTERNAL_GRAPH_TYPES_H_ | |||||
| @@ -1,163 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_REGISTER_REGISTER_H_ | |||||
| #define INC_EXTERNAL_REGISTER_REGISTER_H_ | |||||
| #include <functional> | |||||
| #include <initializer_list> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "graph/operator.h" | |||||
| #include "register/register_error_codes.h" | |||||
| #include "register/register_fmk_types.h" | |||||
| #include "register/register_types.h" | |||||
| using std::make_shared; | |||||
| using std::map; | |||||
| using std::pair; | |||||
| using std::string; | |||||
| using std::to_string; | |||||
| using std::unique_ptr; | |||||
| using std::vector; | |||||
| /*lint -e148*/ | |||||
| namespace ge { | |||||
| class Operator; | |||||
| class TensorDesc; | |||||
| class Tensor; | |||||
| class TBEPluginManager; | |||||
| } // namespace ge | |||||
| namespace google { | |||||
| namespace protobuf { | |||||
| class Message; | |||||
| } | |||||
| } // namespace google | |||||
| namespace domi { | |||||
| const int64_t kMaxNameLength = 1048576; // 1M | |||||
| enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 }; | |||||
| struct DynamicInputOutputInfo { | |||||
| DynamicType type; // input/output | |||||
| const char *port_name; | |||||
| int64_t port_name_len; | |||||
| const char *attr_name; | |||||
| int64_t attr_name_len; | |||||
| DynamicInputOutputInfo() | |||||
| : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {} | |||||
| DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name, | |||||
| int64_t attr_name_len) | |||||
| : type(type), | |||||
| port_name(port_name), | |||||
| port_name_len(port_name_len), | |||||
| attr_name(attr_name), | |||||
| attr_name_len(attr_name_len) {} | |||||
| }; | |||||
| Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); | |||||
| Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, | |||||
| const vector<DynamicInputOutputInfo> &dynamic_name_attr_value); | |||||
| Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | |||||
| Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | |||||
| std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | |||||
| int in_pos = -1, int out_pos = -1); | |||||
| Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input, | |||||
| const std::function<int(int netoutput_index)> &output); | |||||
| Status AutoMappingSubgraphIndex(const ge::Graph &graph, | |||||
| const std::function<Status(int data_index, int &parent_input_index)> &input, | |||||
| const std::function<Status(int netoutput_index, int &parent_output_index)> &output); | |||||
| using google::protobuf::Message; | |||||
| class OpRegistrationDataImpl; | |||||
| using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | |||||
| using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>; | |||||
| using FusionParseParamFunc = | |||||
| std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||||
| using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>; | |||||
| using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
| public: | |||||
| OpRegistrationData(const std::string &om_optype); | |||||
| ~OpRegistrationData(); | |||||
| OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type); | |||||
| OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list); | |||||
| OpRegistrationData &OriginOpType(const std::string &ori_optype); | |||||
| OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | |||||
| OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn); | |||||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); | |||||
| OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); | |||||
| OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | |||||
| OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | |||||
| OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||||
| OpRegistrationData &InputReorderVector(const vector<int> &input_order); | |||||
| domi::ImplyType GetImplyType() const; | |||||
| std::string GetOmOptype() const; | |||||
| std::set<std::string> GetOriginOpTypeSet() const; | |||||
| domi::FrameworkType GetFrameworkType() const; | |||||
| ParseParamFunc GetParseParamFn() const; | |||||
| ParseParamByOpFunc GetParseParamByOperatorFn() const; | |||||
| FusionParseParamFunc GetFusionParseParamFn() const; | |||||
| FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; | |||||
| ParseSubgraphFunc GetParseSubgraphPostFn() const; | |||||
| private: | |||||
| std::shared_ptr<OpRegistrationDataImpl> impl_; | |||||
| friend class OpRegistry; | |||||
| friend class OpRegistrationTbe; | |||||
| friend class ge::TBEPluginManager; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
| public: | |||||
| OpReceiver(OpRegistrationData ®_data); | |||||
| ~OpReceiver() {} | |||||
| }; | |||||
| #define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name) | |||||
| #define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name) | |||||
| #define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ | |||||
| static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name) | |||||
| } // namespace domi | |||||
| namespace ge { | |||||
| using OpRegistrationData = domi::OpRegistrationData; | |||||
| using OpReceiver = domi::OpReceiver; | |||||
| } // namespace ge | |||||
| /*lint +e148*/ | |||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | |||||
| @@ -1,39 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | |||||
| #define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | |||||
| #define SYSID_FWK 3 // Subsystem ID | |||||
| #define MODID_COMMON 0 // Common module ID | |||||
| #define DECLARE_ERRORNO(sysid, modid, name, value) \ | |||||
| const domi::Status name = \ | |||||
| ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||||
| #define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | |||||
| namespace domi { | |||||
| using Status = uint32_t; | |||||
| // General error code | |||||
| DECLARE_ERRORNO(0, 0, SUCCESS, 0); | |||||
| DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | |||||
| DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | |||||
| DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); | |||||
| } // namespace domi | |||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | |||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ | |||||
| #define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ | |||||
| #include <string> | |||||
| namespace domi { | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief AI framework types | |||||
| /// | |||||
| enum FrameworkType { | |||||
| CAFFE = 0, | |||||
| MINDSPORE = 1, | |||||
| TENSORFLOW = 3, | |||||
| ANDROID_NN, | |||||
| ONNX, | |||||
| FRAMEWORK_RESERVED, | |||||
| }; | |||||
| } // namespace domi | |||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ | |||||
| @@ -1,59 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ | |||||
| #define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ | |||||
| namespace domi { | |||||
| #ifdef HOST_VISIBILITY | |||||
| #define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define FMK_FUNC_HOST_VISIBILITY | |||||
| #endif | |||||
| #ifdef DEV_VISIBILITY | |||||
| #define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define FMK_FUNC_DEV_VISIBILITY | |||||
| #endif | |||||
| /// CCE defined constant | |||||
| /// | |||||
| /// @ingroup domi | |||||
| /// @brief original tensor type | |||||
| /// | |||||
| typedef enum tagDomiTensorFormat { | |||||
| DOMI_TENSOR_NCHW = 0, // < NCHW | |||||
| DOMI_TENSOR_NHWC, // < NHWC | |||||
| DOMI_TENSOR_ND, // < Nd Tensor | |||||
| DOMI_TENSOR_NC1HWC0, // < NC1HWC0 | |||||
| DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z | |||||
| DOMI_TENSOR_NC1C0HWPAD, | |||||
| DOMI_TENSOR_NHWC1C0, | |||||
| DOMI_TENSOR_FSR_NCHW, | |||||
| DOMI_TENSOR_FRACTAL_DECONV, | |||||
| DOMI_TENSOR_BN_WEIGHT, | |||||
| DOMI_TENSOR_CHWN, // Android NN Depth CONV | |||||
| DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | |||||
| DOMI_TENSOR_NDHWC, | |||||
| DOMI_TENSOR_NCDHW, | |||||
| DOMI_TENSOR_DHWCN, // 3D filter input tensor format | |||||
| DOMI_TENSOR_DHWNC, | |||||
| DOMI_TENSOR_RESERVED | |||||
| } domiTensorFormat_t; | |||||
| } // namespace domi | |||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ | |||||
| @@ -1,334 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ | |||||
| #define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <map> | |||||
| #include <unordered_map> | |||||
| #include "ge/ge_api_error_codes.h" | |||||
| #include "register/register_error_codes.h" | |||||
| #include "register/register_types.h" | |||||
| #include "graph/operator.h" | |||||
| #define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ | |||||
| do { \ | |||||
| if (!(cond)) { \ | |||||
| if ((fusion_rlt) != nullptr) { \ | |||||
| (fusion_rlt)->SetType(ge::kScopeInvalidType); \ | |||||
| } \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0) | |||||
| namespace domi { | |||||
| class TensorFlowModelParser; | |||||
| } // namespace domi | |||||
| namespace ge { | |||||
| const int32_t kFusionDisableIndex = 99999; | |||||
| const char *const kScopeToMultiNodes = "ScopeToMultiNodes"; | |||||
| const char *const kScopeInvalidType = "ScopeInvalidType"; | |||||
| const char *const kInputFromFusionScope = "InputFromFusionScope"; | |||||
| const char *const kOutputToFusionScope = "OutputToFusionScope"; | |||||
| class ScopePattern; | |||||
| using ScopeFusionPatterns = std::vector<std::vector<ScopePattern *>>; | |||||
| class ScopePassManager; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { | |||||
| public: | |||||
| Scope(); | |||||
| Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); | |||||
| ~Scope(); | |||||
| const std::string &Name() const; | |||||
| const std::string &SubType() const; | |||||
| const std::unordered_map<std::string, ge::OperatorPtr> &AllNodesMap() const; | |||||
| Scope *GetSubScope(const std::string &scope_name) const; | |||||
| const std::string LastName() const; | |||||
| const std::vector<Scope *> &GetAllSubScopes() const; | |||||
| const Scope *GetFatherScope() const; | |||||
| private: | |||||
| class ScopeImpl; | |||||
| std::unique_ptr<ScopeImpl> impl_; | |||||
| friend class ScopeBasePass; | |||||
| friend class ScopeTree; | |||||
| friend class NodeOpTypeFeature; | |||||
| friend class NodeAttrFeature; | |||||
| friend class ScopeFeature; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { | |||||
| public: | |||||
| FusionScopesResult(); | |||||
| Status Init(); | |||||
| ~FusionScopesResult(); | |||||
| void SetName(const std::string &name); | |||||
| void SetType(const std::string &type); | |||||
| void SetDescription(const std::string &description); | |||||
| const std::string &Name() const; | |||||
| const std::vector<ge::OperatorPtr> &Nodes() const; | |||||
| void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map); | |||||
| void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map); | |||||
| class InnerNodeInfo { | |||||
| public: | |||||
| explicit InnerNodeInfo(const std::string &fusion_node_name); | |||||
| InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); | |||||
| InnerNodeInfo(InnerNodeInfo &&other) noexcept; | |||||
| InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; | |||||
| InnerNodeInfo(const InnerNodeInfo &) = delete; | |||||
| InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; | |||||
| ~InnerNodeInfo(); | |||||
| InnerNodeInfo &SetName(const std::string &name); | |||||
| InnerNodeInfo &SetType(const std::string &type); | |||||
| InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); | |||||
| InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); | |||||
| ge::graphStatus BuildInnerNode(); | |||||
| ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); | |||||
| ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); | |||||
| ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); | |||||
| ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); | |||||
| ge::Operator *MutableOperator(); | |||||
| std::string GetName() const; | |||||
| std::string GetType() const; | |||||
| std::vector<std::pair<std::string, int32_t>> GetInputs() const; | |||||
| std::vector<std::pair<std::string, int32_t>> GetOutputs() const; | |||||
| private: | |||||
| class InnerNodeInfoImpl; | |||||
| std::unique_ptr<InnerNodeInfoImpl> impl_; | |||||
| }; | |||||
| InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); | |||||
| InnerNodeInfo *MutableRecentInnerNode(); | |||||
| InnerNodeInfo *MutableInnerNode(uint32_t index); | |||||
| ge::graphStatus CheckInnerNodesInfo(); | |||||
| private: | |||||
| class FusionScopesResultImpl; | |||||
| std::unique_ptr<FusionScopesResultImpl> impl_; | |||||
| friend class ScopeGraph; | |||||
| friend class ScopeBasePass; | |||||
| friend class TensorFlowModelParser; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { | |||||
| public: | |||||
| ScopeTree(); | |||||
| Status Init(); | |||||
| ScopeTree(const ScopeTree &scopetree) = delete; | |||||
| ScopeTree &operator=(const ScopeTree &scopetree) = delete; | |||||
| ~ScopeTree(); | |||||
| const std::vector<Scope *> &GetAllScopes() const; | |||||
| private: | |||||
| class ScopeTreeImpl; | |||||
| std::unique_ptr<ScopeTreeImpl> impl_; | |||||
| friend class ScopeGraph; | |||||
| friend class ScopeBasePass; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { | |||||
| public: | |||||
| ScopeGraph(); | |||||
| Status Init(); | |||||
| ScopeGraph(const ScopeGraph &scope_graph) = delete; | |||||
| ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; | |||||
| ~ScopeGraph(); | |||||
| const ScopeTree *GetScopeTree() const; | |||||
| const std::unordered_map<std::string, ge::OperatorPtr> &GetNodesMap() const; | |||||
| private: | |||||
| class ScopeGraphImpl; | |||||
| std::unique_ptr<ScopeGraphImpl> impl_; | |||||
| friend class ScopePassManager; | |||||
| friend class ScopeBasePass; | |||||
| friend class TensorFlowModelParser; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { | |||||
| public: | |||||
| ScopeAttrValue(); | |||||
| ScopeAttrValue(ScopeAttrValue const &attr_value); | |||||
| ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); | |||||
| ~ScopeAttrValue(); | |||||
| void SetIntValue(int64_t value); | |||||
| void SetFloatValue(float value); | |||||
| void SetStringValue(std::string value); | |||||
| void SetBoolValue(bool value); | |||||
| private: | |||||
| class ScopeAttrValueImpl; | |||||
| std::unique_ptr<ScopeAttrValueImpl> impl_; | |||||
| friend class NodeAttrFeature; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { | |||||
| public: | |||||
| virtual bool Match(const Scope *scope) = 0; | |||||
| virtual ~ScopeBaseFeature(){}; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { | |||||
| public: | |||||
| NodeOpTypeFeature(std::string nodeType, int num, int step = 0); | |||||
| NodeOpTypeFeature(NodeOpTypeFeature const &feature); | |||||
| NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); | |||||
| ~NodeOpTypeFeature(); | |||||
| bool Match(const Scope *scope) override; | |||||
| private: | |||||
| class NodeOpTypeFeatureImpl; | |||||
| std::unique_ptr<NodeOpTypeFeatureImpl> impl_; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { | |||||
| public: | |||||
| NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value); | |||||
| NodeAttrFeature(NodeAttrFeature const &feature); | |||||
| NodeAttrFeature &operator=(NodeAttrFeature const &feature); | |||||
| ~NodeAttrFeature(); | |||||
| bool Match(const Scope *scope) override; | |||||
| private: | |||||
| class NodeAttrFeatureImpl; | |||||
| std::unique_ptr<NodeAttrFeatureImpl> impl_; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { | |||||
| public: | |||||
| ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "", | |||||
| int step = 0); | |||||
| ScopeFeature(ScopeFeature const &feature); | |||||
| ScopeFeature &operator=(ScopeFeature const &feature); | |||||
| ~ScopeFeature(); | |||||
| bool Match(const Scope *scope) override; | |||||
| private: | |||||
| class ScopeFeatureImpl; | |||||
| std::unique_ptr<ScopeFeatureImpl> impl_; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { | |||||
| public: | |||||
| ScopePattern(); | |||||
| ~ScopePattern(); | |||||
| ScopePattern &SetSubType(const std::string &sub_type); | |||||
| ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); | |||||
| ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); | |||||
| ScopePattern &AddScopeFeature(ScopeFeature feature); | |||||
| private: | |||||
| class ScopePatternImpl; | |||||
| std::unique_ptr<ScopePatternImpl> impl_; | |||||
| friend class ScopeBasePass; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { | |||||
| public: | |||||
| ScopesResult(); | |||||
| ScopesResult(ScopesResult const &result); | |||||
| ScopesResult &operator=(ScopesResult const &result); | |||||
| ~ScopesResult(); | |||||
| void SetScopes(std::vector<Scope *> &scopes); | |||||
| void SetNodes(std::vector<ge::OperatorPtr> &nodes); | |||||
| private: | |||||
| class ScopesResultImpl; | |||||
| std::unique_ptr<ScopesResultImpl> impl_; | |||||
| friend class ScopeBasePass; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { | |||||
| public: | |||||
| ScopeBasePass(); | |||||
| virtual ~ScopeBasePass(); | |||||
| protected: | |||||
| // Subclasses implement respective fusion strategies and build the Patterns | |||||
| virtual std::vector<ScopeFusionPatterns> DefinePatterns() = 0; | |||||
| // Define the name of the scope pass | |||||
| virtual std::string PassName() = 0; | |||||
| // Subclasses implement respective multi-scope or operator fusion methods across scopes | |||||
| virtual Status LastMatchScopesAndOPs(std::shared_ptr<ScopeGraph> &scope_graph, | |||||
| std::vector<ScopesResult> &results) = 0; | |||||
| // Subclasses implement their own results and set the input and output of the final fusion operator | |||||
| virtual void GenerateFusionResult(const std::vector<Scope *> &scopes, FusionScopesResult *fusion_rlt) = 0; | |||||
| private: | |||||
| class ScopeBasePassImpl; | |||||
| std::unique_ptr<ScopeBasePassImpl> impl_; | |||||
| friend class ge::ScopePassManager; | |||||
| friend class ScopeBasePassImpl; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { | |||||
| public: | |||||
| using CreateFn = ScopeBasePass *(*)(); | |||||
| ~ScopeFusionPassRegistry(); | |||||
| static ScopeFusionPassRegistry &GetInstance() { | |||||
| static ScopeFusionPassRegistry instance; | |||||
| return instance; | |||||
| } | |||||
| void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); | |||||
| private: | |||||
| ScopeFusionPassRegistry(); | |||||
| class ScopeFusionPassRegistryImpl; | |||||
| /*lint -e148*/ | |||||
| std::unique_ptr<ScopeFusionPassRegistryImpl> impl_; | |||||
| friend class TensorFlowModelParser; | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { | |||||
| public: | |||||
| static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); | |||||
| static void FreeScopePatterns(ScopeFusionPatterns &patterns); | |||||
| static void FreeOneBatchPattern(std::vector<ScopePattern *> &one_batch_pattern); | |||||
| }; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { | |||||
| public: | |||||
| ScopeFusionPassRegistrar(const char *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); | |||||
| ~ScopeFusionPassRegistrar() {} | |||||
| }; | |||||
| #define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ | |||||
| REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, scope_pass, is_general) | |||||
| #define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ | |||||
| REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) | |||||
| #define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ | |||||
| static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ | |||||
| ::ge::ScopeFusionPassRegistrar( \ | |||||
| pass_name, []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, is_general) | |||||
| } // namespace ge | |||||
| #endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ | |||||
| @@ -1,284 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_ANCHOR_H_ | |||||
| #define INC_GRAPH_ANCHOR_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/range_vistor.h" | |||||
| #include "graph/types.h" | |||||
| namespace ge { | |||||
| enum AnchorStatus { | |||||
| ANCHOR_SUSPEND = 0, // dat null | |||||
| ANCHOR_CONST = 1, | |||||
| ANCHOR_DATA = 2, // Effective | |||||
| ANCHOR_RESERVED = 3 | |||||
| }; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| class Node; | |||||
| using NodePtr = std::shared_ptr<Node>; | |||||
| class Edge; | |||||
| using EdgePtr = std::shared_ptr<Edge>; | |||||
| class Anchor; | |||||
| using AnchorPtr = std::shared_ptr<Anchor>; | |||||
| class DataAnchor; | |||||
| using DataAnchorPtr = std::shared_ptr<DataAnchor>; | |||||
| class InDataAnchor; | |||||
| using InDataAnchorPtr = std::shared_ptr<InDataAnchor>; | |||||
| class OutDataAnchor; | |||||
| using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>; | |||||
| class ControlAnchor; | |||||
| using ControlAnchorPtr = std::shared_ptr<ControlAnchor>; | |||||
| class InControlAnchor; | |||||
| using InControlAnchorPtr = std::shared_ptr<InControlAnchor>; | |||||
| class OutControlAnchor; | |||||
| using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>; | |||||
| using ConstAnchor = const Anchor; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this<Anchor> { | |||||
| friend class AnchorUtils; | |||||
| public: | |||||
| using TYPE = const char *; | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstAnchor>>; | |||||
| Anchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~Anchor() = default; | |||||
| protected: | |||||
| // Whether the two anchor is equal | |||||
| virtual bool Equal(AnchorPtr anchor) const = 0; | |||||
| virtual bool IsTypeOf(TYPE type) const; | |||||
| public: | |||||
| // Get all peer anchors connected to current anchor | |||||
| Vistor<AnchorPtr> GetPeerAnchors() const; | |||||
| // Get peer anchor size | |||||
| size_t GetPeerAnchorsSize() const; | |||||
| // Get first peer anchor | |||||
| AnchorPtr GetFirstPeerAnchor() const; | |||||
| // Get the anchor belong to which node | |||||
| NodePtr GetOwnerNode() const; | |||||
| // Remove all links with the anchor | |||||
| void UnlinkAll() noexcept; | |||||
| // Remove link with the given anchor | |||||
| graphStatus Unlink(const AnchorPtr &peer); | |||||
| // Replace peer with new peers | |||||
| graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | |||||
| // Judge if the anchor is linked with the given anchor | |||||
| bool IsLinkedWith(const AnchorPtr &peer); | |||||
| // Get anchor index of the node | |||||
| int GetIdx() const; | |||||
| // set anchor index of the node | |||||
| void SetIdx(int index); | |||||
| protected: | |||||
| // All peer anchors connected to current anchor | |||||
| vector<std::weak_ptr<Anchor>> peer_anchors_; | |||||
| // The owner node of anchor | |||||
| std::weak_ptr<Node> owner_node_; | |||||
| // The index of current anchor | |||||
| int idx_; | |||||
| template <class T> | |||||
| static Anchor::TYPE TypeOf() { | |||||
| static_assert(std::is_base_of<Anchor, T>::value, "T must be a Anchor!"); | |||||
| return __PRETTY_FUNCTION__; | |||||
| } | |||||
| public: | |||||
| template <class T> | |||||
| static std::shared_ptr<T> DynamicAnchorCast(AnchorPtr anchorPtr) { | |||||
| static_assert(std::is_base_of<Anchor, T>::value, "T must be a Anchor!"); | |||||
| if (anchorPtr == nullptr || !anchorPtr->IsTypeOf<T>()) { | |||||
| return nullptr; | |||||
| } | |||||
| return std::static_pointer_cast<T>(anchorPtr); | |||||
| } | |||||
| template <typename T> | |||||
| bool IsTypeOf() { | |||||
| return IsTypeOf(TypeOf<T>()); | |||||
| } | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor { | |||||
| friend class AnchorUtils; | |||||
| public: | |||||
| explicit DataAnchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~DataAnchor() = default; | |||||
| protected: | |||||
| bool IsTypeOf(TYPE type) const override; | |||||
| private: | |||||
| Format format_{FORMAT_ND}; | |||||
| AnchorStatus status_{ANCHOR_SUSPEND}; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor { | |||||
| friend class OutDataAnchor; | |||||
| friend class OutControlAnchor; | |||||
| public: | |||||
| explicit InDataAnchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~InDataAnchor() = default; | |||||
| // Get source out data anchor | |||||
| OutDataAnchorPtr GetPeerOutAnchor() const; | |||||
| // Build connection from OutDataAnchor to InDataAnchor | |||||
| graphStatus LinkFrom(const OutDataAnchorPtr &src); | |||||
| protected: | |||||
| bool Equal(AnchorPtr anchor) const override; | |||||
| bool IsTypeOf(TYPE type) const override; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor { | |||||
| friend class InDataAnchor; | |||||
| friend class AnchorUtils; | |||||
| public: | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstAnchor>>; | |||||
| explicit OutDataAnchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~OutDataAnchor() = default; | |||||
| // Get dst in data anchor(one or more) | |||||
| Vistor<InDataAnchorPtr> GetPeerInDataAnchors() const; | |||||
| uint32_t GetPeerInDataNodesSize() const; | |||||
| // Get dst in control anchor(one or more) | |||||
| Vistor<InControlAnchorPtr> GetPeerInControlAnchors() const; | |||||
| // Build connection from OutDataAnchor to InDataAnchor | |||||
| graphStatus LinkTo(const InDataAnchorPtr &dest); | |||||
| // Build connection from OutDataAnchor to InControlAnchor | |||||
| graphStatus LinkTo(const InControlAnchorPtr &dest); | |||||
| protected: | |||||
| bool Equal(AnchorPtr anchor) const override; | |||||
| bool IsTypeOf(TYPE type) const override; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor { | |||||
| public: | |||||
| explicit ControlAnchor(const NodePtr &ownerNode); | |||||
| explicit ControlAnchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~ControlAnchor() = default; | |||||
| protected: | |||||
| bool IsTypeOf(TYPE type) const override; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor { | |||||
| friend class OutControlAnchor; | |||||
| friend class OutDataAnchor; | |||||
| public: | |||||
| explicit InControlAnchor(const NodePtr &ownerNode); | |||||
| explicit InControlAnchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~InControlAnchor() = default; | |||||
| // Get source out control anchors | |||||
| Vistor<OutControlAnchorPtr> GetPeerOutControlAnchors() const; | |||||
| bool IsPeerOutAnchorsEmpty() const { return peer_anchors_.empty(); } | |||||
| // Get source out data anchors | |||||
| Vistor<OutDataAnchorPtr> GetPeerOutDataAnchors() const; | |||||
| // Build connection from OutControlAnchor to InControlAnchor | |||||
| graphStatus LinkFrom(const OutControlAnchorPtr &src); | |||||
| protected: | |||||
| bool Equal(AnchorPtr anchor) const override; | |||||
| bool IsTypeOf(TYPE type) const override; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor { | |||||
| friend class InControlAnchor; | |||||
| public: | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstAnchor>>; | |||||
| explicit OutControlAnchor(const NodePtr &ownerNode); | |||||
| explicit OutControlAnchor(const NodePtr &ownerNode, int idx); | |||||
| virtual ~OutControlAnchor() = default; | |||||
| // Get dst in control anchor(one or more) | |||||
| Vistor<InControlAnchorPtr> GetPeerInControlAnchors() const; | |||||
| // Get dst data anchor in control anchor(one or more) | |||||
| Vistor<InDataAnchorPtr> GetPeerInDataAnchors() const; | |||||
| // Build connection from OutControlAnchor to InControlAnchor | |||||
| graphStatus LinkTo(const InControlAnchorPtr &dest); | |||||
| // Build connection from OutDataAnchor to InDataAnchor | |||||
| graphStatus LinkTo(const InDataAnchorPtr &dest); | |||||
| protected: | |||||
| bool Equal(AnchorPtr anchor) const override; | |||||
| bool IsTypeOf(TYPE type) const override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_ANCHOR_H_ | |||||
| @@ -1,191 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ | |||||
| #define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/ge_attr_value.h" | |||||
| namespace ge { | |||||
| class GeAttrValue; | |||||
| class _GeSerializable { | |||||
| public: | |||||
| template <typename T> | |||||
| struct ge_serializable_int64_t_support_type { | |||||
| using DT = typename std::remove_cv<T>::type; | |||||
| static const bool value = std::is_same<DT, uint64_t>::value // by cast | |||||
| || std::is_same<DT, int32_t>::value || std::is_same<DT, uint32_t>::value || | |||||
| std::is_same<DT, int16_t>::value || std::is_same<DT, uint16_t>::value || | |||||
| std::is_same<DT, int8_t>::value || std::is_same<DT, uint8_t>::value; | |||||
| }; | |||||
| template <typename T, typename T::__ge_serializable = 0> | |||||
| static GeAttrValue SaveItemAsAttrValue(const T &t) { | |||||
| return GeAttrValue::CreateFrom(t); | |||||
| } | |||||
| template <typename T, typename T::__ge_serializable = 0> | |||||
| static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) { | |||||
| return GeAttrValue::CreateFrom(t); | |||||
| } | |||||
| template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type> | |||||
| static GeAttrValue SaveItemAsAttrValue(const T &t) { | |||||
| return GeAttrValue::CreateFrom<DT>(t); | |||||
| } | |||||
| // int64_t support type | |||||
| template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0> | |||||
| static GeAttrValue SaveItemAsAttrValue(const T &t) { | |||||
| return GeAttrValue::CreateFrom<GeAttrValue::INT>(t); | |||||
| } | |||||
| // vector int64_t support type | |||||
| template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0> | |||||
| static GeAttrValue SaveItemAsAttrValue(const vector<T> &t) { | |||||
| return GeAttrValue::CreateFrom<GeAttrValue::LIST_INT>(t); | |||||
| } | |||||
| template <typename T, typename T::__ge_serializable = 0> | |||||
| static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { | |||||
| return attrVal.GetValue(t); | |||||
| } | |||||
| template <typename T, typename T::__ge_serializable = 0> | |||||
| static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) { | |||||
| return attrVal.GetValue(t); | |||||
| } | |||||
| template <typename T, GeAttrValue::enable_if_type_valid_t<T> = 0, typename DT = typename std::remove_cv<T>::type> | |||||
| static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { | |||||
| return attrVal.GetValue<DT>(t); | |||||
| } | |||||
| template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0> | |||||
| static graphStatus LoadItemFromAttrValue(T &t, GeAttrValue &attrVal) { | |||||
| return attrVal.GetValue<GeAttrValue::INT>(t); | |||||
| } | |||||
| template <typename T, typename std::enable_if<ge_serializable_int64_t_support_type<T>::value, int>::type = 0> | |||||
| static graphStatus LoadItemFromAttrValue(vector<T> &t, GeAttrValue &attrVal) { | |||||
| return attrVal.GetValue<GeAttrValue::LIST_INT>(t); | |||||
| } | |||||
| template <class T, class... Args> | |||||
| static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||||
| GeAttrValue itemVal = SaveItemAsAttrValue(item); | |||||
| (void)namedAttrs.SetAttr(itemName, itemVal); | |||||
| SaveItem(namedAttrs, args...); | |||||
| } | |||||
| static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} | |||||
| template <class T, class... Args> | |||||
| static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||||
| auto itemVal = namedAttrs.GetItem(itemName); | |||||
| auto status = LoadItemFromAttrValue(item, itemVal); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| return LoadItem(namedAttrs, args...); | |||||
| } | |||||
| static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| }; | |||||
| #define _GE_FI(a) #a, a | |||||
| #define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | |||||
| #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | |||||
| #define _GE_MAP_FIELDS3(a1, a2, a3) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3) | |||||
| #define _GE_MAP_FIELDS4(a1, a2, a3, a4) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4) | |||||
| #define _GE_MAP_FIELDS5(a1, a2, a3, a4, a5) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5) | |||||
| #define _GE_MAP_FIELDS6(a1, a2, a3, a4, a5, a6) _GE_FI(a1), _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6) | |||||
| #define _GE_MAP_FIELDS7(a1, a2, a3, a4, a5, a6, a7) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7) | |||||
| #define _GE_MAP_FIELDS8(a1, a2, a3, a4, a5, a6, a7, a8) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8) | |||||
| #define _GE_MAP_FIELDS9(a1, a2, a3, a4, a5, a6, a7, a8, a9) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9) | |||||
| #define _GE_MAP_FIELDS10(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10) | |||||
| #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||||
| _GE_FI(a11) | |||||
| #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||||
| _GE_FI(a11), _GE_FI(a12) | |||||
| #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||||
| #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||||
| #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | |||||
| _GE_FI(a1) \ | |||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||||
| #define _GE_PRIVATE_ARGS_GLUE(x, y) x y | |||||
| #define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, \ | |||||
| ...) \ | |||||
| N | |||||
| #define _GE_PRIVATE_MACRO_VAR_ARGS_IMPL(args) _GE_PRIVATE_MACRO_VAR_ARGS_IMPL_COUNT args | |||||
| #define _GE_COUNT_MACRO_VAR_ARGS(...) \ | |||||
| _GE_PRIVATE_MACRO_VAR_ARGS_IMPL((__VA_ARGS__, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) | |||||
| #define _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) M##count | |||||
| #define _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER2(M, count) | |||||
| #define _GE_PRIVATE_MACRO_CHOOSE_HELPER(M, count) _GE_PRIVATE_MACRO_CHOOSE_HELPER1(M, count) | |||||
| #define _GE_INVOKE_VAR_MACRO(...) \ | |||||
| _GE_PRIVATE_ARGS_GLUE(_GE_PRIVATE_MACRO_CHOOSE_HELPER(_GE_MAP_FIELDS, _GE_COUNT_MACRO_VAR_ARGS(__VA_ARGS__)), \ | |||||
| (__VA_ARGS__)) | |||||
| #define GE_SERIALIZABLE(...) \ | |||||
| public: \ | |||||
| friend class ge::GeAttrValue; \ | |||||
| using __ge_serializable = int; \ | |||||
| \ | |||||
| private: \ | |||||
| ge::graphStatus Save(GeAttrValue &ar) const { \ | |||||
| GeAttrValue::NAMED_ATTRS named_attrs; \ | |||||
| _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ | |||||
| return ar.SetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||||
| } \ | |||||
| ge::graphStatus Load(const GeAttrValue &ar) { \ | |||||
| GeAttrValue::NAMED_ATTRS named_attrs; \ | |||||
| ge::graphStatus status = ar.GetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||||
| if (status != GRAPH_SUCCESS) { \ | |||||
| return status; \ | |||||
| } \ | |||||
| return _GeSerializable::LoadItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ | |||||
| } | |||||
| // end NamedAttrs Helper: GE_SERIALIZABLE | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ | |||||
| @@ -1,82 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_BUFFER_H_ | |||||
| #define INC_GRAPH_BUFFER_H_ | |||||
| #include <graph/types.h> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "detail/attributes_holder.h" | |||||
| namespace ge { | |||||
| #ifdef HOST_VISIBILITY | |||||
| #define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_HOST_VISIBILITY | |||||
| #endif | |||||
| #ifdef DEV_VISIBILITY | |||||
| #define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) | |||||
| #else | |||||
| #define GE_FUNC_DEV_VISIBILITY | |||||
| #endif | |||||
| using std::shared_ptr; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||||
| public: | |||||
| Buffer(); | |||||
| Buffer(const Buffer &other); | |||||
| explicit Buffer(std::size_t bufferSize, std::uint8_t defualtVal = 0); | |||||
| ~Buffer() = default; | |||||
| Buffer &operator=(const Buffer &other); | |||||
| static Buffer CopyFrom(const std::uint8_t *data, std::size_t bufferSize); | |||||
| const std::uint8_t *GetData() const; | |||||
| std::uint8_t *GetData(); | |||||
| std::size_t GetSize() const; | |||||
| void ClearBuffer(); | |||||
| // For compatibility | |||||
| inline const std::uint8_t *data() const { return GetData(); } | |||||
| inline std::uint8_t *data() { return GetData(); } // lint !e659 | |||||
| inline std::size_t size() const { return GetSize(); } | |||||
| inline void clear() { return ClearBuffer(); } | |||||
| uint8_t operator[](size_t index) const { // lint !e1022 !e1042 | |||||
| if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 | |||||
| return (uint8_t)(*buffer_)[index]; | |||||
| } | |||||
| return 0xff; | |||||
| } | |||||
| private: | |||||
| GeIrProtoHelper<proto::AttrDef> data_; | |||||
| std::string *buffer_ = nullptr; | |||||
| // Create from protobuf obj | |||||
| Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | |||||
| Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | |||||
| friend class GeAttrValueImp; | |||||
| friend class GeTensor; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_BUFFER_H_ | |||||
| @@ -1,308 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | |||||
| #define INC_GRAPH_COMPUTE_GRAPH_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <deque> | |||||
| #include "detail/attributes_holder.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/range_vistor.h" | |||||
| namespace ge { | |||||
| class Node; | |||||
| using NodePtr = std::shared_ptr<Node>; | |||||
| class Edge; | |||||
| using EdgePtr = std::shared_ptr<Edge>; | |||||
| class InDataAnchor; | |||||
| using InDataAnchorPtr = std::shared_ptr<InDataAnchor>; | |||||
| class OutDataAnchor; | |||||
| using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>; | |||||
| class ControlAnchor; | |||||
| using ControlAnchorPtr = std::shared_ptr<ControlAnchor>; | |||||
| class InControlAnchor; | |||||
| using InControlAnchorPtr = std::shared_ptr<InControlAnchor>; | |||||
| class OutControlAnchor; | |||||
| using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>; | |||||
| class GeAttrValue; | |||||
| using AttrValuePtr = std::shared_ptr<GeAttrValue>; | |||||
| using ConstComputeGraph = const ComputeGraph; | |||||
| class OperatorImpl; | |||||
| using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | |||||
| class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public AttrHolder { | |||||
| friend class GraphUtils; | |||||
| public: | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | |||||
| explicit ComputeGraph(const std::string &name); | |||||
| ~ComputeGraph() override; | |||||
| std::string GetName() const; | |||||
| void SetName(const std::string &name); | |||||
| using AttrHolder::DelAttr; | |||||
| using AttrHolder::GetAttr; | |||||
| using AttrHolder::HasAttr; | |||||
| using AttrHolder::SetAttr; | |||||
| size_t GetAllNodesSize() const; | |||||
| Vistor<NodePtr> GetAllNodes() const; | |||||
| // is_unknown_shape: false, same with GetAllNodes func | |||||
| // is_unknown_shape: true, same with GetDirectNodes func | |||||
| Vistor<NodePtr> GetNodes(bool is_unknown_shape) const; | |||||
| size_t GetDirectNodesSize() const; | |||||
| Vistor<NodePtr> GetDirectNode() const; | |||||
| Vistor<NodePtr> GetInputNodes() const; | |||||
| Vistor<NodePtr> GetOutputNodes() const; | |||||
| NodePtr FindNode(const std::string &name) const; | |||||
| NodePtr FindFirstNodeMatchType(const std::string &name) const; | |||||
| /*lint -e504*/ | |||||
| // AddNode with NodePtr | |||||
| NodePtr AddNode(NodePtr node); | |||||
| NodePtr AddNode(OpDescPtr op); | |||||
| NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize | |||||
| NodePtr AddNodeFront(NodePtr node); | |||||
| NodePtr AddNodeFront(const OpDescPtr &op); | |||||
| NodePtr AddInputNode(NodePtr node); | |||||
| NodePtr AddOutputNode(NodePtr node); | |||||
| NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index); | |||||
| // insert node with specific pre_node | |||||
| NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); | |||||
| NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); | |||||
| graphStatus RemoveNode(const NodePtr &node); | |||||
| graphStatus RemoveInputNode(const NodePtr &node); | |||||
| graphStatus RemoveOutputNode(const NodePtr &node); | |||||
| graphStatus RemoveConstInput(const NodePtr &node); | |||||
| /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, | |||||
| /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph | |||||
| /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() | |||||
| /// must equal to subgraph->GetOwnerGraph(). | |||||
| /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. | |||||
| /// The subgraph's name SHOULD(not must) be the same as the parameter `name` | |||||
| graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph); | |||||
| graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||||
| void RemoveSubgraph(const std::string &name); | |||||
| void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||||
| std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const; | |||||
| std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const; | |||||
| // obsolete | |||||
| std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | |||||
| // obsolete | |||||
| graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | |||||
| /// | |||||
| /// @brief Update input-mapping | |||||
| /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||||
| /// @return graphStatus | |||||
| /// | |||||
| graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||||
| /// | |||||
| /// @brief Update output-mapping | |||||
| /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||||
| /// @return graphStatus | |||||
| /// | |||||
| graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||||
| graphStatus TopologicalSorting(); | |||||
| bool IsValid() const; | |||||
| void InValid() { is_valid_flag_ = false; } | |||||
| void Dump() const; | |||||
| void Swap(ComputeGraph &graph); | |||||
| graphStatus IsolateNode(const NodePtr &node); | |||||
| graphStatus Verify(); | |||||
| graphStatus InferShape(); | |||||
| graphStatus InferOriginFormat(); | |||||
| graphStatus InferShapeInNeed(); | |||||
| graphStatus InsertEventNodes(); | |||||
| bool operator==(const ComputeGraph &r_compute_graph) const; | |||||
| /*lint +e504*/ | |||||
| const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const { | |||||
| return params_share_map_; | |||||
| } | |||||
| void SetShareParamLayer(const std::map<std::vector<std::string>, std::vector<std::string>> params_share_map) { | |||||
| params_share_map_ = params_share_map; | |||||
| } | |||||
| void SetInputsOrder(const std::vector<std::string> &inputs_order) { inputs_order_ = inputs_order; } | |||||
| void SetGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) { out_nodes_map_ = out_nodes_map; } | |||||
| void AppendGraphOutNodes(std::map<std::string, std::vector<int32_t>> out_nodes_map) { | |||||
| for (auto &item : out_nodes_map) { | |||||
| (void)out_nodes_map_.emplace(item.first, item.second); | |||||
| } | |||||
| } | |||||
| shared_ptr<ComputeGraph> GetParentGraph(); | |||||
| void SetParentGraph(const shared_ptr<ComputeGraph> &parent); | |||||
| shared_ptr<Node> GetParentNode(); | |||||
| void SetParentNode(const shared_ptr<Node> &parent); | |||||
| const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | |||||
| void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | |||||
| ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } | |||||
| void SetOutputSize(uint32_t size) { output_size_ = size; } | |||||
| uint32_t GetOutputSize() const { return output_size_; } | |||||
| void SetInputSize(uint32_t size) { input_size_ = size; } | |||||
| uint32_t GetInputSize() const { return input_size_; } | |||||
| // false: known shape true: unknow shape | |||||
| bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } | |||||
| void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } | |||||
| /// | |||||
| /// Set is need train iteration. | |||||
| /// If set true, it means this graph need to be run iteration some | |||||
| /// times(according variant "npu_runconfig/iterations_per_loop"). | |||||
| /// @param need_iteration is need iteration | |||||
| /// | |||||
| void SetNeedIteration(bool need_iteration) { need_iteration_ = need_iteration; } | |||||
| void SetUserDefOutput(const std::string &output_name); | |||||
| const std::string GetOutput(); | |||||
| /// | |||||
| /// Get is need train iteration. | |||||
| /// @return is need iteration | |||||
| /// | |||||
| bool GetNeedIteration() const { return need_iteration_; } | |||||
| void SetGraphOpName(const std::map<uint32_t, std::string> &op_name_map) { op_name_map_ = op_name_map; } | |||||
| const std::map<uint32_t, std::string> &GetGraphOpName() const { return op_name_map_; } | |||||
| const std::map<OperatorImplPtr, NodePtr> &GetAllNodesInfo() const; | |||||
| void SetAllNodesInfo(const std::map<OperatorImplPtr, NodePtr> &nodes) { all_nodes_infos_ = nodes; } | |||||
| void SetGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) { | |||||
| output_nodes_info_ = out_nodes_info; | |||||
| } | |||||
| void AppendGraphOutNodesInfo(std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info) { | |||||
| output_nodes_info_.insert(output_nodes_info_.end(), out_nodes_info.begin(), out_nodes_info.end()); | |||||
| } | |||||
| const std::vector<std::pair<NodePtr, int32_t>> &GetGraphOutNodesInfo() const { return output_nodes_info_; } | |||||
| void SetGraphTargetNodesInfo(const std::vector<NodePtr> &target_nodes_info) { | |||||
| target_nodes_info_ = target_nodes_info; | |||||
| } | |||||
| const std::vector<NodePtr> &GetGraphTargetNodesInfo() const { return target_nodes_info_; } | |||||
| void SetSessionID(uint64_t session_id) { session_id_ = session_id; } | |||||
| uint64_t GetSessionID() const { return session_id_; } | |||||
| void SetGraphID(uint32_t graph_id) { graph_id_ = graph_id; } | |||||
| uint32_t GetGraphID() const { return graph_id_; } | |||||
| void SaveDataFormat(ge::Format data_format) { data_format_ = data_format; } | |||||
| ge::Format GetDataFormat() const { return data_format_; } | |||||
| bool IsSummaryGraph() const { return is_summary_graph_; } | |||||
| void SetSummaryFlag(bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } | |||||
| // Graph Before BFE | |||||
| ComputeGraphPtr origGraph_; | |||||
| protected: | |||||
| ProtoAttrMapHelper MutableAttrMap() override; | |||||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
| private: | |||||
| graphStatus DFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||||
| std::vector<NodePtr> &stack); | |||||
| graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||||
| std::deque<NodePtr> &stack); | |||||
| graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||||
| std::map<string, NodePtr> &breadth_node_map); | |||||
| graphStatus TopologicalSortingGraph(); | |||||
| graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | |||||
| Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const; | |||||
| size_t GetInEdgeSize(const NodePtr &node); | |||||
| size_t GetOutEdgeSize(const NodePtr &node); | |||||
| graphStatus RemoveExtraOutEdge(const NodePtr &node); | |||||
| bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; | |||||
| bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; | |||||
| bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | |||||
| const std::vector<NodePtr> &l_node_ptr_vector) const; | |||||
| void SetNodesOwner(); | |||||
| friend class ModelSerializeImp; | |||||
| friend class GraphDebugImp; | |||||
| friend class OnnxUtils; | |||||
| friend class TuningUtils; | |||||
| std::string name_; | |||||
| uint32_t graph_id_ = 0; | |||||
| ProtoAttrMapHelper attrs_; | |||||
| std::vector<NodePtr> nodes_; | |||||
| std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||||
| std::vector<NodePtr> target_nodes_info_; | |||||
| std::vector<NodePtr> input_nodes_; | |||||
| std::vector<std::string> inputs_order_; | |||||
| uint32_t input_size_ = 1; | |||||
| std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||||
| uint32_t output_size_ = 1; | |||||
| std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||||
| std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | |||||
| std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_; | |||||
| std::weak_ptr<ComputeGraph> parent_graph_; | |||||
| std::weak_ptr<Node> parent_node_; | |||||
| // the members followed should not in the ComputeGraph class | |||||
| bool is_valid_flag_; | |||||
| bool is_summary_graph_ = false; | |||||
| // Indicates whether it is need iteration | |||||
| bool need_iteration_ = false; | |||||
| std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | |||||
| // TaskIdx -> op_name Map | |||||
| std::map<uint32_t, std::string> op_name_map_; | |||||
| uint64_t session_id_ = 0; | |||||
| ge::Format data_format_ = ge::FORMAT_ND; | |||||
| // unknown graph indicator, default is false, mean known shape | |||||
| bool is_unknown_shape_graph_ = false; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_COMPUTE_GRAPH_H_ | |||||
| @@ -1,195 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_DEF_TYPES_H_ | |||||
| #define INC_GRAPH_DEF_TYPES_H_ | |||||
| #include <atomic> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "graph/attr_value_serializable.h" | |||||
| #include "graph/buffer.h" | |||||
| namespace ge { | |||||
| #define DEF_TYPE_DEC(type, name) \ | |||||
| inline void set_##name(const type &value) { name = value; } \ | |||||
| type *mutable_##name() { return &name; } | |||||
| #define DEF_TYPE_HAS_DEC(type, name) \ | |||||
| inline void set_##name(const type &value) { name = value; } \ | |||||
| \ | |||||
| private: \ | |||||
| bool has_mutable_##name{false}; \ | |||||
| \ | |||||
| public: \ | |||||
| bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ | |||||
| type *mutable_##name() { \ | |||||
| has_mutable_##name = true; \ | |||||
| return &name; \ | |||||
| } | |||||
| #define DEF_TYPE_VEC_DEC(type, name) \ | |||||
| inline int name##_size() const { return name.size(); } \ | |||||
| inline void clear_##name() { name.clear(); } \ | |||||
| inline void set_##name(int index, type value) { name[index] = value; } \ | |||||
| inline void add_##name(type value) { name.push_back(value); } \ | |||||
| inline std::vector<type> *mutable_##name() { return &name; } | |||||
| #define DEF_TYPE_BYTES_DEC(name) \ | |||||
| inline void clear_##name() { name.ClearBuffer(); } \ | |||||
| inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||||
| inline Buffer *mutable_##name() { return &name; } | |||||
| struct CompressInfo { | |||||
| public: | |||||
| CompressInfo() {} | |||||
| CompressInfo(int32_t blockRow, int32_t blockCol, int32_t fractalK, int32_t fractalN, int32_t lastFractalK, | |||||
| int32_t lastFractalN, int32_t cubeSize, int32_t loadDir) { | |||||
| blockrow = blockRow; | |||||
| blockcol = blockCol; | |||||
| fractalk = fractalK; | |||||
| fractaln = fractalN; | |||||
| lastfractalk = lastFractalK; | |||||
| lastfractaln = lastFractalN; | |||||
| cubesize = cubeSize; | |||||
| loaddir = loadDir; | |||||
| } | |||||
| int32_t blockrow{0}; // Block row | |||||
| int32_t blockcol{0}; // Block col | |||||
| int32_t fractalk{0}; // Fractal K | |||||
| int32_t fractaln{0}; // Fractal N | |||||
| int32_t lastfractalk{0}; // K of last fractal | |||||
| int32_t lastfractaln{0}; // N of last fractal | |||||
| int32_t cubesize{0}; // Cube's length | |||||
| int32_t loaddir{0}; // Data load directtiono 0:col load 1:row load | |||||
| DEF_TYPE_DEC(int32_t, blockrow); | |||||
| DEF_TYPE_DEC(int32_t, blockcol); | |||||
| DEF_TYPE_DEC(int32_t, fractalk); | |||||
| DEF_TYPE_DEC(int32_t, fractaln); | |||||
| DEF_TYPE_DEC(int32_t, lastfractalk); | |||||
| DEF_TYPE_DEC(int32_t, lastfractaln); | |||||
| DEF_TYPE_DEC(int32_t, cubesize); | |||||
| DEF_TYPE_DEC(int32_t, loaddir); | |||||
| GE_SERIALIZABLE(blockrow, blockcol, fractalk, fractaln, lastfractalk, lastfractaln, cubesize, loaddir); | |||||
| }; | |||||
| enum QuantizeScaleType { VECTOR_SCALE = 0, SCALAR_SCALE = 1 }; | |||||
| enum QuantizeScaleMode { NORMAL_MODE = 0, SQRT_MODE = 1 }; | |||||
| enum QuantizeAlgorithm { | |||||
| NON_OFFSET_ALGO = 0, | |||||
| HALF_OFFSET_ALGO = 1, | |||||
| ALL_OFFSET_ALGO = 2, | |||||
| }; | |||||
| struct QuantizeFactor { | |||||
| public: | |||||
| // QuantizeScaleMode scale_mode; | |||||
| uint32_t scale_mode{0}; | |||||
| Buffer scale_value; | |||||
| int64_t scale_offset{0}; | |||||
| Buffer offset_data_value; | |||||
| int64_t offset_data_offset{0}; | |||||
| Buffer offset_weight_value; | |||||
| int64_t offset_weight_offset{0}; | |||||
| Buffer offset_pad_value; | |||||
| int64_t offset_pad_offset{0}; | |||||
| DEF_TYPE_DEC(uint32_t, scale_mode); | |||||
| DEF_TYPE_BYTES_DEC(scale_value); | |||||
| DEF_TYPE_DEC(int64_t, scale_offset); | |||||
| DEF_TYPE_BYTES_DEC(offset_data_value); | |||||
| DEF_TYPE_DEC(int64_t, offset_data_offset); | |||||
| DEF_TYPE_BYTES_DEC(offset_weight_value); | |||||
| DEF_TYPE_DEC(int64_t, offset_weight_offset); | |||||
| DEF_TYPE_BYTES_DEC(offset_pad_value); | |||||
| DEF_TYPE_DEC(int64_t, offset_pad_offset); | |||||
| GE_SERIALIZABLE(scale_mode, scale_value, scale_offset, offset_data_value, offset_data_offset, offset_weight_value, | |||||
| offset_weight_offset, offset_pad_value, offset_pad_offset) | |||||
| }; | |||||
| static inline bool QuantizeFactorHasData(const QuantizeFactor &factor) { | |||||
| return factor.scale_value.GetSize() > 0 || factor.offset_data_value.GetSize() > 0 || | |||||
| factor.offset_weight_value.GetSize() > 0 || factor.offset_pad_value.GetSize() > 0; | |||||
| } | |||||
| struct AllOffsetQuantizeInfo { | |||||
| public: | |||||
| AllOffsetQuantizeInfo() {} | |||||
| AllOffsetQuantizeInfo(float s, int32_t o) : scale(s), offset(o) {} | |||||
| float scale{0}; | |||||
| int32_t offset{0}; | |||||
| DEF_TYPE_DEC(float, scale); | |||||
| DEF_TYPE_DEC(int32_t, offset); | |||||
| GE_SERIALIZABLE(scale, offset) | |||||
| }; | |||||
| struct QuantizeCalcFactor { | |||||
| public: | |||||
| Buffer offsetw; | |||||
| int64_t offsetw_offset{0}; | |||||
| Buffer offsetd; | |||||
| int64_t offsetd_offset{0}; | |||||
| Buffer scalereq; | |||||
| int64_t scaledreq_offset{0}; | |||||
| Buffer offsetdnext; | |||||
| int64_t offsetdnext_offset{0}; | |||||
| DEF_TYPE_BYTES_DEC(offsetw); | |||||
| DEF_TYPE_DEC(int64_t, offsetw_offset); | |||||
| DEF_TYPE_BYTES_DEC(offsetd); | |||||
| DEF_TYPE_DEC(int64_t, offsetd_offset); | |||||
| DEF_TYPE_BYTES_DEC(scalereq); | |||||
| DEF_TYPE_DEC(int64_t, scaledreq_offset); | |||||
| DEF_TYPE_BYTES_DEC(offsetdnext); | |||||
| DEF_TYPE_DEC(int64_t, offsetdnext_offset); | |||||
| GE_SERIALIZABLE(offsetw, offsetw_offset, offsetd, offsetd_offset, scalereq, scaledreq_offset, offsetdnext, | |||||
| offsetdnext_offset); | |||||
| }; | |||||
| static inline bool QuantizeFactorHasData(const QuantizeCalcFactor &factor) { | |||||
| return factor.offsetw.GetSize() > 0 || factor.offsetd.GetSize() > 0 || factor.scalereq.GetSize() > 0 || | |||||
| factor.offsetdnext.GetSize() > 0; | |||||
| } | |||||
| struct QuantizeFactorParams { | |||||
| uint32_t quantize_algo{0}; | |||||
| uint32_t scale_type{0}; | |||||
| QuantizeFactor quantize_param; | |||||
| QuantizeFactor dequantize_param; | |||||
| QuantizeFactor requantize_param; | |||||
| QuantizeCalcFactor quantizecalc_param; | |||||
| DEF_TYPE_DEC(uint32_t, quantize_algo); | |||||
| DEF_TYPE_DEC(uint32_t, scale_type); | |||||
| DEF_TYPE_HAS_DEC(QuantizeFactor, quantize_param); | |||||
| DEF_TYPE_HAS_DEC(QuantizeFactor, dequantize_param); | |||||
| DEF_TYPE_HAS_DEC(QuantizeFactor, requantize_param); | |||||
| DEF_TYPE_HAS_DEC(QuantizeCalcFactor, quantizecalc_param); | |||||
| GE_SERIALIZABLE(quantize_algo, scale_type, quantize_param, dequantize_param, requantize_param, quantizecalc_param, | |||||
| has_mutable_quantize_param, has_mutable_dequantize_param, has_mutable_requantize_param, | |||||
| has_mutable_quantizecalc_param); | |||||
| }; | |||||
| #undef DEF_TYPE_DEC | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_DEF_TYPES_H_ | |||||
| @@ -1,120 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_DETAIL_ANY_MAP_H_ | |||||
| #define INC_GRAPH_DETAIL_ANY_MAP_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| namespace ge { | |||||
| using std::shared_ptr; | |||||
| using std::string; | |||||
| class TypeID { | |||||
| public: | |||||
| template <class T> | |||||
| static TypeID Of() { | |||||
| return TypeID(__PRETTY_FUNCTION__); | |||||
| } | |||||
| ~TypeID() = default; | |||||
| bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } | |||||
| private: | |||||
| explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 | |||||
| string type_; | |||||
| }; | |||||
| class AnyMap { | |||||
| public: | |||||
| template <class DT> | |||||
| bool Set(const string &name, const DT &val); | |||||
| template <class T> | |||||
| bool Get(const string &name, T &retValue) const; | |||||
| bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } | |||||
| void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); } | |||||
| private: | |||||
| class Placeholder { | |||||
| public: | |||||
| virtual ~Placeholder() = default; | |||||
| virtual const TypeID &GetTypeInfo() const = 0; | |||||
| }; | |||||
| template <typename VT> | |||||
| class Holder : public Placeholder { | |||||
| public: | |||||
| explicit Holder(const VT &value) : value_(value) {} | |||||
| ~Holder() override = default; | |||||
| const TypeID &GetTypeInfo() const override { | |||||
| static const TypeID typeId = TypeID::Of<VT>(); | |||||
| return typeId; | |||||
| } | |||||
| const VT value_; | |||||
| }; | |||||
| std::map<string, shared_ptr<Placeholder>> anyValues_; | |||||
| }; | |||||
| template <class DT> | |||||
| bool AnyMap::Set(const string &name, const DT &val) { | |||||
| auto it = anyValues_.find(name); | |||||
| std::shared_ptr<Holder<DT>> tmp; | |||||
| try { | |||||
| tmp = std::make_shared<Holder<DT>>(val); | |||||
| } catch (std::bad_alloc &e) { | |||||
| tmp = nullptr; | |||||
| } catch (...) { | |||||
| tmp = nullptr; | |||||
| } | |||||
| if (it == anyValues_.end()) { | |||||
| (void)anyValues_.emplace(name, tmp); | |||||
| } else { | |||||
| if (it->second && it->second->GetTypeInfo() == TypeID::Of<DT>()) { | |||||
| it->second = tmp; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <class T> | |||||
| bool AnyMap::Get(const string &name, T &retValue) const { | |||||
| auto it = anyValues_.find(name); | |||||
| if (it != anyValues_.end() && it->second && it->second->GetTypeInfo() == TypeID::Of<T>()) { | |||||
| auto retPtr = std::static_pointer_cast<Holder<T>>(it->second); | |||||
| retValue = retPtr->value_; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_DETAIL_ANY_MAP_H_ | |||||
| @@ -1,165 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ | |||||
| #define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_set> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "graph/detail/any_map.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| namespace google { | |||||
| namespace protobuf { | |||||
| class Message; | |||||
| template <typename Key, typename T> | |||||
| class Map; | |||||
| } // namespace protobuf | |||||
| } // namespace google | |||||
| namespace ge { | |||||
| using std::string; | |||||
| class GeAttrValue; | |||||
| namespace proto { | |||||
| class AttrDef; | |||||
| class TensorDef; | |||||
| class TensorDescriptor; | |||||
| class ShapeDef; | |||||
| class NamedAttrs; | |||||
| class ModelDef; | |||||
| class OpDef; | |||||
| class GraphDef; | |||||
| } // namespace proto | |||||
| using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 | |||||
| using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; | |||||
| template <class ProtoType> | |||||
| class GeIrProtoHelper { | |||||
| public: | |||||
| GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *protoMsg) | |||||
| : protoOwner_(protoOwner), protoMsg_(protoMsg) {} | |||||
| GeIrProtoHelper() { | |||||
| protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr); | |||||
| protoMsg_ = nullptr; | |||||
| } | |||||
| virtual ~GeIrProtoHelper() = default; | |||||
| template <typename T> | |||||
| GeIrProtoHelper(const GeIrProtoHelper<T> &other) { | |||||
| protoOwner_ = other.protoOwner_; | |||||
| protoMsg_ = other.protoMsg_; | |||||
| } | |||||
| template <typename T> | |||||
| GeIrProtoHelper &operator=(const GeIrProtoHelper<T> &other) { | |||||
| protoOwner_ = other.protoOnwer_; | |||||
| protoMsg_ = other.protoMsg_; | |||||
| return *this; | |||||
| } | |||||
| void InitDefault(); | |||||
| template <typename T> | |||||
| bool operator==(const GeIrProtoHelper<T> &other) const { | |||||
| return protoOwner_ == other.protoOwner_ && protoMsg_ == other.protoMsg_; | |||||
| } | |||||
| inline const ProtoMsgOwner &GetProtoOwner() const { return protoOwner_; } | |||||
| inline ProtoType *GetProtoMsg() const { return protoMsg_; } | |||||
| void CopyValueFrom(const GeIrProtoHelper<const ProtoType> &other) { | |||||
| if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { | |||||
| *protoMsg_ = *other.protoMsg_; | |||||
| } | |||||
| } | |||||
| void MoveValueFrom(GeIrProtoHelper<ProtoType> &&other) { | |||||
| if (other.protoMsg_ != nullptr && protoMsg_ != nullptr) { | |||||
| *protoMsg_ = std::move(*other.protoMsg_); | |||||
| } | |||||
| } | |||||
| void Swap(GeIrProtoHelper<ProtoType> &other) { | |||||
| protoOwner_.swap(other.protoOwner_); | |||||
| ProtoType *temp = protoMsg_; | |||||
| protoMsg_ = other.protoMsg_; | |||||
| other.protoMsg_ = temp; | |||||
| } | |||||
| // protoMsg_ is part of protoOwner_, they have the same runtime | |||||
| ProtoMsgOwner protoOwner_ = nullptr; | |||||
| ProtoType *protoMsg_ = nullptr; | |||||
| friend class GeIrProtoHelper<typename std::conditional< | |||||
| std::is_const<ProtoType>::value, typename std::remove_const<ProtoType>::type, const ProtoType>::type>; | |||||
| }; | |||||
| using ProtoAttrMapHelper = GeIrProtoHelper<ProtoAttrMap>; | |||||
| using ConstProtoAttrMapHelper = GeIrProtoHelper<const ProtoAttrMap>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { | |||||
| public: | |||||
| AttrHolder() = default; | |||||
| virtual ~AttrHolder() = default; | |||||
| graphStatus SetAttr(const string &name, const GeAttrValue &value); | |||||
| graphStatus GetAttr(const string &name, GeAttrValue &value) const; | |||||
| bool HasAttr(const string &name) const; | |||||
| graphStatus DelAttr(const string &name); | |||||
| void CopyAttrsFrom(const AttrHolder &holder); | |||||
| void Swap(AttrHolder &holder) { | |||||
| requiredAttrs_.swap(holder.requiredAttrs_); | |||||
| extAttrs_.Swap(holder.extAttrs_); | |||||
| } | |||||
| template <class T> | |||||
| bool SetExtAttr(const string &name, const T &value) { | |||||
| return extAttrs_.Set(name, value); | |||||
| } | |||||
| template <class T> | |||||
| T TryGetExtAttr(const string &name, T defaultValue) const { | |||||
| T ret(defaultValue); | |||||
| (void)extAttrs_.Get(name, ret); | |||||
| return ret; | |||||
| } | |||||
| protected: | |||||
| graphStatus AddRequiredAttr(const std::string &name); | |||||
| const std::unordered_set<string> GetAllAttrNames() const; | |||||
| const std::map<string, GeAttrValue> GetAllAttrs() const; // lint !e1073 | |||||
| virtual ProtoAttrMapHelper MutableAttrMap() = 0; | |||||
| virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; | |||||
| friend class ModelSerializeImp; | |||||
| friend class AttrUtils; | |||||
| friend class AttrUtilsHelper; | |||||
| std::vector<string> requiredAttrs_; | |||||
| private: | |||||
| AnyMap extAttrs_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ | |||||
| @@ -1,93 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ | |||||
| #define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/detail/attributes_holder.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||||
| struct NodeNameGraphReq { | |||||
| string node_name; | |||||
| int32_t index; | |||||
| ComputeGraphPtr graph; | |||||
| }; | |||||
| struct NodeNameNodeReq { | |||||
| string src_node_name; | |||||
| int32_t src_out_index; | |||||
| NodePtr dst_node; | |||||
| int32_t dst_in_index; | |||||
| string dst_node_name; | |||||
| }; | |||||
| class ModelSerializeImp { | |||||
| public: | |||||
| bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); | |||||
| bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); | |||||
| bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | |||||
| bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||||
| bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||||
| bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | |||||
| bool UnserializeModel(Model &model, proto::ModelDef &modeProto); | |||||
| bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graphProto); | |||||
| bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graphProto); | |||||
| bool HandleNodeNameRef(); | |||||
| bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); | |||||
| void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out, | |||||
| std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out, std::vector<string> &opt); | |||||
| void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); | |||||
| bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); | |||||
| bool UnserializeTensor(GeTensorPtr &tensor, proto::TensorDef &tensorProto); | |||||
| bool ParseNodeIndex(const string &node_index, string &nodeName, int32_t &index); | |||||
| void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } | |||||
| private: | |||||
| bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map<std::string, ComputeGraphPtr> &subgraphs); | |||||
| std::vector<NodeNameGraphReq> graph_input_node_names_; | |||||
| std::vector<NodeNameGraphReq> graph_output_node_names_; | |||||
| std::vector<NodeNameNodeReq> node_input_node_names_; | |||||
| std::map<string, NodePtr> node_map_; | |||||
| ProtoMsgOwner protobuf_owner_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ | |||||
| @@ -1,343 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_GE_ATTR_VALUE_H_ | |||||
| #define INC_GRAPH_GE_ATTR_VALUE_H_ | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "graph/buffer.h" | |||||
| #include "detail/attributes_holder.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| class GeTensor; | |||||
| using GeTensorPtr = std::shared_ptr<GeTensor>; | |||||
| using ConstGeTensorPtr = std::shared_ptr<const GeTensor>; | |||||
| class ComputeGraph; | |||||
| using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||||
| using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>; | |||||
| class GeTensorDesc; | |||||
| class GeAttrValue; | |||||
| class GeAttrValueImp; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { | |||||
| public: | |||||
| NamedAttrs(); | |||||
| virtual ~NamedAttrs() = default; | |||||
| void SetName(const std::string &name); | |||||
| string GetName() const; | |||||
| GeAttrValue GetItem(const string &key) const; | |||||
| protected: | |||||
| ProtoAttrMapHelper MutableAttrMap() override; | |||||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
| private: | |||||
| // Create namedAttrs from protobuf obj | |||||
| NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||||
| GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||||
| friend class GeAttrValueImp; | |||||
| friend class GeAttrValue; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
| public: | |||||
| using INT = int64_t; | |||||
| using FLOAT = float; | |||||
| using BOOL = bool; | |||||
| using STR = std::string; | |||||
| using TENSOR = GeTensorPtr; | |||||
| using TENSOR_DESC = GeTensorDesc; | |||||
| using GRAPH = ComputeGraphPtr; | |||||
| using BYTES = Buffer; | |||||
| using NAMED_ATTRS = ge::NamedAttrs; | |||||
| using DATA_TYPE = ge::DataType; | |||||
| using LIST_INT = vector<INT>; | |||||
| using LIST_FLOAT = vector<FLOAT>; | |||||
| using LIST_BOOL = vector<BOOL>; | |||||
| using LIST_STR = vector<STR>; | |||||
| using LIST_TENSOR = vector<TENSOR>; | |||||
| using LIST_TENSOR_DESC = vector<TENSOR_DESC>; | |||||
| using LIST_GRAPH = vector<GRAPH>; | |||||
| using LIST_BYTES = vector<BYTES>; | |||||
| using LIST_NAMED_ATTRS = vector<NAMED_ATTRS>; | |||||
| using LIST_LIST_INT = vector<vector<int64_t>>; | |||||
| using LIST_DATA_TYPE = vector<ge::DataType>; | |||||
| using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). | |||||
| enum ValueType { | |||||
| VT_NONE = 0, | |||||
| VT_STRING, | |||||
| VT_FLOAT, | |||||
| VT_BOOL, | |||||
| VT_INT, | |||||
| VT_TENSOR_DESC, | |||||
| VT_TENSOR, | |||||
| VT_BYTES, | |||||
| VT_GRAPH, | |||||
| VT_NAMED_ATTRS, | |||||
| VT_LIST_LIST_INT, | |||||
| VT_DATA_TYPE, | |||||
| VT_LIST_BASE = 1000, | |||||
| VT_LIST_STRING = VT_LIST_BASE + VT_STRING, | |||||
| VT_LIST_FLOAT = VT_LIST_BASE + VT_FLOAT, | |||||
| VT_LIST_BOOL = VT_LIST_BASE + VT_BOOL, | |||||
| VT_LIST_INT = VT_LIST_BASE + VT_INT, | |||||
| VT_LIST_TENSOR_DESC = VT_LIST_BASE + VT_TENSOR_DESC, | |||||
| VT_LIST_TENSOR = VT_LIST_BASE + VT_TENSOR, | |||||
| VT_LIST_BYTES = VT_LIST_BASE + VT_BYTES, | |||||
| VT_LIST_GRAPH = VT_LIST_BASE + VT_GRAPH, | |||||
| VT_LIST_NAMED_ATTRS = VT_LIST_BASE + VT_NAMED_ATTRS, | |||||
| VT_LIST_DATA_TYPE = VT_LIST_BASE + VT_DATA_TYPE, | |||||
| }; | |||||
| template <class T> | |||||
| struct IsAttrTypeEnable { | |||||
| using DT = typename std::remove_cv<T>::type; | |||||
| static bool const VALUE = std::is_same<INT, DT>::value || std::is_same<FLOAT, DT>::value || | |||||
| std::is_same<BOOL, DT>::value || std::is_same<STR, DT>::value || | |||||
| std::is_same<GRAPH, DT>::value || std::is_same<TENSOR, DT>::value || | |||||
| std::is_same<TENSOR_DESC, DT>::value || std::is_same<BYTES, DT>::value || | |||||
| std::is_same<NAMED_ATTRS, DT>::value || std::is_same<DATA_TYPE, DT>::value; | |||||
| // Not has list type of NamedAttrs | |||||
| static bool const LIST_VALUE = std::is_same<LIST_INT, DT>::value || std::is_same<LIST_FLOAT, DT>::value || | |||||
| std::is_same<LIST_BOOL, DT>::value || std::is_same<LIST_STR, DT>::value || | |||||
| std::is_same<LIST_GRAPH, DT>::value || std::is_same<LIST_TENSOR, DT>::value || | |||||
| std::is_same<LIST_TENSOR_DESC, DT>::value || std::is_same<LIST_BYTES, DT>::value || | |||||
| std::is_same<LIST_NAMED_ATTRS, DT>::value || | |||||
| std::is_same<LIST_LIST_INT, DT>::value || std::is_same<LIST_DATA_TYPE, DT>::value; | |||||
| }; | |||||
| template <typename vector_type> | |||||
| // To cols | |||||
| using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type; | |||||
| template <typename one_type> | |||||
| using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | |||||
| template <typename val_type> | |||||
| using enable_if_type_valid_t = | |||||
| typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||||
| template <typename seriliable_type> | |||||
| using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | |||||
| GeAttrValue(); | |||||
| ~GeAttrValue() = default; | |||||
| // SetValue, Set initializer_list | |||||
| template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0> | |||||
| graphStatus SetValue(std::initializer_list<DT> &&val) { | |||||
| T vectorVal; | |||||
| for (auto &item : val) { | |||||
| vectorVal.push_back(item); | |||||
| } | |||||
| return SetValue(vectorVal); | |||||
| } | |||||
| // SetValue, Set vector | |||||
| template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0> | |||||
| graphStatus SetValue(const std::vector<DT> &val) { | |||||
| T vectorVal; | |||||
| for (auto item : val) { | |||||
| vectorVal.push_back(item); | |||||
| } | |||||
| return SetValue(vectorVal); | |||||
| } | |||||
| // SetValue, not list type | |||||
| template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0> | |||||
| graphStatus SetValue(DT &&val) { | |||||
| return SetValue(T(std::forward<DT>(val))); | |||||
| } | |||||
| // GE_SERIALIZABLE | |||||
| template <typename T, enable_if_seriliable_type_valid_t<T> = 0> | |||||
| graphStatus SetValue(const T &t) { | |||||
| return t.Save(*this); | |||||
| } | |||||
| template <typename T, enable_if_seriliable_type_valid_t<T> = 0> | |||||
| graphStatus SetValue(const vector<T> &t) { | |||||
| vector<NamedAttrs> attrs; | |||||
| for (auto &item : t) { | |||||
| GeAttrValue val; | |||||
| item.Save(val); | |||||
| NamedAttrs attrsItem; | |||||
| (void)val.GetValue<NamedAttrs>(attrsItem); | |||||
| attrs.push_back(attrsItem); | |||||
| } | |||||
| return SetValue(attrs); | |||||
| } | |||||
| // GetValue, list value | |||||
| template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0, | |||||
| typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0> | |||||
| graphStatus GetValue(std::vector<DT> &val) const { | |||||
| T valGet; | |||||
| val.clear(); | |||||
| auto status = GetValue(valGet); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| for (auto item : valGet) { | |||||
| val.push_back(item); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // GetValue, not list type | |||||
| template <typename T, typename DT, enable_if_one_type_valid_t<T> = 0, | |||||
| typename std::enable_if<!std::is_same<DT, GeTensorPtr>::value, int>::type = 0> | |||||
| graphStatus GetValue(DT &val) const { | |||||
| T valGet; | |||||
| auto status = GetValue(valGet); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| val = DT(valGet); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| // GE_SERIALIZABLE | |||||
| template <typename T, enable_if_seriliable_type_valid_t<T> = 0> | |||||
| graphStatus GetValue(T &t) { | |||||
| return t.Load(*this); | |||||
| } | |||||
| template <typename T, enable_if_seriliable_type_valid_t<T> = 0> | |||||
| graphStatus GetValue(vector<T> &t) { | |||||
| graphStatus status; | |||||
| t.clear(); | |||||
| vector<NamedAttrs> attrs; | |||||
| status = this->GetValue(attrs); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| for (auto &attr : attrs) { | |||||
| T item; | |||||
| GeAttrValue val; | |||||
| (void)val.SetValue(attr); | |||||
| status = item.Load(val); | |||||
| if (status != GRAPH_SUCCESS) { | |||||
| return status; | |||||
| } | |||||
| t.push_back(item); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| template <typename T, typename DT, enable_if_type_valid_t<T> = 0> | |||||
| static GeAttrValue CreateFrom(DT &&val) { | |||||
| GeAttrValue valRet; | |||||
| (void)valRet.SetValue<T>(std::forward<DT>(val)); | |||||
| return valRet; | |||||
| } | |||||
| template <typename T, typename DT, enable_if_vector_type_valid_t<T> = 0> | |||||
| static GeAttrValue CreateFrom(std::initializer_list<DT> &&val) { | |||||
| GeAttrValue valRet; | |||||
| (void)valRet.SetValue<T>(std::move(val)); | |||||
| return valRet; | |||||
| } | |||||
| template <typename T, enable_if_seriliable_type_valid_t<T> = 0> | |||||
| static GeAttrValue CreateFrom(const T &val) { | |||||
| GeAttrValue valRet; | |||||
| (void)valRet.SetValue(val); | |||||
| return valRet; | |||||
| } | |||||
| template <typename T, enable_if_seriliable_type_valid_t<T> = 0> | |||||
| static GeAttrValue CreateFrom(const vector<T> &val) { | |||||
| GeAttrValue valRet; | |||||
| (void)valRet.SetValue(val); | |||||
| return valRet; | |||||
| } | |||||
| ValueType GetValueType() const; | |||||
| bool IsEmpty() const; | |||||
| GeAttrValue Copy() const; | |||||
| // For map key | |||||
| bool operator==(const GeAttrValue &other) const { return value_ == other.value_; } | |||||
| graphStatus MutableTensor(GeTensorPtr &tensor); | |||||
| graphStatus MutableListTensor(vector<GeTensorPtr> &list_tensor); | |||||
| private: | |||||
| #define VALUE_SET_GET_DEC(DT) \ | |||||
| graphStatus SetValue(const DT &val); \ | |||||
| graphStatus GetValue(DT &val) const; | |||||
| VALUE_SET_GET_DEC(GeAttrValue::STR) | |||||
| VALUE_SET_GET_DEC(GeAttrValue::INT) | |||||
| VALUE_SET_GET_DEC(GeAttrValue::FLOAT) | |||||
| VALUE_SET_GET_DEC(GeAttrValue::BOOL) | |||||
| VALUE_SET_GET_DEC(GeTensorDesc) | |||||
| VALUE_SET_GET_DEC(GeAttrValue::TENSOR) | |||||
| VALUE_SET_GET_DEC(GeAttrValue::GRAPH) | |||||
| VALUE_SET_GET_DEC(BYTES) | |||||
| VALUE_SET_GET_DEC(NamedAttrs) | |||||
| VALUE_SET_GET_DEC(ge::DataType) // lint !e665 | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::STR>) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::INT>) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::BOOL>) | |||||
| VALUE_SET_GET_DEC(vector<GeTensorDesc>) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::TENSOR>) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>) | |||||
| VALUE_SET_GET_DEC(vector<NamedAttrs>) | |||||
| VALUE_SET_GET_DEC(vector<vector<int64_t>>) // lint !e665 | |||||
| VALUE_SET_GET_DEC(vector<ge::DataType>) // lint !e665 | |||||
| #undef VALUE_SET_GET_DEC | |||||
| GeIrProtoHelper<proto::AttrDef> value_; | |||||
| GeAttrValue(const ProtoMsgOwner &proto_owner, ge::proto::AttrDef *val); | |||||
| friend class AttrHolder; | |||||
| friend class ModelSerializeImp; | |||||
| friend class OnnxUtils; | |||||
| }; | |||||
| class AttrValueImpl { | |||||
| public: | |||||
| AttrValueImpl() = default; | |||||
| ~AttrValueImpl() = default; | |||||
| GeAttrValue geAttrValue_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_GE_ATTR_VALUE_H_ | |||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_GE_CONTEXT_H_ | |||||
| #define INC_GRAPH_GE_CONTEXT_H_ | |||||
| #include <string> | |||||
| #include "graph/ge_error_codes.h" | |||||
| namespace ge { | |||||
| class GEContext { | |||||
| public: | |||||
| graphStatus GetOption(const std::string &key, std::string &option); | |||||
| bool GetHostExecFlag(); | |||||
| uint64_t SessionId(); | |||||
| uint32_t DeviceId(); | |||||
| uint64_t TraceId(); | |||||
| void Init(); | |||||
| void SetSessionId(uint64_t session_id); | |||||
| void SetCtxDeviceId(uint32_t device_id); | |||||
| private: | |||||
| uint64_t session_id_ = 0; | |||||
| uint32_t device_id_ = 0; | |||||
| uint64_t trace_id_ = 0; | |||||
| }; // class GEContext | |||||
| /// Get context | |||||
| /// @return | |||||
| GEContext &GetContext(); | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_GE_CONTEXT_H_ | |||||
| @@ -1,26 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ | |||||
| #define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| namespace ge { | |||||
| std::map<std::string, std::string> &GetMutableGlobalOptions(); | |||||
| } | |||||
| #endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ | |||||
| #define INC_GRAPH_GE_LOCAL_CONTEXT_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/ge_error_codes.h" | |||||
| using std::map; | |||||
| using std::string; | |||||
| namespace ge { | |||||
| class GEThreadLocalContext { | |||||
| public: | |||||
| graphStatus GetOption(const string &key, string &option); | |||||
| void SetGraphOption(map<std::string, string> options_map); | |||||
| void SetSessionOption(map<std::string, string> options_map); | |||||
| void SetGlobalOption(map<std::string, string> options_map); | |||||
| private: | |||||
| map<string, string> graph_options_; | |||||
| map<string, string> session_options_; | |||||
| map<string, string> global_options_; | |||||
| }; // class GEThreadLocalContext | |||||
| GEThreadLocalContext &GetThreadLocalContext(); | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ | |||||
| @@ -1,193 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_GE_TENSOR_H_ | |||||
| #define INC_GRAPH_GE_TENSOR_H_ | |||||
| #include <atomic> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "detail/attributes_holder.h" | |||||
| #include "graph/buffer.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| namespace ge { | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||||
| public: | |||||
| GeShape(); | |||||
| ~GeShape() = default; | |||||
| explicit GeShape(std::vector<int64_t> s); | |||||
| size_t GetDimNum() const; | |||||
| // If the idx is invalid, return 0 | |||||
| int64_t GetDim(size_t idx) const; | |||||
| graphStatus SetDim(size_t idx, int64_t value); | |||||
| std::vector<int64_t> GetDims() const; | |||||
| int64_t GetShapeSize() const; | |||||
| std::string ToString() const; | |||||
| /// | |||||
| /// @brief Check is unknown shape | |||||
| /// @return bool | |||||
| /// | |||||
| bool IsUnknownShape() const; | |||||
| /// | |||||
| /// @brief Check is a scalar | |||||
| /// @return bool | |||||
| /// | |||||
| bool IsScalar() const; | |||||
| GeShape(const GeShape &other); | |||||
| GeShape(GeShape &&other); | |||||
| GeShape &operator=(const GeShape &other); | |||||
| GeShape &operator=(GeShape &&other); | |||||
| private: | |||||
| GeIrProtoHelper<proto::ShapeDef> shape_def_; | |||||
| friend class GeTensorDesc; | |||||
| // Create from proto obj | |||||
| GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | |||||
| void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder { | |||||
| friend class TensorUtils; | |||||
| friend class GeAttrValue; | |||||
| friend class ModelSerialize; | |||||
| public: | |||||
| GeTensorDesc(); | |||||
| explicit GeTensorDesc(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | |||||
| GeTensorDesc(const GeTensorDesc &desc); | |||||
| GeTensorDesc(GeTensorDesc &&desc); | |||||
| ~GeTensorDesc() = default; | |||||
| bool operator==(const GeTensorDesc &r_ge_tensor_desc) const; | |||||
| void Update(GeShape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | |||||
| GeShape GetShape() const; | |||||
| GeShape &MutableShape(); | |||||
| void SetShape(GeShape shape); | |||||
| // set shape with -2, it stand for unknown shape | |||||
| void SetUnknownDimNumShape(); | |||||
| // for unknown shape | |||||
| graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||||
| graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||||
| GeShape GetOriginShape() const; | |||||
| void SetOriginShape(const GeShape &originShape); | |||||
| Format GetFormat() const; | |||||
| void SetFormat(Format format); | |||||
| Format GetOriginFormat() const; | |||||
| void SetOriginFormat(Format originFormat); | |||||
| void SetName(const std::string &name); | |||||
| const std::string GetName() const; | |||||
| DataType GetDataType() const; | |||||
| void SetDataType(DataType dt); | |||||
| DataType GetOriginDataType() const; | |||||
| void SetOriginDataType(DataType originDataType); | |||||
| std::vector<uint32_t> GetRefPortIndex() const; | |||||
| void SetRefPortByIndex(const std::vector<uint32_t> &index); | |||||
| GeTensorDesc Clone() const; | |||||
| GeTensorDesc &operator=(const GeTensorDesc &desc); | |||||
| GeTensorDesc &operator=(GeTensorDesc &&desc); | |||||
| graphStatus IsValid() const; | |||||
| protected: | |||||
| ProtoAttrMapHelper MutableAttrMap() override; | |||||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
| private: | |||||
| bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const; | |||||
| using AttrHolder::DelAttr; | |||||
| using AttrHolder::GetAllAttrs; | |||||
| using AttrHolder::GetAttr; | |||||
| using AttrHolder::HasAttr; | |||||
| using AttrHolder::SetAttr; | |||||
| void Init(); | |||||
| // Create from proto obj | |||||
| GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | |||||
| friend class GeTensor; | |||||
| friend class GeAttrValueImp; | |||||
| friend class ModelSerializeImp; | |||||
| friend class OnnxUtils; | |||||
| GeIrProtoHelper<proto::TensorDescriptor> tensor_descriptor_; | |||||
| // Reference from tensorDescriptor_, do not direct use | |||||
| mutable GeShape __shape_; | |||||
| void RefTo(const GeTensorDesc &tensorDesc) { tensor_descriptor_ = tensorDesc.tensor_descriptor_; } | |||||
| GeShape &ShapeReference() const; | |||||
| }; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||||
| public: | |||||
| GeTensor(); | |||||
| explicit GeTensor(const GeTensorDesc &tensorDesc); | |||||
| explicit GeTensor(const GeTensorDesc &tensorDesc, const std::vector<uint8_t> &data); | |||||
| explicit GeTensor(const GeTensorDesc &tensorDesc, const Buffer &data); | |||||
| explicit GeTensor(const GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); | |||||
| explicit GeTensor(GeTensorDesc &&tensorDesc, std::vector<uint8_t> &&data); | |||||
| ~GeTensor() = default; | |||||
| GeTensorDesc GetTensorDesc() const; | |||||
| GeTensorDesc &MutableTensorDesc(); | |||||
| void SetTensorDesc(const GeTensorDesc &tensorDesc); | |||||
| const Buffer GetData() const; | |||||
| Buffer MutableData(); | |||||
| graphStatus SetData(std::vector<uint8_t> &&data); | |||||
| graphStatus SetData(const std::vector<uint8_t> &data); | |||||
| graphStatus SetData(const Buffer &data); | |||||
| graphStatus SetData(const uint8_t *data, size_t size); | |||||
| GeTensor Clone() const; | |||||
| // Share value | |||||
| GeTensor(const GeTensor &other); | |||||
| // Share value | |||||
| GeTensor &operator=(const GeTensor &other); | |||||
| private: | |||||
| friend class GeAttrValueImp; | |||||
| friend class ModelSerializeImp; | |||||
| friend class OnnxUtils; | |||||
| // Create from proto obj | |||||
| GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | |||||
| GeIrProtoHelper<proto::TensorDef> tensor_def_; | |||||
| // Reference from tensorDef_, do not direct use | |||||
| mutable GeTensorDesc __desc_; | |||||
| GeTensorDesc &DescReference() const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_GE_TENSOR_H_ | |||||
| @@ -1,134 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_GRAPH_UTIL_H_ | |||||
| #define INC_GRAPH_GRAPH_UTIL_H_ | |||||
| #include <string> | |||||
| #include "proto/om.pb.h" | |||||
| namespace ge { | |||||
| using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | |||||
| bool HasOpAttr(const OpDef *opdef, std::string attr_name); | |||||
| bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); | |||||
| static const char OP_TYPE_DATA[] = "Data"; | |||||
| static const char OP_TYPE_INPUT[] = "Input"; | |||||
| static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; | |||||
| static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; | |||||
| static const char OP_TYPE_ANN_DATA[] = "AnnData"; | |||||
| } // namespace ge | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #include "toolchain/slog.h" | |||||
| const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; | |||||
| #else | |||||
| #include <syslog.h> | |||||
| #include <utils/Log.h> | |||||
| const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; | |||||
| #endif | |||||
| #ifdef _MSC_VER | |||||
| #define FUNC_NAME __FUNCTION__ | |||||
| #else | |||||
| #define FUNC_NAME __PRETTY_FUNCTION__ | |||||
| #endif | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ | |||||
| dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) | |||||
| #define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ | |||||
| dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) | |||||
| #define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ | |||||
| dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) | |||||
| #else | |||||
| #define D_GRAPH_LOG(level, format, ...) \ | |||||
| do { \ | |||||
| { \ | |||||
| fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ | |||||
| __FILE__, __LINE__, ##__VA_ARGS__); \ | |||||
| syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ | |||||
| ##__VA_ARGS__); \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) | |||||
| #define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) | |||||
| #define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) | |||||
| #endif | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) | |||||
| #define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) | |||||
| #define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) | |||||
| #else | |||||
| #define GRAPH_LOG(level, format, ...) \ | |||||
| do { \ | |||||
| { \ | |||||
| fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ | |||||
| __FILE__, __LINE__, ##__VA_ARGS__); \ | |||||
| syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ | |||||
| ##__VA_ARGS__); \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) | |||||
| #define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) | |||||
| #define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) | |||||
| #endif | |||||
| #define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ | |||||
| do { \ | |||||
| const domi::graphStatus _status = (expr); \ | |||||
| if (_status != domi::GRAPH_SUCCESS) { \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||||
| do { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GRAPH_LOGE(__VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| }; | |||||
| #define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ | |||||
| { \ | |||||
| if (expr) { \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| #define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ | |||||
| do { \ | |||||
| const ::domi::graphStatus _status = (expr); \ | |||||
| if (_status) { \ | |||||
| GRAPH_LOGE(__VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #endif // INC_GRAPH_GRAPH_UTIL_H_ | |||||
| @@ -1,94 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_MODEL_H_ | |||||
| #define INC_GRAPH_MODEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "detail/attributes_holder.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "graph/graph.h" | |||||
| namespace ge { | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||||
| public: | |||||
| Model(); | |||||
| ~Model() = default; | |||||
| Model(const string &name, const string &custom_version); | |||||
| string GetName() const; | |||||
| void SetName(const string &name); | |||||
| uint32_t GetVersion() const; | |||||
| void SetVersion(uint32_t version) { version_ = version; } | |||||
| std::string GetPlatformVersion() const; | |||||
| void SetPlatformVersion(string version) { platform_version_ = version; } | |||||
| Graph GetGraph() const; | |||||
| void SetGraph(const Graph &graph); | |||||
| void SetAttr(const ProtoAttrMapHelper &attrs); | |||||
| using AttrHolder::GetAllAttrNames; | |||||
| using AttrHolder::GetAllAttrs; | |||||
| using AttrHolder::GetAttr; | |||||
| using AttrHolder::HasAttr; | |||||
| using AttrHolder::SetAttr; | |||||
| graphStatus Save(Buffer &buffer, bool is_dump = false) const; | |||||
| graphStatus SaveToFile(const string &file_name) const; | |||||
| // Model will be rewrite | |||||
| static graphStatus Load(const uint8_t *data, size_t len, Model &model); | |||||
| graphStatus Load(ge::proto::ModelDef &model_def); | |||||
| graphStatus LoadFromFile(const string &file_name); | |||||
| bool IsValid() const; | |||||
| protected: | |||||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
| ProtoAttrMapHelper MutableAttrMap() override; | |||||
| private: | |||||
| void Init(); | |||||
| ProtoAttrMapHelper attrs_; | |||||
| friend class ModelSerializeImp; | |||||
| friend class GraphDebugImp; | |||||
| friend class OnnxUtils; | |||||
| friend class ModelHelper; | |||||
| friend class ModelBuilder; | |||||
| string name_; | |||||
| uint32_t version_; | |||||
| std::string platform_version_{""}; | |||||
| Graph graph_; | |||||
| }; | |||||
| } // namespace ge | |||||
| using ModelPtr = std::shared_ptr<ge::Model>; | |||||
| #endif // INC_GRAPH_MODEL_H_ | |||||
| @@ -1,52 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_MODEL_SERIALIZE_H_ | |||||
| #define INC_GRAPH_MODEL_SERIALIZE_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include "graph/buffer.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/model.h" | |||||
| namespace ge { | |||||
| class ModelSerialize { | |||||
| public: | |||||
| Buffer SerializeModel(const Model &model, bool is_dump = false); | |||||
| Model UnserializeModel(const uint8_t *data, size_t len); | |||||
| Model UnserializeModel(ge::proto::ModelDef &model_def); | |||||
| Buffer SerializeGraph(const ComputeGraphPtr &graph); | |||||
| ComputeGraphPtr UnserializeGraph(const uint8_t *data, size_t len); | |||||
| Buffer SerializeOpDesc(const ConstOpDescPtr &opDesc); | |||||
| OpDescPtr UnserializeOpDesc(const uint8_t *data, size_t len); | |||||
| size_t GetSerializeModelSize(const Model &model); | |||||
| private: | |||||
| static std::map<std::string, GeAttrValue> &MutableTensorDescAttrMap(GeTensorDesc &tensorDesc); | |||||
| static const std::map<std::string, GeAttrValue> &GetTensorDescAttrMap(const GeTensorDesc &tensorDesc); | |||||
| friend class ModelSerializeImp; | |||||
| friend class GraphDebugImp; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_MODEL_SERIALIZE_H_ | |||||
| @@ -1,213 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_NODE_H_ | |||||
| #define INC_GRAPH_NODE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <unordered_set> | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "utils/attr_utils.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/range_vistor.h" | |||||
| namespace ge { | |||||
| class ComputeGraph; | |||||
| using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||||
| class Node; | |||||
| using NodePtr = std::shared_ptr<Node>; | |||||
| using ConstNodePtr = std::shared_ptr<const Node>; | |||||
| using NodeRef = std::weak_ptr<Node>; | |||||
| class Anchor; | |||||
| using AnchorPtr = std::shared_ptr<Anchor>; | |||||
| class InDataAnchor; | |||||
| using InDataAnchorPtr = std::shared_ptr<InDataAnchor>; | |||||
| class OutDataAnchor; | |||||
| using OutDataAnchorPtr = std::shared_ptr<OutDataAnchor>; | |||||
| class ControlAnchor; | |||||
| using ControlAnchorPtr = std::shared_ptr<ControlAnchor>; | |||||
| class InControlAnchor; | |||||
| using InControlAnchorPtr = std::shared_ptr<InControlAnchor>; | |||||
| class OutControlAnchor; | |||||
| using OutControlAnchorPtr = std::shared_ptr<OutControlAnchor>; | |||||
| using OpDescPtr = std::shared_ptr<OpDesc>; | |||||
| using ConstNode = const Node; | |||||
| typedef std::vector<std::multimap<std::string, ge::AnchorPtr>> kFusionDataFlowVec_t; | |||||
| // Node is a component of ComputeGraph | |||||
| class Node : public std::enable_shared_from_this<Node> { | |||||
| friend class ComputeGraph; | |||||
| friend class ModelSerializeImp; | |||||
| public: | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstNode>>; | |||||
| ~Node(); | |||||
| Node(const Node &) = delete; | |||||
| Node &operator=(const Node &) = delete; | |||||
| bool operator==(const Node &r_node) const; | |||||
| protected: | |||||
| Node() = default; | |||||
| Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph); | |||||
| public: | |||||
| graphStatus Init(); | |||||
| std::string GetName() const; | |||||
| std::string GetType() const; | |||||
| ComputeGraphPtr GetOwnerComputeGraph() const; | |||||
| graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); | |||||
| Vistor<InDataAnchorPtr> GetAllInDataAnchors() const; | |||||
| Vistor<OutDataAnchorPtr> GetAllOutDataAnchors() const; | |||||
| uint32_t GetAllInDataAnchorsSize() const; | |||||
| uint32_t GetAllOutDataAnchorsSize() const; | |||||
| Vistor<AnchorPtr> GetAllOutAnchors() const; | |||||
| Vistor<AnchorPtr> GetAllInAnchors() const; | |||||
| InDataAnchorPtr GetInDataAnchor(int idx) const; | |||||
| OutDataAnchorPtr GetOutDataAnchor(int idx) const; | |||||
| InControlAnchorPtr GetInControlAnchor() const; | |||||
| OutControlAnchorPtr GetOutControlAnchor() const; | |||||
| Vistor<NodePtr> GetInNodes() const; | |||||
| Vistor<NodePtr> GetOutNodes() const; | |||||
| AnchorPtr GetInAnchor(int idx) const; | |||||
| AnchorPtr GetOutAnchor(int idx) const; | |||||
| bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | |||||
| // All in Data nodes | |||||
| Vistor<NodePtr> GetInDataNodes() const; | |||||
| // All in Control nodes | |||||
| Vistor<NodePtr> GetInControlNodes() const; | |||||
| // GetInAllNodes = InDataNodes + InControlNodes | |||||
| Vistor<NodePtr> GetInAllNodes() const; | |||||
| // All out Data nodes | |||||
| Vistor<NodePtr> GetOutDataNodes() const; | |||||
| uint32_t GetOutDataNodesSize() const; | |||||
| // All out Control nodes | |||||
| Vistor<NodePtr> GetOutControlNodes() const; | |||||
| // GetOutAllNodes = OutDataNodes + InControlNodes | |||||
| Vistor<NodePtr> GetOutAllNodes() const; | |||||
| // Get all in data nodes and its out-anchor | |||||
| Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | |||||
| // Get all out data nodes and its in-anchor | |||||
| Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | |||||
| graphStatus InferShapeAndType() const; | |||||
| graphStatus Verify() const; | |||||
| graphStatus InferOriginFormat() const; | |||||
| OpDescPtr GetOpDesc() const; | |||||
| graphStatus UpdateOpDesc(const OpDescPtr &op); | |||||
| graphStatus AddLinkFrom(const NodePtr &input_node); | |||||
| graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node); | |||||
| graphStatus AddLinkFrom(const string &name, NodePtr input_node); | |||||
| graphStatus AddLinkFromForParse(const NodePtr &input_node); | |||||
| void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); } | |||||
| void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); } | |||||
| const std::vector<uint32_t> &GetSendEventIdList() const { return send_event_id_list_; } | |||||
| const std::vector<uint32_t> &GetRecvEventIdList() const { return recv_event_id_list_; } | |||||
| void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { | |||||
| fusion_input_list = fusion_input_dataflow_list_; | |||||
| } | |||||
| void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { | |||||
| fusion_output_list = fusion_output_dataflow_list_; | |||||
| } | |||||
| void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { | |||||
| fusion_input_dataflow_list_ = fusion_input_list; | |||||
| } | |||||
| void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { | |||||
| fusion_output_dataflow_list_ = fusion_output_list; | |||||
| } | |||||
| bool GetHostNode() const { return host_node_; } | |||||
| void SetHostNode(bool is_host) { host_node_ = is_host; } | |||||
| void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | |||||
| NodePtr GetOrigNode() { return orig_node_; } | |||||
| private: | |||||
| bool NodeMembersAreEqual(const Node &r_node) const; | |||||
| bool NodeAttrsAreEqual(const Node &r_node) const; | |||||
| bool NodeInConnectsAreEqual(const Node &r_node) const; | |||||
| bool NodeOutConnectsAreEqual(const Node &r_node) const; | |||||
| bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const; | |||||
| OpDescPtr op_; | |||||
| std::weak_ptr<ComputeGraph> owner_graph_; | |||||
| vector<InDataAnchorPtr> in_data_anchors_; | |||||
| vector<OutDataAnchorPtr> out_data_anchors_; | |||||
| InControlAnchorPtr in_control_anchor_; | |||||
| OutControlAnchorPtr out_control_anchor_; | |||||
| map<string, GeAttrValue> attrs_; // lint !e1073 | |||||
| bool has_init_{false}; | |||||
| bool host_node_{false}; | |||||
| bool anchor_status_updated_{false}; | |||||
| std::vector<uint32_t> send_event_id_list_; | |||||
| std::vector<uint32_t> recv_event_id_list_; | |||||
| kFusionDataFlowVec_t fusion_input_dataflow_list_; | |||||
| kFusionDataFlowVec_t fusion_output_dataflow_list_; | |||||
| NodePtr orig_node_; | |||||
| friend class NodeUtils; | |||||
| friend class OnnxUtils; | |||||
| friend class TuningUtils; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_NODE_H_ | |||||
| @@ -1,328 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_OP_DESC_H_ | |||||
| #define INC_GRAPH_OP_DESC_H_ | |||||
| #include <functional> | |||||
| #include <algorithm> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_set> | |||||
| #include <vector> | |||||
| #include "detail/attributes_holder.h" | |||||
| #include "graph/range_vistor.h" | |||||
| #define DYNAMIN_INPUT_NAME(name, index) (((name)) + std::to_string((index))) | |||||
| #define DYNAMIN_OUTPUT_NAME(name, index) (((name)) + std::to_string((index))) | |||||
| namespace ge { | |||||
| using std::map; | |||||
| using std::pair; | |||||
| using std::shared_ptr; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| class Operator; | |||||
| class GeTensorDesc; | |||||
| using GeTensorDescPtr = shared_ptr<GeTensorDesc>; | |||||
| using ConstGeTensorDescPtr = shared_ptr<const GeTensorDesc>; | |||||
| class OpDesc; | |||||
| using OpDescPtr = shared_ptr<OpDesc>; | |||||
| using ConstOpDescPtr = shared_ptr<const OpDesc>; | |||||
| class GeAttrValue; | |||||
| using ConstOpDesc = const OpDesc; | |||||
| enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; | |||||
| class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| public: | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, shared_ptr<ConstOpDesc>>; | |||||
| friend class GraphBuilderImpl; | |||||
| friend class OperatorImpl; | |||||
| OpDesc(const string &name, const string &type); | |||||
| OpDesc(); | |||||
| ~OpDesc(); | |||||
| bool operator==(const OpDesc &r_op_desc) const; | |||||
| string GetName() const; | |||||
| void SetName(const string &name); | |||||
| string GetType() const; | |||||
| void SetType(const string &type); | |||||
| graphStatus AddInputDesc(const GeTensorDesc &input_desc); | |||||
| graphStatus AddInputDesc(const string &name, const GeTensorDesc &input_desc); | |||||
| graphStatus AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_desc); | |||||
| graphStatus AddInputDescForward(const string &name, const unsigned int num); | |||||
| graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); | |||||
| graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); | |||||
| graphStatus AddOutputDescForward(const string &name, const unsigned int num); | |||||
| graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | |||||
| graphStatus UpdateInputDesc(uint32_t index, const GeTensorDesc &tensor_desc); | |||||
| graphStatus UpdateInputDesc(const string &name, const GeTensorDesc &tensor_desc); | |||||
| bool InputIsSet(const string &name) const; | |||||
| GeTensorDesc GetInputDesc(uint32_t index) const; | |||||
| GeTensorDesc GetInputDesc(const string &name) const; | |||||
| Vistor<string> GetAllInputNames() const; | |||||
| GeTensorDescPtr MutableInputDesc(uint32_t index) const; | |||||
| GeTensorDescPtr MutableInputDesc(const string &name) const; | |||||
| Vistor<GeTensorDesc> GetAllInputsDesc() const; | |||||
| Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const; | |||||
| size_t GetInputsSize() const; | |||||
| size_t GetAllInputsSize() const; | |||||
| graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | |||||
| graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | |||||
| graphStatus UpdateOutputDesc(uint32_t index, const GeTensorDesc &tensor_desc); | |||||
| graphStatus UpdateOutputDesc(const string &name, const GeTensorDesc &tensor_desc); | |||||
| GeTensorDesc GetOutputDesc(uint32_t index) const; | |||||
| GeTensorDesc GetOutputDesc(const string &name) const; | |||||
| GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | |||||
| GeTensorDescPtr MutableOutputDesc(const string &name) const; | |||||
| uint32_t GetAllOutputsDescSize() const; | |||||
| Vistor<GeTensorDesc> GetAllOutputsDesc() const; | |||||
| Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | |||||
| size_t GetOutputsSize() const; | |||||
| ConstGeTensorDescPtr GetOutputDescPtr(uint32_t index) const; | |||||
| ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | |||||
| ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; | |||||
| ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; | |||||
| graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||||
| graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); | |||||
| graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||||
| bool IsOptionalInput(const string &name) const; | |||||
| bool IsOptionalInput(uint32_t index) const; | |||||
| std::map<string, uint32_t> GetAllInputName() const; | |||||
| std::map<string, uint32_t> GetAllOutputName(); | |||||
| bool UpdateInputName(std::map<string, uint32_t> inputNameIdx); | |||||
| bool UpdateOutputName(std::map<string, uint32_t> outputNameIdx); | |||||
| void AddInferFunc(const std::function<graphStatus(Operator &)> &func); | |||||
| std::function<graphStatus(Operator &)> GetInferFunc() const; | |||||
| graphStatus InferShapeAndType(); | |||||
| void AddInferFormatFunc(const std::function<graphStatus(Operator &)> &func); | |||||
| std::function<graphStatus(Operator &)> GetInferFormatFunc() const; | |||||
| graphStatus DefaultInferFormat(); | |||||
| std::function<graphStatus(Operator &)> GetVerifyFunc() const; | |||||
| void AddVerifierFunc(const std::function<graphStatus(Operator &)> &func); | |||||
| graphStatus CallInferFormatFunc(Operator &op); | |||||
| graphStatus OpVerify(); | |||||
| graphStatus CommonVerify() const; | |||||
| graphStatus AddRegisterInputName(const string &name); | |||||
| graphStatus AddRegisterOutputName(const string &name); | |||||
| vector<string> GetRegisterInputName() const; | |||||
| vector<string> GetRegisterOutputName() const; | |||||
| using AttrHolder::AddRequiredAttr; | |||||
| using AttrHolder::DelAttr; | |||||
| using AttrHolder::GetAllAttrNames; | |||||
| using AttrHolder::GetAllAttrs; | |||||
| using AttrHolder::GetAttr; | |||||
| using AttrHolder::HasAttr; | |||||
| using AttrHolder::SetAttr; | |||||
| void SetId(int64_t id); | |||||
| int64_t GetId() const; | |||||
| void SetStreamId(int64_t stream_id); | |||||
| int64_t GetStreamId() const; | |||||
| void SetInputName(const vector<string> &input_name); | |||||
| vector<string> GetInputName() const; | |||||
| void SetSrcName(const vector<string> &src_name); | |||||
| vector<string> GetSrcName() const; | |||||
| void SetSrcIndex(const vector<int64_t> &src_index); | |||||
| vector<int64_t> GetSrcIndex() const; | |||||
| void SetInputOffset(const vector<int64_t> &input); | |||||
| vector<int64_t> GetInputOffset() const; | |||||
| void SetOutputOffset(const vector<int64_t> &input); | |||||
| vector<int64_t> GetOutputOffset() const; | |||||
| void SetDstName(const vector<string> &dst_name); | |||||
| vector<string> GetDstName() const; | |||||
| void SetDstIndex(const vector<int64_t> &dst_index); | |||||
| vector<int64_t> GetDstIndex() const; | |||||
| void SetWorkspace(const vector<int64_t> &workspace); | |||||
| vector<int64_t> GetWorkspace() const; | |||||
| void SetWorkspaceBytes(const vector<int64_t> &workspace_bytes); | |||||
| vector<int64_t> GetWorkspaceBytes() const; | |||||
| void SetIsInputConst(const vector<bool> &is_input_const); | |||||
| vector<bool> GetIsInputConst() const; | |||||
| void SetOpInferDepends(const vector<string> &depend_names); | |||||
| vector<string> GetOpInferDepends() const; | |||||
| string GetInputNameByIndex(uint32_t index) const; | |||||
| int GetInputIndexByName(const string &name) const; | |||||
| string GetOutputNameByIndex(uint32_t index) const; | |||||
| int GetOutputIndexByName(const string &name) const; | |||||
| graphStatus RestoreInputNameIdx(const string &name, const int &index); | |||||
| graphStatus RestoreOutputNameIdx(const string &name, const int &index); | |||||
| graphStatus CallInferFunc(Operator &op); | |||||
| void SetOpKernelLibName(const std::string &name); | |||||
| std::string GetOpKernelLibName() const; | |||||
| void SetOpEngineName(const std::string &name); | |||||
| std::string GetOpEngineName() const; | |||||
| void RegisterSubgraphIrName(const std::string &name, SubgraphType type); | |||||
| const std::map<std::string, SubgraphType> &GetSubgraphIrNames() const; | |||||
| SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; | |||||
| graphStatus AddSubgraphName(const std::string &name); | |||||
| const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||||
| std::string GetSubgraphInstanceName(uint32_t index) const; | |||||
| const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||||
| /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, | |||||
| /// because this kind of functions will only append a new subgraph instance name | |||||
| /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. | |||||
| /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. | |||||
| /// \param index | |||||
| /// \param name | |||||
| /// \return | |||||
| graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); | |||||
| void RemoveSubgraphInstanceName(const std::string &name); | |||||
| graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; | |||||
| protected: | |||||
| ProtoAttrMapHelper MutableAttrMap() override; | |||||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||||
| private: | |||||
| OpDesc(const ProtoMsgOwner &proto_msg_owner, ge::proto::OpDef *op_def); | |||||
| bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; | |||||
| bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; | |||||
| bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | |||||
| GeIrProtoHelper<ge::proto::OpDef> op_def_; | |||||
| std::vector<std::string> subgraph_instance_names_; | |||||
| // subgraph names to index, for a `if` operator: | |||||
| // then_branch: 0 | |||||
| // else_branch: 1 | |||||
| // or for a `case` node: | |||||
| // branches0: 0 | |||||
| // branches1: 1 | |||||
| // branches2: 2 | |||||
| std::map<std::string, uint32_t> subgraph_names_to_index_; | |||||
| // subgraph ir names to type, for a `if` operator: | |||||
| // then_branch: static | |||||
| // else_branch: static | |||||
| // or for a `case` op: | |||||
| // branches: dynamic | |||||
| std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | |||||
| vector<GeTensorDescPtr> inputs_desc_{}; | |||||
| map<string, uint32_t> input_name_idx_{}; | |||||
| vector<string> register_input_name_{}; | |||||
| std::unordered_set<string> optional_input_names_{}; | |||||
| vector<GeTensorDescPtr> outputs_desc_{}; | |||||
| map<string, uint32_t> output_name_idx_{}; | |||||
| vector<string> register_output_name_{}; | |||||
| std::function<graphStatus(Operator &)> infer_func_ = nullptr; | |||||
| std::function<graphStatus(Operator &)> infer_format_func_ = nullptr; | |||||
| std::function<graphStatus(Operator &)> verifier_func_ = nullptr; | |||||
| string op_kernel_lib_name_; | |||||
| string engine_name_; | |||||
| friend class OpDescUtils; | |||||
| friend class ModelSerializeImp; | |||||
| friend class AttrUtils; | |||||
| friend class GeAttrValueImp; | |||||
| friend class OnnxUtils; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_OP_DESC_H_ | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_OP_KERNEL_BIN_H_ | |||||
| #define INC_GRAPH_OP_KERNEL_BIN_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| namespace ge { | |||||
| class OpKernelBin { | |||||
| public: | |||||
| OpKernelBin(std::string name, std::vector<char> &&data) : name_(std::move(name)), data_(std::move(data)) {} | |||||
| ~OpKernelBin() = default; | |||||
| const std::string &GetName() const { return name_; } | |||||
| const uint8_t *GetBinData() const { return (const uint8_t *)data_.data(); } | |||||
| size_t GetBinDataSize() const { return data_.size(); } | |||||
| OpKernelBin(const OpKernelBin &) = delete; | |||||
| const OpKernelBin &operator=(const OpKernelBin &) = delete; | |||||
| private: | |||||
| std::string name_; | |||||
| std::vector<char> data_; | |||||
| }; | |||||
| using OpKernelBinPtr = std::shared_ptr<OpKernelBin>; | |||||
| const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; | |||||
| const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel"; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_OP_KERNEL_BIN_H_ | |||||
| @@ -1,56 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ | |||||
| #define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/operator_factory.h" | |||||
| namespace ge { | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { | |||||
| public: | |||||
| static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); | |||||
| static graphStatus GetOpsTypeList(std::vector<std::string> &all_ops); | |||||
| static bool IsExistOp(const string &operator_type); | |||||
| static InferShapeFunc GetInferShapeFunc(const std::string &operator_type); | |||||
| static InferFormatFunc GetInferFormatFunc(const std::string &operator_type); | |||||
| static VerifyFunc GetVerifyFunc(const std::string &operator_type); | |||||
| static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator); | |||||
| static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func); | |||||
| static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func); | |||||
| static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | |||||
| static shared_ptr<std::map<string, OpCreator>> operator_creators_; | |||||
| static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | |||||
| static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | |||||
| static shared_ptr<std::map<string, VerifyFunc>> operator_verify_funcs_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ | |||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ | |||||
| #define INC_GRAPH_OPSPROTO_MANAGER_H_ | |||||
| #include <dirent.h> | |||||
| #include <dlfcn.h> | |||||
| #include <string.h> | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <mutex> | |||||
| namespace ge { | |||||
| class OpsProtoManager { | |||||
| public: | |||||
| static OpsProtoManager *Instance(); | |||||
| bool Initialize(const std::map<std::string, std::string> &options); | |||||
| void Finalize(); | |||||
| private: | |||||
| void LoadOpsProtoPluginSo(std::string &path); | |||||
| std::string pluginPath_; | |||||
| std::vector<void *> handles_; | |||||
| bool is_init_ = false; | |||||
| std::mutex mutex_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_OPSPROTO_MANAGER_H_ | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_RANGE_VISTOR_H_ | |||||
| #define INC_GRAPH_RANGE_VISTOR_H_ | |||||
| #include <vector> | |||||
| template <class E, class O> | |||||
| class RangeVistor { | |||||
| public: | |||||
| using Iterator = typename std::vector<E>::iterator; | |||||
| using ConstIterator = typename std::vector<E>::const_iterator; | |||||
| RangeVistor(O owner, const std::vector<E> &vs) : owner_(owner), elements_(vs) {} | |||||
| ~RangeVistor() {} | |||||
| Iterator begin() { return elements_.begin(); } | |||||
| Iterator end() { return elements_.end(); } | |||||
| ConstIterator begin() const { return elements_.begin(); } | |||||
| ConstIterator end() const { return elements_.end(); } | |||||
| std::size_t size() const { return elements_.size(); } | |||||
| bool empty() const { return elements_.empty(); } | |||||
| E &at(std::size_t index) { return elements_.at(index); } | |||||
| const E &at(std::size_t index) const { return elements_.at(index); } | |||||
| private: | |||||
| O owner_; | |||||
| std::vector<E> elements_; | |||||
| }; | |||||
| #endif // INC_GRAPH_RANGE_VISTOR_H_ | |||||
| @@ -1,79 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMMON_GRAPH_REF_RELATION_H_ | |||||
| #define COMMON_GRAPH_REF_RELATION_H_ | |||||
| #include <deque> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "node.h" | |||||
| namespace ge { | |||||
| enum InOutFlag { | |||||
| NODE_IN = 0, // input flag | |||||
| NODE_OUT = 1, // output flag | |||||
| }; | |||||
| struct RefCell { | |||||
| std::string node_name; | |||||
| ge::NodePtr node = nullptr; | |||||
| InOutFlag in_out = NODE_IN; | |||||
| int in_out_idx = 0; | |||||
| bool operator==(const RefCell &c) const { | |||||
| return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; | |||||
| } | |||||
| RefCell() = default; | |||||
| RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { | |||||
| node_name = name; | |||||
| node = node_ptr; | |||||
| in_out = in_out_flag; | |||||
| in_out_idx = idx; | |||||
| }; | |||||
| ~RefCell() = default; | |||||
| }; | |||||
| struct RefCellHash { | |||||
| size_t operator()(const RefCell &c) const { | |||||
| unsigned long number = reinterpret_cast<unsigned long>(reinterpret_cast<uintptr_t>(c.node.get())); | |||||
| string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); | |||||
| return std::hash<string>()(tmp); | |||||
| } | |||||
| }; | |||||
| class RefRelations { | |||||
| public: | |||||
| graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result); | |||||
| graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||||
| graphStatus Clear(); | |||||
| RefRelations(); | |||||
| ~RefRelations() = default; | |||||
| public: | |||||
| class Impl; | |||||
| std::shared_ptr<Impl> impl_ = nullptr; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // COMMON_GRAPH_REF_RELATION_H_ | |||||
| @@ -1,46 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ | |||||
| #define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <vector> | |||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "external/graph/tensor.h" | |||||
| namespace ge { | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { | |||||
| public: | |||||
| static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); | |||||
| static graphStatus CreateContext(const std::string &context_id); | |||||
| static void DestroyContext(const std::string &context_id); | |||||
| graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); | |||||
| graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); | |||||
| private: | |||||
| std::map<int64_t, std::vector<Tensor>> tensors_; | |||||
| std::mutex mu_; | |||||
| static std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> contexts_; | |||||
| static std::mutex ctx_mu_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_SHAPE_REFINER_H_ | |||||
| #define INC_GRAPH_SHAPE_REFINER_H_ | |||||
| #include <string> | |||||
| #include "external/graph/inference_context.h" | |||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| // ShapeRefiner performs shape inference for compute graphs | |||||
| class ShapeRefiner { | |||||
| public: | |||||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); | |||||
| static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); | |||||
| static graphStatus InferShapeAndType(const NodePtr &node); | |||||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||||
| static void ClearContextMap(); | |||||
| private: | |||||
| static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_SHAPE_REFINER_H_ | |||||
| @@ -1,130 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_TUNING_UTILS_H | |||||
| #define MAIN_TUNING_UTILS_H | |||||
| #include <fcntl.h> | |||||
| #include <sys/stat.h> | |||||
| #include <sys/types.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | |||||
| #include <cstring> | |||||
| #include <fstream> | |||||
| #include <iomanip> | |||||
| #include <queue> | |||||
| #include <mutex> | |||||
| #include <graph/anchor.h> | |||||
| #include <graph/detail/attributes_holder.h> | |||||
| #include <graph/ge_tensor.h> | |||||
| #include <graph/graph.h> | |||||
| #include <graph/model.h> | |||||
| #include <graph/node.h> | |||||
| #include <graph/utils/graph_utils.h> | |||||
| #include <graph/utils/type_utils.h> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "utils/attr_utils.h" | |||||
| #include "utils/node_utils.h" | |||||
| #include "external/ge/ge_api_types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| namespace ge { | |||||
| // Configure build mode, default value is "normal" | |||||
| const char *const BUILD_MODE = "ge.buildMode"; | |||||
| const char *const BUILD_STEP = "ge.buildStep"; | |||||
| // Configure tuning path | |||||
| const char *const TUNING_PATH = "ge.tuningPath"; | |||||
| // for interface: aclgrphBuildModel | |||||
| const std::set<std::string> ir_builder_supported_options_for_lx_fusion = {BUILD_MODE, BUILD_STEP, TUNING_PATH}; | |||||
| // Build model | |||||
| const char *const BUILD_MODE_NORMAL = "normal"; | |||||
| const char *const BUILD_MODE_TUNING = "tuning"; | |||||
| const char *const BUILD_MODE_BASELINE = "baseline"; | |||||
| const std::set<std::string> build_mode_options = {BUILD_MODE_NORMAL, BUILD_MODE_TUNING, BUILD_MODE_BASELINE}; | |||||
| // Build step | |||||
| const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match"; | |||||
| const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match"; | |||||
| const char *const BUILD_STEP_AFTER_BUILDER = "after_builder"; | |||||
| const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub"; | |||||
| const char *const BUILD_STEP_AFTER_MERGE = "after_merge"; | |||||
| const std::set<std::string> build_step_options = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_UB_MATCH, | |||||
| BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB, | |||||
| BUILD_STEP_AFTER_MERGE}; | |||||
| using SubgraphCreateOutNode = std::unordered_map<ComputeGraphPtr, NodePtr>; | |||||
| using NodetoNodeMap = std::unordered_map<NodePtr, NodePtr>; | |||||
| using NodeSet = std::set<NodePtr>; | |||||
| using NodeNametoNodeNameMap = std::unordered_map<std::string, std::string>; | |||||
| using NodetoNodeNameMap = std::unordered_map<NodePtr, std::string>; | |||||
| class TuningUtils { | |||||
| public: | |||||
| TuningUtils() = default; | |||||
| ~TuningUtils() = default; | |||||
| // Dump all the subgraphs and modify | |||||
| // the subgraphs in them to be executable subgraphs if exe_flag is true | |||||
| // `tuning_path` means path to save the graphs | |||||
| static graphStatus ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs, | |||||
| std::vector<ComputeGraphPtr> non_tuning_subgraphs = {}, bool exe_flag = false, | |||||
| const std::string &path = "", const std::string &user_path = ""); | |||||
| // Recovery `graph` from graph dump files configured in options | |||||
| static graphStatus ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph); | |||||
| private: | |||||
| // part 1 | |||||
| struct HelpInfo { | |||||
| int64_t index; | |||||
| bool exe_flag; | |||||
| bool is_tuning_graph; | |||||
| const std::string &path; | |||||
| const std::string &user_path; | |||||
| }; | |||||
| static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info); | |||||
| static graphStatus HandlePld(NodePtr &node); | |||||
| static graphStatus HandleEnd(NodePtr &node); | |||||
| static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node); | |||||
| static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node); | |||||
| static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node); | |||||
| static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node); | |||||
| static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node); | |||||
| static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node); | |||||
| static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node); | |||||
| static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path); | |||||
| static SubgraphCreateOutNode create_output_; | |||||
| // part 2 | |||||
| static graphStatus MergeAllSubGraph(std::vector<ComputeGraphPtr> &graphs, ComputeGraphPtr &graph); | |||||
| static graphStatus MergeSubGraph(ComputeGraphPtr &graph); | |||||
| // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 | |||||
| static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); | |||||
| static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, | |||||
| AnchorPtr &src_out_anchor); | |||||
| static NodeNametoNodeNameMap data_2_netoutput_; | |||||
| static NodetoNodeNameMap data_node_2_netoutput_; | |||||
| static NodetoNodeMap data_node_2_netoutput_node_; | |||||
| static NodeSet netoutput_nodes_; | |||||
| static NodeSet merged_graph_nodes_; | |||||
| static std::mutex mutex_; | |||||
| // for debug | |||||
| static std::string PrintCheckLog(); | |||||
| static std::string GetNodeNameByAnchor(const Anchor *anchor); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // MAIN_TUNING_UTILS_H | |||||
| @@ -1,133 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_USR_TYPES_H_ | |||||
| #define INC_GRAPH_USR_TYPES_H_ | |||||
| #include <atomic> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| namespace ge { | |||||
| #define USR_TYPE_DEC(type, name) \ | |||||
| inline void set_##name(const type &value) { name = value; } \ | |||||
| type *mutable_##name() { return &name; } | |||||
| #define USR_TYPE_HAS_DEC(type, name) \ | |||||
| inline void set_##name(const type &value) { name = value; } \ | |||||
| \ | |||||
| private: \ | |||||
| bool has_mutable_##name{false}; \ | |||||
| \ | |||||
| public: \ | |||||
| bool has_##name() const { return (has_mutable_##name) || QuantizeFactorHasData(name); } \ | |||||
| type *mutable_##name() { \ | |||||
| has_mutable_##name = true; \ | |||||
| return &name; \ | |||||
| } | |||||
| #define USR_TYPE_BYTES_DEC(name) \ | |||||
| inline void clear_##name() { name.clear(); } \ | |||||
| inline void set_##name(const void *value, size_t size) { \ | |||||
| name.assign(reinterpret_cast<uint8_t *>(const_cast<void *>(value)), \ | |||||
| reinterpret_cast<uint8_t *>(const_cast<void *>(value)) + size); \ | |||||
| } | |||||
| enum UsrQuantizeScaleType { USR_VECTOR_SCALE = 0, USR_SCALAR_SCALE = 1 }; | |||||
| enum UsrQuantizeScaleMode { USR_NORMAL_MODE = 0, USR_SQRT_MODE = 1 }; | |||||
| enum UsrQuantizeAlgorithm { | |||||
| USR_NON_OFFSET_ALGO = 0, | |||||
| USR_HALF_OFFSET_ALGO = 1, | |||||
| USR_ALL_OFFSET_ALGO = 2, | |||||
| }; | |||||
| struct UsrQuantizeFactor { | |||||
| public: | |||||
| // QuantizeScaleMode scale_mode; | |||||
| UsrQuantizeScaleMode scale_mode{USR_NORMAL_MODE}; | |||||
| std::vector<uint8_t> scale_value; | |||||
| int64_t scale_offset{0}; | |||||
| std::vector<uint8_t> offset_data_value; | |||||
| int64_t offset_data_offset{0}; | |||||
| std::vector<uint8_t> offset_weight_value; | |||||
| int64_t offset_weight_offset{0}; | |||||
| std::vector<uint8_t> offset_pad_value; | |||||
| int64_t offset_pad_offset{0}; | |||||
| USR_TYPE_DEC(UsrQuantizeScaleMode, scale_mode); | |||||
| USR_TYPE_BYTES_DEC(scale_value); | |||||
| USR_TYPE_DEC(int64_t, scale_offset); | |||||
| USR_TYPE_BYTES_DEC(offset_data_value); | |||||
| USR_TYPE_DEC(int64_t, offset_data_offset); | |||||
| USR_TYPE_BYTES_DEC(offset_weight_value); | |||||
| USR_TYPE_DEC(int64_t, offset_weight_offset); | |||||
| USR_TYPE_BYTES_DEC(offset_pad_value); | |||||
| USR_TYPE_DEC(int64_t, offset_pad_offset); | |||||
| }; | |||||
| static inline bool QuantizeFactorHasData(const UsrQuantizeFactor &factor) { | |||||
| return factor.scale_value.size() > 0 || factor.offset_data_value.size() > 0 || | |||||
| factor.offset_weight_value.size() > 0 || factor.offset_pad_value.size() > 0; | |||||
| } | |||||
| struct UsrQuantizeCalcFactor { | |||||
| public: | |||||
| std::vector<uint8_t> offsetw; | |||||
| int64_t offsetw_offset{0}; | |||||
| std::vector<uint8_t> offsetd; | |||||
| int64_t offsetd_offset{0}; | |||||
| std::vector<uint8_t> scalereq; | |||||
| int64_t scaledreq_offset{0}; | |||||
| std::vector<uint8_t> offsetdnext; | |||||
| int64_t offsetdnext_offset{0}; | |||||
| USR_TYPE_BYTES_DEC(offsetw); | |||||
| USR_TYPE_DEC(int64_t, offsetw_offset); | |||||
| USR_TYPE_BYTES_DEC(offsetd); | |||||
| USR_TYPE_DEC(int64_t, offsetd_offset); | |||||
| USR_TYPE_BYTES_DEC(scalereq); | |||||
| USR_TYPE_DEC(int64_t, scaledreq_offset); | |||||
| USR_TYPE_BYTES_DEC(offsetdnext); | |||||
| USR_TYPE_DEC(int64_t, offsetdnext_offset); | |||||
| }; | |||||
| static inline bool QuantizeFactorHasData(const UsrQuantizeCalcFactor &factor) { | |||||
| return factor.offsetw.size() > 0 || factor.offsetd.size() > 0 || factor.scalereq.size() > 0 || | |||||
| factor.offsetdnext.size() > 0; | |||||
| } | |||||
| struct UsrQuantizeFactorParams { | |||||
| UsrQuantizeAlgorithm quantize_algo{USR_NON_OFFSET_ALGO}; | |||||
| UsrQuantizeScaleType scale_type{USR_VECTOR_SCALE}; | |||||
| UsrQuantizeFactor quantize_param; | |||||
| UsrQuantizeFactor dequantize_param; | |||||
| UsrQuantizeFactor requantize_param; | |||||
| UsrQuantizeCalcFactor quantizecalc_param; | |||||
| USR_TYPE_DEC(UsrQuantizeAlgorithm, quantize_algo); | |||||
| USR_TYPE_DEC(UsrQuantizeScaleType, scale_type); | |||||
| USR_TYPE_HAS_DEC(UsrQuantizeFactor, quantize_param); | |||||
| USR_TYPE_HAS_DEC(UsrQuantizeFactor, dequantize_param); | |||||
| USR_TYPE_HAS_DEC(UsrQuantizeFactor, requantize_param); | |||||
| USR_TYPE_HAS_DEC(UsrQuantizeCalcFactor, quantizecalc_param); | |||||
| }; | |||||
| #undef USR_TYPE_DEC | |||||
| #undef USR_TYPE_HAS_DEC | |||||
| #undef USR_TYPE_BYTES_DEC | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_USR_TYPES_H_ | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| class AnchorUtils { | |||||
| public: | |||||
| // Get anchor format | |||||
| static Format GetFormat(const DataAnchorPtr &dataAnchor); | |||||
| // Set anchor format | |||||
| static graphStatus SetFormat(const DataAnchorPtr &dataAnchor, Format dataFormat); | |||||
| // Get anchor status | |||||
| static AnchorStatus GetStatus(const DataAnchorPtr &dataAnchor); | |||||
| // Set anchor status | |||||
| static graphStatus SetStatus(const DataAnchorPtr &dataAnchor, AnchorStatus anchorStatus); | |||||
| static bool HasControlEdge(const AnchorPtr &anchor); | |||||
| static bool IsControlEdge(const AnchorPtr &src, const AnchorPtr &dst); | |||||
| static int GetIdx(const AnchorPtr &anchor); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ | |||||
| @@ -1,150 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_ATTR_UTILS_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/detail/attributes_holder.h" | |||||
| #include "graph/ge_attr_value.h" | |||||
| #include "graph/types.h" | |||||
| namespace ge { | |||||
| class OpDesc; | |||||
| using OpDescPtr = std::shared_ptr<OpDesc>; | |||||
| using ConstOpDescPtr = std::shared_ptr<const OpDesc>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
| public: | |||||
| class ConstAttrHolderAdapter; | |||||
| class AttrHolderAdapter; | |||||
| // Set | |||||
| static bool HasAttr(ConstAttrHolderAdapter &&obj, const string &name); | |||||
| static bool SetInt(AttrHolderAdapter &&obj, const string &name, const int64_t &value); | |||||
| static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int64_t> &value); | |||||
| static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<uint32_t> &value); | |||||
| static bool SetListInt(AttrHolderAdapter &&obj, const string &name, const vector<int32_t> &value); | |||||
| static bool SetListInt(AttrHolderAdapter &&obj, const string &name, std::initializer_list<int64_t> &&value); | |||||
| static bool SetFloat(AttrHolderAdapter &&obj, const string &name, const float &value); | |||||
| static bool SetListFloat(AttrHolderAdapter &&obj, const string &name, const vector<float> &value); | |||||
| static bool SetBool(AttrHolderAdapter &&obj, const string &name, const bool &value); | |||||
| static bool SetListBool(AttrHolderAdapter &&obj, const string &name, const vector<bool> &value); | |||||
| static bool SetStr(AttrHolderAdapter &&obj, const string &name, const string &value); | |||||
| static bool SetListStr(AttrHolderAdapter &&obj, const string &name, const vector<string> &value); | |||||
| static bool SetTensorDesc(AttrHolderAdapter &&obj, const string &name, const GeTensorDesc &value); | |||||
| static bool SetListTensorDesc(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorDesc> &value); | |||||
| static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensorPtr &value); | |||||
| static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const ConstGeTensorPtr &value); | |||||
| static bool SetTensor(AttrHolderAdapter &&obj, const string &name, const GeTensor &value); | |||||
| static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensorPtr> &value); | |||||
| static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<ConstGeTensorPtr> &value); | |||||
| static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, | |||||
| std::initializer_list<ConstGeTensorPtr> &&value); | |||||
| static bool SetListTensor(AttrHolderAdapter &&obj, const string &name, const vector<GeTensor> &value); | |||||
| static bool SetGraph(AttrHolderAdapter &&obj, const string &name, const ComputeGraphPtr &value); | |||||
| static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value); | |||||
| static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); | |||||
| static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value); | |||||
| static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); | |||||
| static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, | |||||
| const vector<GeAttrValue::NAMED_ATTRS> &value); | |||||
| static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value); | |||||
| static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value); | |||||
| // Get | |||||
| static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int64_t &value); | |||||
| static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, int32_t &value); | |||||
| static bool GetInt(ConstAttrHolderAdapter &&obj, const string &name, uint32_t &value); | |||||
| static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int64_t> &value); | |||||
| static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<int32_t> &value); | |||||
| static bool GetListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<uint32_t> &value); | |||||
| static bool GetFloat(ConstAttrHolderAdapter &&obj, const string &name, float &value); | |||||
| static bool GetListFloat(ConstAttrHolderAdapter &&obj, const string &name, vector<float> &value); | |||||
| static bool GetBool(ConstAttrHolderAdapter &&obj, const string &name, bool &value); | |||||
| static bool GetListBool(ConstAttrHolderAdapter &&obj, const string &name, vector<bool> &value); | |||||
| static bool GetStr(ConstAttrHolderAdapter &&obj, const string &name, string &value); | |||||
| static bool GetListStr(ConstAttrHolderAdapter &&obj, const string &name, vector<string> &value); | |||||
| static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, GeTensorDesc &value); | |||||
| static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<GeTensorDesc> &value); | |||||
| static bool GetTensor(ConstAttrHolderAdapter &&obj, const string &name, ConstGeTensorPtr &value); | |||||
| static bool MutableTensor(AttrHolderAdapter &&obj, const string &name, GeTensorPtr &value); | |||||
| static bool GetListTensor(ConstAttrHolderAdapter &&obj, const string &name, vector<ConstGeTensorPtr> &value); | |||||
| static bool MutableListTensor(AttrHolderAdapter &&obj, const string &name, vector<GeTensorPtr> &value); | |||||
| static bool GetGraph(ConstAttrHolderAdapter &&obj, const string &name, ComputeGraphPtr &value); | |||||
| static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value); | |||||
| static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); | |||||
| static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value); | |||||
| static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); | |||||
| static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, | |||||
| vector<GeAttrValue::NAMED_ATTRS> &value); | |||||
| static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value); | |||||
| // Value will be moved | |||||
| static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | |||||
| static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | |||||
| // Value will be moved | |||||
| static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||||
| static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||||
| static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | |||||
| static bool GetListListInt(ConstAttrHolderAdapter &&obj, const string &name, vector<vector<int64_t>> &value); | |||||
| static bool SetListDataType(AttrHolderAdapter &&obj, const string &name, const vector<ge::DataType> &value); | |||||
| static bool GetListDataType(ConstAttrHolderAdapter &&obj, const string &name, vector<ge::DataType> &value); | |||||
| static bool SetDataType(AttrHolderAdapter &&obj, const string &name, const ge::DataType &value); | |||||
| static bool GetDataType(ConstAttrHolderAdapter &&obj, const string &name, ge::DataType &value); | |||||
| static OpDescPtr CloneOpDesc(const ConstOpDescPtr &orgOpDesc); | |||||
| static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | |||||
| static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); | |||||
| class AttrHolderAdapter { | |||||
| public: | |||||
| AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | |||||
| ~AttrHolderAdapter() {} | |||||
| template <class T> | |||||
| AttrHolderAdapter(const std::shared_ptr<T> &obj) : obj_(obj.get()) {} | |||||
| AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} | |||||
| operator bool() const { return obj_ != nullptr; } | |||||
| AttrHolder *operator->() { return obj_; } | |||||
| AttrHolder *get() { return obj_; } | |||||
| AttrHolder *obj_; | |||||
| }; | |||||
| class ConstAttrHolderAdapter { | |||||
| public: | |||||
| ConstAttrHolderAdapter(const AttrHolder *obj) : obj_(obj) {} | |||||
| ~ConstAttrHolderAdapter() {} | |||||
| template <class T> | |||||
| ConstAttrHolderAdapter(const std::shared_ptr<T> obj) : obj_(obj.get()) {} | |||||
| ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} | |||||
| operator bool() const { return obj_ != nullptr; } | |||||
| const AttrHolder *operator->() const { return obj_; } | |||||
| const AttrHolder *get() const { return obj_; } | |||||
| private: | |||||
| const AttrHolder *obj_; | |||||
| }; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_ATTR_UTILS_H_ | |||||
| @@ -1,771 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_GRAPH_UTILS_H_ | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <list> | |||||
| #include <unordered_map> | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/utils/anchor_utils.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/model.h" | |||||
| #define GE_DUMP(compute_graph, name) \ | |||||
| do { \ | |||||
| GraphUtils::DumpGEGraph(compute_graph, name); \ | |||||
| GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ | |||||
| uint64_t i = 0; \ | |||||
| for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ | |||||
| auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ | |||||
| GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ | |||||
| GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ | |||||
| do { \ | |||||
| DataType ret; \ | |||||
| attr.GetValue<DataType>(ret); \ | |||||
| } while (0) | |||||
| #define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ | |||||
| do { \ | |||||
| if (value_type == VT_ENUM) { \ | |||||
| REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ | |||||
| stream << ret; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ | |||||
| do { \ | |||||
| if (value_type == VT_ENUM) { \ | |||||
| REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ | |||||
| stream << "["; \ | |||||
| for (int i = 0; i < ret.size(); i++) { \ | |||||
| stream << ret[i]; \ | |||||
| if (i + 1 != ret.size()) stream << ", "; \ | |||||
| } \ | |||||
| stream << "]"; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ | |||||
| else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) | |||||
| #define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ | |||||
| else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) | |||||
| #define PRINT_SHAPE(i_o, n, idx, stream) \ | |||||
| do { \ | |||||
| auto op = n->GetOpDesc(); \ | |||||
| GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \ | |||||
| auto shape = td.GetShape().GetDims(); \ | |||||
| stream << "["; \ | |||||
| for (int i = 0; i < shape.size(); i++) { \ | |||||
| stream << shape[i]; \ | |||||
| if (i + 1 < shape.size()) stream << ", "; \ | |||||
| } \ | |||||
| stream << "]"; \ | |||||
| } while (0) | |||||
| #define PRINT_ATTR_FUNC(stream) \ | |||||
| [&](GeAttrValue attr) { \ | |||||
| auto type = attr.GetValueType(); \ | |||||
| PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \ | |||||
| PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \ | |||||
| PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \ | |||||
| PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \ | |||||
| PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \ | |||||
| PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \ | |||||
| PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \ | |||||
| PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \ | |||||
| else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \ | |||||
| else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \ | |||||
| else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \ | |||||
| else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \ | |||||
| else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \ | |||||
| else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \ | |||||
| }; | |||||
| namespace ge { | |||||
| enum IOType { kIn, kOut }; | |||||
| struct NodeIndexIO { | |||||
| NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) | |||||
| : node_(std::move(node)), index_(index), io_type_(io_type) { | |||||
| if (node_ != nullptr) { | |||||
| value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); | |||||
| } | |||||
| } | |||||
| NodeIndexIO(ge::NodePtr node, int index, IOType io_type) | |||||
| : node_(std::move(node)), index_(static_cast<uint32_t>(index)), io_type_(io_type) { | |||||
| if (node_ != nullptr) { | |||||
| value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_); | |||||
| } | |||||
| } | |||||
| ~NodeIndexIO() {} | |||||
| NodePtr node_ = nullptr; | |||||
| uint32_t index_ = 0; | |||||
| IOType io_type_ = kOut; | |||||
| std::string value_; | |||||
| const std::string &ToString() const { return value_; } | |||||
| }; | |||||
| class GraphUtils { | |||||
| public: | |||||
| static ComputeGraphPtr GetComputeGraph(const Graph &graph); | |||||
| static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | |||||
| static graphStatus RecoverGraphOperators(const Graph &graph); | |||||
| static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | |||||
| static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||||
| static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst, | |||||
| const Format &dst_format); | |||||
| static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); | |||||
| static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); | |||||
| static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); | |||||
| // check whether src is link to dst and then remove | |||||
| static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||||
| static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); | |||||
| static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); | |||||
| static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); | |||||
| static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, | |||||
| const InDataAnchorPtr &new_dst); | |||||
| static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, | |||||
| const InControlAnchorPtr &new_dst); | |||||
| static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, | |||||
| const NodePtr &new_node); | |||||
| static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); | |||||
| static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); | |||||
| static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | |||||
| const std::vector<OpDescPtr> &vec_op_desc); | |||||
| /// | |||||
| /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst | |||||
| /// @param [in] src | |||||
| /// @param [in] dsts | |||||
| /// @param [in] insert_node | |||||
| /// @param [in] input_index | |||||
| /// @param [in] output_index | |||||
| /// @return graphStatus | |||||
| /// | |||||
| static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||||
| static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | |||||
| static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | |||||
| static void RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, const ge::NodePtr &node); | |||||
| static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | |||||
| static bool MatchDumpStr(const std::string &suffix); | |||||
| static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, | |||||
| const std::string &user_graph_name = ""); | |||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | |||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); | |||||
| static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | |||||
| static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | |||||
| static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); | |||||
| static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); | |||||
| static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path); | |||||
| static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
| /// | |||||
| /// Isolating `node`, relinking data links from the in-anchor peer nodes to | |||||
| /// the out-anchor peer nodes according to `io_map`, relinking control links | |||||
| /// to ensure that input nodes of `node` are before out nodes | |||||
| /// | |||||
| /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer | |||||
| /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0, | |||||
| /// unlink all links from `i` output anchor without any relinking. | |||||
| /// | |||||
| /// @param node | |||||
| /// @param io_map | |||||
| /// @return | |||||
| /// | |||||
| static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list<int> &io_map); | |||||
| static graphStatus IsolateNode(const NodePtr &node, const std::vector<int> &io_map); | |||||
| /// | |||||
| /// Isolate `node` which must be one input one output, equivalent to | |||||
| /// `IsolateNode(node, {0})` | |||||
| /// @param node | |||||
| /// @return | |||||
| /// | |||||
| static graphStatus IsolateNodeOneIO(const NodePtr &node); | |||||
| /// | |||||
| /// The data anchors replacing behavior is the same with | |||||
| /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control | |||||
| /// anchors with `new_node`'s. | |||||
| /// @param new_node | |||||
| /// @param old_node | |||||
| /// @param inputs_map | |||||
| /// @param outputs_map | |||||
| /// @return | |||||
| /// | |||||
| static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, | |||||
| std::initializer_list<int> inputs_map, std::initializer_list<int> outputs_map); | |||||
| static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, | |||||
| const std::vector<int> &inputs_map, const std::vector<int> &outputs_map); | |||||
| /// | |||||
| /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`. | |||||
| /// Replace the `i` in/out data anchor on `old_node` with | |||||
| /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`. | |||||
| /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in | |||||
| /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain | |||||
| /// on `old_node`. | |||||
| /// @param new_node | |||||
| /// @param old_node | |||||
| /// @param inputs_map | |||||
| /// @param outputs_map | |||||
| /// @return | |||||
| /// | |||||
| static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, | |||||
| std::initializer_list<int> inputs_map, | |||||
| std::initializer_list<int> outputs_map); | |||||
| static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, | |||||
| const std::vector<int> &inputs_map, const std::vector<int> &outputs_map); | |||||
| /// | |||||
| /// Copy all in-control edges from `src_node` to `dst_node` | |||||
| /// @param src_node | |||||
| /// @param dst_node | |||||
| /// @return | |||||
| /// | |||||
| static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); | |||||
| static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); | |||||
| /// | |||||
| /// Copy all out-control edges from `src_node` to `dst_node` | |||||
| /// @param src_node | |||||
| /// @param dst_node | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); | |||||
| /// | |||||
| /// Move all out-control edges from `src_node` to `dst_node` | |||||
| /// @param src_node | |||||
| /// @param dst_node | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||||
| /// | |||||
| /// Copy all in-data edges from `src_node` to `dst_node` | |||||
| /// @param src_node | |||||
| /// @param dst_node | |||||
| /// @return | |||||
| /// | |||||
| static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); | |||||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||||
| /// | |||||
| /// Make a copy of ComputeGraph. | |||||
| /// @param graph: original graph. | |||||
| /// @param prefix: node name prefix of new graph. | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, | |||||
| std::vector<NodePtr> &input_nodes, std::vector<NodePtr> &output_nodes); | |||||
| /// | |||||
| /// Copy tensor attribute to new node. | |||||
| /// @param [in] dst_desc: cloned node. | |||||
| /// @param [in] src_node: original node. | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); | |||||
| static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||||
| /// | |||||
| /// Get reference-mapping of all data_anchors in graph | |||||
| /// @param [in] graph | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus GetRefMapping(const ComputeGraphPtr &graph, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs | |||||
| /// of the graph have UNKNOWN_SHAPE operators or not. | |||||
| /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following | |||||
| /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE | |||||
| /// ROOT graph: A -----> B -----> C | |||||
| /// K subgraph U | |||||
| /// | | |||||
| /// V | |||||
| /// SUB graph: D --> E --> F | |||||
| /// K K K | |||||
| /// @param [in] graph | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); | |||||
| static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); | |||||
| private: | |||||
| /// | |||||
| /// Get reference-mapping for in_data_anchors of node | |||||
| /// @param [in] node | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus HandleInAnchorMapping(const NodePtr &node, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Get reference-mapping for out_data_anchors of node | |||||
| /// @param [in] node | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus HandleOutAnchorMapping(const NodePtr &node, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Handle input of subgraph | |||||
| /// @param [in] node | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus HandleSubgraphInput(const NodePtr &node, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Handle input of Merge op | |||||
| /// @param [in] node | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus HandleMergeInput(const NodePtr &node, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Handle output of subgraph | |||||
| /// @param [in] node | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus HandleSubgraphOutput(const NodePtr &node, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Relink all edges for cloned ComputeGraph. | |||||
| /// @param [in] node: original node. | |||||
| /// @param [in] prefix: node name prefix of new node. | |||||
| /// @param [in] all_nodes: all nodes in new graph. | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, | |||||
| const std::unordered_map<string, NodePtr> &all_nodes); | |||||
| /// | |||||
| /// Union ref-mapping | |||||
| /// @param [in] exist_node_info1 | |||||
| /// @param [in] exist_node_info2 | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @param [out] symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol); | |||||
| /// | |||||
| /// Update symbol mapping with a new reference pair | |||||
| /// @param [in] cur_node_info | |||||
| /// @param [in] exist_node_info | |||||
| /// @param [out] symbol_to_anchors | |||||
| /// @param [out] anchor_to_symbol | |||||
| /// @return success: GRAPH_SUCESS | |||||
| /// | |||||
| static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||||
| std::map<std::string, std::string> &anchor_to_symbol); | |||||
| /// | |||||
| /// Check if out_data_anchor is reference of input | |||||
| /// @param [in] out_data_anchor | |||||
| /// @param [out] reuse_in_index | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); | |||||
| }; | |||||
| class ComputeGraphBuilder { | |||||
| public: | |||||
| ComputeGraphBuilder() : owner_graph_(nullptr) {} | |||||
| ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; | |||||
| ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; | |||||
| ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; | |||||
| ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; | |||||
| ~ComputeGraphBuilder() = default; | |||||
| /// | |||||
| /// @brief Add node to graph | |||||
| /// @param [in] op_desc | |||||
| /// @return ComputeGraphBuilder | |||||
| /// | |||||
| virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); | |||||
| /// | |||||
| /// @brief Add data-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] out_anchor_ind | |||||
| /// @param [in] dst_name | |||||
| /// @param [in] in_anchor_ind | |||||
| /// @return ComputeGraphBuilder | |||||
| /// | |||||
| virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, | |||||
| const std::string &dst_name, uint32_t in_anchor_ind); | |||||
| /// | |||||
| /// @brief Add ctrl-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] dst_name | |||||
| /// @return ComputeGraphBuilder | |||||
| /// | |||||
| virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); | |||||
| /// | |||||
| /// @brief Build graph | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; | |||||
| /// @brief Get node with name | |||||
| /// @param [in] name | |||||
| /// @return NodePtr | |||||
| /// | |||||
| NodePtr GetNode(const std::string &name); | |||||
| /// @brief Get all nodes | |||||
| /// @return std::vector<NodePtr> | |||||
| /// | |||||
| std::vector<NodePtr> GetAllNodes(); | |||||
| protected: | |||||
| /// | |||||
| /// @brief Build nodes | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildNodes(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build data-links | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildDataLinks(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build ctrl-links | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); | |||||
| ComputeGraphPtr owner_graph_; | |||||
| // node_name -> node | |||||
| std::map<std::string, NodePtr> node_names_; | |||||
| std::vector<OpDescPtr> nodes_; | |||||
| // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind> | |||||
| std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_; | |||||
| // src_node_name -> dst_node_name | |||||
| std::vector<std::pair<std::string, std::string>> ctrl_links_; | |||||
| }; | |||||
| class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
| public: | |||||
| explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} | |||||
| CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | |||||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | |||||
| CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | |||||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; | |||||
| ~CompleteGraphBuilder() = default; | |||||
| /// | |||||
| /// @brief Add node to graph | |||||
| /// @param [in] op_desc | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||||
| /// | |||||
| /// @brief Add data-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] out_anchor_ind | |||||
| /// @param [in] dst_name | |||||
| /// @param [in] in_anchor_ind | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||||
| uint32_t in_anchor_ind) override; | |||||
| /// | |||||
| /// @brief Add ctrl-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] dst_name | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||||
| /// | |||||
| /// @brief Set index_th input anchor for graph | |||||
| /// @param [in] index | |||||
| /// @param [in] node_names | |||||
| /// @param [in] anchor_inds | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names, | |||||
| const std::vector<uint32_t> &anchor_inds); | |||||
| /// | |||||
| /// @brief Set index_th input of graph as useless | |||||
| /// @param [in] index | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetUselessInput(uint32_t index); | |||||
| /// | |||||
| /// @brief Add output anchor for graph | |||||
| /// @param [in] owner_node_name | |||||
| /// @param [in] anchor_ind | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | |||||
| /// | |||||
| /// @brief Add target for graph | |||||
| /// @param [in] target_name | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddTarget(const std::string &target_name); | |||||
| /// | |||||
| /// @brief Set parent-node of graph | |||||
| /// @param [in] parent_node | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); | |||||
| /// | |||||
| /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node | |||||
| /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||||
| /// | |||||
| /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind | |||||
| /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||||
| /// | |||||
| /// @brief Build graph | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||||
| private: | |||||
| /// | |||||
| /// @brief Add data nodes | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void AddDataNodes(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Add data node | |||||
| /// @param [in] index | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Add RetVal nodes | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void AddRetValNodes(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build target-nodes for graph | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); | |||||
| std::string name_; | |||||
| NodePtr parent_node_; | |||||
| std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | |||||
| std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | |||||
| std::vector<std::string> graph_targets_; | |||||
| // index_of_graph_input -> in_anchor_index_of_parent_node | |||||
| std::map<uint32_t, uint32_t> input_mapping_; | |||||
| // index_of_graph_output -> out_anchor_index_of_parent_node | |||||
| std::map<uint32_t, uint32_t> output_mapping_; | |||||
| }; | |||||
| class PartialGraphBuilder : public ComputeGraphBuilder { | |||||
| public: | |||||
| PartialGraphBuilder() = default; | |||||
| PartialGraphBuilder(const PartialGraphBuilder &) = delete; | |||||
| PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; | |||||
| PartialGraphBuilder(const PartialGraphBuilder &&) = delete; | |||||
| PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; | |||||
| ~PartialGraphBuilder() = default; | |||||
| /// | |||||
| /// @brief Add node to graph | |||||
| /// @param [in] op_desc | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||||
| /// | |||||
| /// @brief Add data-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] out_anchor_ind | |||||
| /// @param [in] dst_name | |||||
| /// @param [in] in_anchor_ind | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||||
| uint32_t in_anchor_ind) override; | |||||
| /// | |||||
| /// @brief Add ctrl-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] dst_name | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||||
| /// | |||||
| /// @brief Set owner graph | |||||
| /// @param [in] graph | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); | |||||
| /// | |||||
| /// @brief Add exist node | |||||
| /// @param [in] node | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddExistNode(const NodePtr &node); | |||||
| /// | |||||
| /// @brief Build multi nodes with links | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||||
| private: | |||||
| /// | |||||
| /// @brief Build exist nodes | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildExistNodes(graphStatus &error_code, std::string &error_msg); | |||||
| std::vector<NodePtr> exist_nodes_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ | |||||
| @@ -1,170 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_NODE_UTILS_H_ | |||||
| #include <set> | |||||
| #include <map> | |||||
| #include <vector> | |||||
| #include "external/graph/operator.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| // Op types of Const like Opps. | |||||
| extern const std::set<std::string> kConstOpTypes; | |||||
| // Op types of If like Opps. | |||||
| extern const std::set<std::string> kIfOpTypes; | |||||
| // Op types of While like Opps. | |||||
| extern const std::set<std::string> kWhileOpTypes; | |||||
| // Op types of Case like Opps. | |||||
| extern const std::set<std::string> kCaseOpTypes; | |||||
| // Op types of For like Opps. | |||||
| extern const std::set<std::string> kForOpTypes; | |||||
| class NodeUtils { | |||||
| public: | |||||
| static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); | |||||
| static graphStatus AddRecvEventId(const NodePtr &node, const uint32_t &event_id); | |||||
| static graphStatus GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send); | |||||
| static graphStatus GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv); | |||||
| static graphStatus ClearSendInfo(); | |||||
| static graphStatus ClearRecvInfo(); | |||||
| static graphStatus GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst); | |||||
| static graphStatus GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data, | |||||
| InControlAnchorPtr &in_control); | |||||
| static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); | |||||
| static graphStatus SetAllAnchorStatus(const NodePtr &nodePtr); | |||||
| static graphStatus SetAllAnchorStatus(Node &node); | |||||
| static bool IsAnchorStatusSet(const NodePtr &nodePtr); | |||||
| static bool IsAnchorStatusSet(const Node &node); | |||||
| static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); | |||||
| static void UpdateIsInputConst(const NodePtr &nodePtr); | |||||
| static void UpdateIsInputConst(Node &node); | |||||
| static bool IsConst(const Node &node); | |||||
| static void UnlinkAll(const Node &node); | |||||
| static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); | |||||
| static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t num); | |||||
| static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t num); | |||||
| static graphStatus AppendOutputAnchor(const NodePtr &node, uint32_t num); | |||||
| static graphStatus RemoveOutputAnchor(const NodePtr &node, uint32_t num); | |||||
| static bool IsInNodesEmpty(const Node &node); | |||||
| static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); | |||||
| static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); | |||||
| static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | |||||
| static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | |||||
| // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; | |||||
| // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, | |||||
| // the out param "is_unknow" will be true too | |||||
| static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); | |||||
| static std::string GetNodeType(const Node &node); | |||||
| static std::string GetNodeType(const NodePtr &node); | |||||
| static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||||
| static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); | |||||
| /// | |||||
| /// Check if node is input of subgraph | |||||
| /// @param [in] node | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsSubgraphInput(const NodePtr &node); | |||||
| /// | |||||
| /// Check if node is output of subgraph | |||||
| /// @param [in] node | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsSubgraphOutput(const NodePtr &node); | |||||
| /// | |||||
| /// @brief Get subgraph original input node. | |||||
| /// @param [in] node | |||||
| /// @return Node | |||||
| /// | |||||
| static NodePtr GetParentInput(const Node &node); | |||||
| static NodePtr GetParentInput(const NodePtr &node); | |||||
| /// | |||||
| /// @brief Get is dynamic shape graph from node. | |||||
| /// @param [in] node | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsDynamicShape(const Node &node); | |||||
| static bool IsDynamicShape(const NodePtr &node); | |||||
| /// | |||||
| /// @brief Check is varying_input for while node | |||||
| /// @param [in] node: Data node for subgraph | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsWhileVaryingInput(const ge::NodePtr &node); | |||||
| /// | |||||
| /// @brief Get subgraph input is constant. | |||||
| /// @param [in] node | |||||
| /// @param [out] string | |||||
| /// @return bool | |||||
| /// | |||||
| static bool GetConstOpType(const NodePtr &node, std::string &type); | |||||
| /// | |||||
| /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. | |||||
| /// @param [in] node | |||||
| /// @return return GRAPH_SUCCESS if remove successfully, other for failed. | |||||
| /// | |||||
| static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); | |||||
| /// | |||||
| /// @brief Get subgraph input data node by index. | |||||
| /// @param [in] node | |||||
| /// @return Node | |||||
| /// | |||||
| static vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, int index); | |||||
| /// | |||||
| /// @brief Get subgraph input data node by index. | |||||
| /// @param [in] node | |||||
| /// @return Node | |||||
| /// | |||||
| static vector<NodePtr> GetSubgraphOutputNodes(const Node &node); | |||||
| static NodePtr GetInDataNodeByIndex(const Node &node, const int index); | |||||
| static vector<pair<InDataAnchorPtr, NodePtr>> GetOutDataNodesWithAnchorByIndex(const Node &node, const int index); | |||||
| static ge::ConstNodePtr GetNodeFromOperator(const Operator &oprt); | |||||
| static graphStatus GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor); | |||||
| static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); | |||||
| private: | |||||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||||
| static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_NODE_UTILS_H_ | |||||
| @@ -1,181 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/def_types.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/operator.h" | |||||
| #include "graph/range_vistor.h" | |||||
| namespace ge { | |||||
| class OpDesc; | |||||
| using OpDescPtr = std::shared_ptr<OpDesc>; | |||||
| class OpDescUtils { | |||||
| public: | |||||
| template <class T> | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<OpDesc>>; | |||||
| OpDescUtils() = default; | |||||
| ~OpDescUtils() = default; | |||||
| static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); | |||||
| static bool HasQuantizeFactorParams(const OpDesc& op_desc); | |||||
| static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); | |||||
| static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); | |||||
| static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); | |||||
| static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); | |||||
| static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node); | |||||
| static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes); | |||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node); | |||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node); | |||||
| static vector<GeTensorPtr> MutableWeights(const ge::Node& node); | |||||
| static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | |||||
| static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights); | |||||
| static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights); | |||||
| static graphStatus ClearWeights(ge::NodePtr node); | |||||
| static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | |||||
| static bool ClearInputDesc(const ge::NodePtr& node); | |||||
| static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); | |||||
| static bool ClearOutputDesc(const ge::NodePtr& node); | |||||
| static vector<ge::NodePtr> GetConstInputs(const ge::Node& node); | |||||
| static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node); | |||||
| static size_t GetNonConstInputsSize(const ge::Node& node); | |||||
| static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | |||||
| // Index: Indicates the index of all non const inputs | |||||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); | |||||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); | |||||
| static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); | |||||
| static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); | |||||
| // Index: Indicates the index of all inputs | |||||
| static bool IsNonConstInput(const ge::Node& node, size_t index = 0); | |||||
| static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); | |||||
| static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node); | |||||
| static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); | |||||
| static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | |||||
| static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | |||||
| static OpDescPtr GetOpDescFromOperator(const Operator& oprt); | |||||
| static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||||
| static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, | |||||
| const std::string& subgraph_instance_name, OpDescPtr& op_desc); | |||||
| private: | |||||
| static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||||
| static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | |||||
| static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); | |||||
| static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | |||||
| }; | |||||
| class OpDescBuilder { | |||||
| public: | |||||
| OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} | |||||
| OpDescBuilder(const OpDescBuilder&) = delete; | |||||
| OpDescBuilder& operator=(const OpDescBuilder&) = delete; | |||||
| OpDescBuilder(const OpDescBuilder&&) = delete; | |||||
| OpDescBuilder& operator=(const OpDescBuilder&&) = delete; | |||||
| ~OpDescBuilder() = default; | |||||
| /// | |||||
| /// @brief Add input | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddInput(const std::string& name); | |||||
| /// | |||||
| /// @brief Add input | |||||
| /// @param [in] name | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); | |||||
| /// | |||||
| /// @brief Add dynamic input | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||||
| /// | |||||
| /// @brief Add dynamic input | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||||
| /// | |||||
| /// @brief Add output | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddOutput(const std::string& name); | |||||
| /// | |||||
| /// @brief Add output | |||||
| /// @param [in] name | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); | |||||
| /// | |||||
| /// @brief Add dynamic output | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||||
| /// | |||||
| /// @brief Add dynamic output | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @param [in] tensor | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||||
| /// | |||||
| /// @brief Build op_desc | |||||
| /// @return OpDescPtr | |||||
| /// | |||||
| OpDescPtr Build(); | |||||
| private: | |||||
| std::string name_; | |||||
| std::string type_; | |||||
| std::vector<std::pair<std::string, GeTensorDesc>> inputs_; | |||||
| std::vector<std::pair<std::string, GeTensorDesc>> outputs_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | |||||
| @@ -1,43 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ | |||||
| #define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ | |||||
| #include <memory> | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/tensor.h" | |||||
| namespace ge { | |||||
| using GeTensorPtr = std::shared_ptr<GeTensor>; | |||||
| using ConstGeTensorPtr = std::shared_ptr<const GeTensor>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { | |||||
| public: | |||||
| static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensorDesc); | |||||
| static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &geTensorDesc); | |||||
| static GeTensorPtr Tensor2GeTensor(const Tensor &tensor); | |||||
| static Tensor GeTensor2Tensor(const ConstGeTensorPtr &geTensor); | |||||
| static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value | |||||
| static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value | |||||
| static const GeTensor AsGeTensor(const Tensor &tensor); // Share value | |||||
| static GeTensor AsGeTensor(Tensor &tensor); // Share value | |||||
| static const Tensor AsTensor(const GeTensor &tensor); // Share value | |||||
| static Tensor AsTensor(GeTensor &tensor); // Share value | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ | |||||
| @@ -1,77 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | |||||
| #include <vector> | |||||
| #include "graph/def_types.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| namespace ge { | |||||
| class TensorUtils { | |||||
| public: | |||||
| static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); | |||||
| static void SetSize(GeTensorDesc &tensorDesc, int64_t size); | |||||
| static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | |||||
| static uint32_t GetWeightSize(const GeTensor &tensor); | |||||
| static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | |||||
| static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensorPtr, uint8_t *base); | |||||
| static uint8_t *GetWeightAddr(const GeTensor &tensor, uint8_t *base); | |||||
| static void SetWeightSize(GeTensorDesc &tensorDesc, uint32_t size); | |||||
| static ge::graphStatus GetReuseInput(const GeTensorDesc &tensorDesc, bool &flag); | |||||
| static void SetReuseInput(GeTensorDesc &tensorDesc, bool flag); | |||||
| static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensorDesc, bool &flag); | |||||
| static void SetOutputTensor(GeTensorDesc &tensorDesc, bool flag); | |||||
| static graphStatus GetDeviceType(const GeTensorDesc &tensorDesc, DeviceType &type); | |||||
| static void SetDeviceType(GeTensorDesc &tensorDesc, DeviceType type); | |||||
| static ge::graphStatus GetInputTensor(const GeTensorDesc &tensorDesc, bool &flag); | |||||
| static void SetInputTensor(GeTensorDesc &tensorDesc, bool flag); | |||||
| static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensorDesc, uint32_t &cnt); | |||||
| static void SetRealDimCnt(GeTensorDesc &tensorDesc, uint32_t cnt); | |||||
| static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensorDesc, uint32_t &idx); | |||||
| static void SetReuseInputIndex(GeTensorDesc &tensorDesc, uint32_t idx); | |||||
| static ge::graphStatus GetDataOffset(const GeTensorDesc &tensorDesc, int64_t &offset); | |||||
| static void SetDataOffset(GeTensorDesc &tensorDesc, int64_t offset); | |||||
| static ge::graphStatus GetCmpsSize(const GeTensorDesc &tensorDesc, uint32_t &cmp_size); | |||||
| static void SetCmpsSize(GeTensorDesc &tensorDesc, uint32_t cmp_size); | |||||
| static ge::graphStatus GetCmpsTab(const GeTensorDesc &tensorDesc, vector<uint8_t> &vec); | |||||
| static void SetCmpsTab(GeTensorDesc &tensorDesc, const uint8_t *data, size_t size); | |||||
| static ge::graphStatus GetCmpsTabOffset(const GeTensorDesc &tensorDesc, int64_t &tab_offset); | |||||
| static void SetCmpsTabOffset(GeTensorDesc &tensorDesc, int64_t tab_offset); | |||||
| static ge::graphStatus GetCmpsInfo(const GeTensorDesc &tensorDesc, CompressInfo &info); | |||||
| static void SetCmpsInfo(GeTensorDesc &tensorDesc, const CompressInfo &info); | |||||
| static bool HasAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc); | |||||
| static ge::graphStatus GetAlloffsetQuantizeInfo(const GeTensorDesc &tensorDesc, AllOffsetQuantizeInfo &info); | |||||
| static void SetAlloffsetQuantizeInfo(GeTensorDesc &tensorDesc, const AllOffsetQuantizeInfo &info); | |||||
| static ge::graphStatus GetRC(const GeTensorDesc &tensorDesc, uint32_t &rc); | |||||
| static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | |||||
| /// | |||||
| /// calculate tensor mem size. | |||||
| /// @param shape tensor shape | |||||
| /// @param format tensor format | |||||
| /// @param data_type tensor data type | |||||
| /// @param mem_size -1 means unknown shape,other means mem size | |||||
| /// @return GRAPH_SUCCESS:success, other:failed | |||||
| /// | |||||
| static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | |||||
| static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||||
| static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ | |||||
| @@ -1,53 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_ | |||||
| #define INC_GRAPH_UTILS_TYPE_UTILS_H_ | |||||
| #include <map> | |||||
| #include <unordered_set> | |||||
| #include <string> | |||||
| #include "graph/def_types.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| #include "graph/usr_types.h" | |||||
| #include "register/register_types.h" | |||||
| #include "external/register/register_fmk_types.h" | |||||
| namespace ge { | |||||
| class TypeUtils { | |||||
| public: | |||||
| static bool IsDataTypeValid(DataType dt); | |||||
| static bool IsFormatValid(Format format); | |||||
| static bool IsInternalFormat(Format format); | |||||
| static std::string ImplyTypeToSerialString(domi::ImplyType imply_type); | |||||
| static std::string DataTypeToSerialString(DataType data_type); | |||||
| static DataType SerialStringToDataType(const std::string &str); | |||||
| static std::string FormatToSerialString(Format format); | |||||
| static Format SerialStringToFormat(const std::string &str); | |||||
| static Format DataFormatToFormat(const std::string &str); | |||||
| static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); | |||||
| static std::string FmkTypeToSerialString(domi::FrameworkType fmk_type); | |||||
| static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); | |||||
| static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); | |||||
| static bool GetDataTypeLength(ge::DataType data_type, uint32_t &length); | |||||
| static bool CheckUint64MulOverflow(uint64_t a, uint32_t b); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_TYPE_UTILS_H_ | |||||
| @@ -1,127 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package toolkit.dumpdata; | |||||
| enum OutputDataType { | |||||
| DT_UNDEFINED = 0; | |||||
| DT_FLOAT = 1; | |||||
| DT_FLOAT16 = 2; | |||||
| DT_INT8 = 3; | |||||
| DT_UINT8 = 4; | |||||
| DT_INT16 = 5; | |||||
| DT_UINT16 = 6; | |||||
| DT_INT32 = 7; | |||||
| DT_INT64 = 8; | |||||
| DT_UINT32 = 9; | |||||
| DT_UINT64 = 10; | |||||
| DT_BOOL = 11; | |||||
| DT_DOUBLE = 12; | |||||
| DT_STRING = 13; | |||||
| DT_DUAL_SUB_INT8 = 14; | |||||
| DT_DUAL_SUB_UINT8 = 15; | |||||
| DT_COMPLEX64 = 16; | |||||
| DT_COMPLEX128 = 17; | |||||
| DT_QINT8 = 18; | |||||
| DT_QINT16 = 19; | |||||
| DT_QINT32 = 20; | |||||
| DT_QUINT8 = 21; | |||||
| DT_QUINT16 = 22; | |||||
| DT_RESOURCE = 23; | |||||
| DT_STRING_REF = 24; | |||||
| DT_DUAL = 25; | |||||
| } | |||||
| enum OutputFormat { | |||||
| FORMAT_NCHW = 0; | |||||
| FORMAT_NHWC = 1; | |||||
| FORMAT_ND = 2; | |||||
| FORMAT_NC1HWC0 = 3; | |||||
| FORMAT_FRACTAL_Z = 4; | |||||
| FORMAT_NC1C0HWPAD = 5; | |||||
| FORMAT_NHWC1C0 = 6; | |||||
| FORMAT_FSR_NCHW = 7; | |||||
| FORMAT_FRACTAL_DECONV = 8; | |||||
| FORMAT_C1HWNC0 = 9; | |||||
| FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; | |||||
| FORMAT_NC1HWC0_C04 = 12; | |||||
| FORMAT_FRACTAL_Z_C04 = 13; | |||||
| FORMAT_CHWN = 14; | |||||
| FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; | |||||
| FORMAT_HWCN = 16; | |||||
| FORMAT_NC1KHKWHWC0 = 17; | |||||
| FORMAT_BN_WEIGHT = 18; | |||||
| FORMAT_FILTER_HWCK = 19; | |||||
| FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; | |||||
| FORMAT_HASHTABLE_LOOKUP_KEYS = 21; | |||||
| FORMAT_HASHTABLE_LOOKUP_VALUE = 22; | |||||
| FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; | |||||
| FORMAT_HASHTABLE_LOOKUP_HITS=24; | |||||
| FORMAT_C1HWNCoC0 = 25; | |||||
| FORMAT_MD = 26; | |||||
| FORMAT_NDHWC = 27; | |||||
| FORMAT_FRACTAL_ZZ = 28; | |||||
| FORMAT_FRACTAL_NZ = 29; | |||||
| FORMAT_RESERVED = 30; | |||||
| } | |||||
| message OriginalOp { | |||||
| string name = 1; | |||||
| uint32 output_index = 2; | |||||
| OutputDataType data_type = 3; | |||||
| OutputFormat format = 4; | |||||
| } | |||||
| message Shape { | |||||
| repeated uint64 dim = 1; | |||||
| } | |||||
| message OpOutput { | |||||
| OutputDataType data_type = 1; | |||||
| OutputFormat format = 2; | |||||
| Shape shape = 3; | |||||
| OriginalOp original_op = 4; // the original op corresponding to the output | |||||
| bytes data = 5; | |||||
| uint64 size = 6; | |||||
| } | |||||
| message OpInput { | |||||
| OutputDataType data_type = 1; | |||||
| OutputFormat format = 2; | |||||
| Shape shape = 3; | |||||
| bytes data = 4; | |||||
| uint64 size = 5; | |||||
| } | |||||
| enum BufferType { | |||||
| L1 = 0; | |||||
| } | |||||
| message OpBuffer { | |||||
| BufferType buffer_type = 1; | |||||
| bytes data = 2; | |||||
| uint64 size = 3; | |||||
| } | |||||
| message DumpData{ | |||||
| string version = 1; | |||||
| uint64 dump_time = 2; | |||||
| repeated OpOutput output = 3; | |||||
| repeated OpInput input = 4; | |||||
| repeated OpBuffer buffer = 5; | |||||
| } | |||||
| @@ -1,26 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| import "om.proto"; | |||||
| package domi; | |||||
| message FusionModelDef { | |||||
| string version = 1; | |||||
| repeated OpDef fusion_op = 2; | |||||
| } | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package aicpu.FWKAdapter; | |||||
| option cc_enable_arenas = true; | |||||
| // Defines an struct for input and output. | |||||
| message TensorDataInfo { | |||||
| // value DataType | |||||
| uint32 dtype = 1; | |||||
| // shape dim | |||||
| repeated int64 dim = 2; | |||||
| // data point addr | |||||
| int64 data_addr = 3; | |||||
| } | |||||
| message KernelRunParam { | |||||
| // input | |||||
| repeated TensorDataInfo input = 1; | |||||
| // output | |||||
| repeated TensorDataInfo output = 2; | |||||
| } | |||||